Graphite
Loading...
Searching...
No Matches
error.hpp
Go to the documentation of this file.
1
2#pragma once
3#include <cstddef>
5namespace graphite {
6
7namespace ops {
8
9template <size_t N, size_t P> struct arg_helper {
10 constexpr static size_t value = N;
11 static_assert(N == P, "Helper to print N and P at compile time");
12};
13
14template <class> struct fn_arity;
15
16// Count number of arguments
17template <class R, class... Args>
18struct fn_arity<R (*)(Args...)>
19 : std::integral_constant<std::size_t, sizeof...(Args)> {};
20
21template <class> struct first_arg;
22
23template <class R, class First, class... Rest>
24struct first_arg<R (*)(First, Rest...)> {
25 using type = First;
26};
27
28template <typename F, typename D> __device__ constexpr bool takes_vertices() {
29 return std::is_reference<
30 typename first_arg<decltype(&F::Traits::template error<D>)>::type>::value;
31}
32
33template <typename F, typename D, typename VertexPointers,
34 typename ParameterBlocks, typename Observation,
35 typename ConstraintData, typename ErrorVector, std::size_t... Is>
36__device__ inline void
37call_error_fn(VertexPointers &vertices, ParameterBlocks &parameters,
38 Observation &local_obs, ConstraintData &local_data,
39 ErrorVector &local_error, std::index_sequence<Is...>) {
40
41 using DataType = typename F::ConstraintDataType;
42 using ObsType = typename F::ObservationType;
43 constexpr size_t N = F::get_num_vertices();
44
45 constexpr size_t num_parameters =
46 fn_arity<decltype(&F::Traits::template error<D>)>::value;
47
48 if constexpr (std::is_empty<ObsType>::value &&
49 std::is_empty<DataType>::value) {
50 if constexpr (N * 2 + 1 == num_parameters) {
51 F::Traits::error((*cuda::std::get<Is>(vertices))...,
52 cuda::std::get<Is>(parameters).data()..., local_error);
53 } else if constexpr (takes_vertices<F, D>()) {
54 F::Traits::error((*cuda::std::get<Is>(vertices))..., local_error);
55 } else {
56 F::Traits::error(cuda::std::get<Is>(parameters).data()..., local_error);
57 }
58 } else if constexpr (std::is_empty<DataType>::value) {
59
60 if constexpr (N * 2 + 2 == num_parameters) {
61 F::Traits::error((*cuda::std::get<Is>(vertices))...,
62 cuda::std::get<Is>(parameters).data()..., *local_obs,
63 local_error);
64 } else if constexpr (takes_vertices<F, D>()) {
65 F::Traits::error((*cuda::std::get<Is>(vertices))..., *local_obs,
66 local_error);
67 } else {
68 F::Traits::error(cuda::std::get<Is>(parameters).data()..., *local_obs,
69 local_error);
70 }
71 } else if constexpr (std::is_empty<ObsType>::value) {
72 if constexpr (N * 2 + 2 == num_parameters) {
73 F::Traits::error((*cuda::std::get<Is>(vertices))...,
74 cuda::std::get<Is>(parameters).data()..., *local_data,
75 local_error);
76 } else if constexpr (takes_vertices<F, D>()) {
77 F::Traits::error((*cuda::std::get<Is>(vertices))..., *local_data,
78 local_error);
79 } else {
80 F::Traits::error(cuda::std::get<Is>(parameters).data()..., *local_data,
81 local_error);
82 }
83 } else {
84 if constexpr (N * 2 + 3 == num_parameters) {
85 F::Traits::error((*cuda::std::get<Is>(vertices))...,
86 cuda::std::get<Is>(parameters).data()..., *local_obs,
87 *local_data, local_error);
88 } else if constexpr (takes_vertices<F, D>()) {
89 F::Traits::error((*cuda::std::get<Is>(vertices))..., *local_obs,
90 *local_data, local_error);
91 } else {
92 F::Traits::error(cuda::std::get<Is>(parameters).data()..., *local_obs,
93 *local_data, local_error);
94 }
95 }
96}
97
98template <typename T, typename S, size_t I, size_t N, typename M, size_t E,
99 typename F, typename VT, std::size_t... Is>
100__global__ void compute_error_kernel_autodiff(
101 const M *obs, T *error,
102 const typename F::ConstraintDataType *constraint_data,
103 const size_t *active_ids, const size_t *ids, const size_t *hessian_ids,
104 const size_t num_threads, VT args, S *jacs, const uint8_t *active_state,
105 std::index_sequence<Is...>) {
106 const size_t idx = get_thread_id();
107
108 if (idx >= num_threads) {
109 return;
110 }
111
112 constexpr auto vertex_sizes = F::get_vertex_sizes();
113 const auto factor_id = active_ids[idx / vertex_sizes[I]];
114 const auto vertex_id = ids[factor_id * N + I];
115
116 // printf("CEAD: Thread %d, Vertex %d, Factor %d\n", idx, vertex_id,
117 // factor_id);
118 using G = std::conditional_t<is_low_precision<S>::value, T, S>;
119 const M *local_obs = obs + factor_id;
120 Dual<T, G> local_error[E];
121 const typename F::ConstraintDataType *local_data =
122 constraint_data + factor_id;
123
124 auto v = cuda::std::make_tuple(std::array<Dual<T, G>, vertex_sizes[Is]>{}...);
125
126 auto vargs = cuda::std::make_tuple(
127 (*(std::get<Is>(args) + ids[factor_id * N + Is]))...);
128
129 auto copy_vertices = [&v, &vertex_sizes, &vargs](auto &&...ptrs) {
130 ((std::tuple_element<Is, typename F::Traits::VertexDescriptors>::type::
131 Traits::parameters(*cuda::std::get<Is>(vargs),
132 cuda::std::get<Is>(v).data())),
133 ...);
134 };
135
136 cuda::std::apply(copy_vertices, vargs);
137
138 cuda::std::get<I>(v)[idx % vertex_sizes[I]].dual = static_cast<G>(1);
139
140 call_error_fn<F, Dual<T, G>>(vargs, v, local_obs, local_data, local_error,
141 std::make_index_sequence<N>{});
142
143 constexpr auto j_size = vertex_sizes[I] * E;
144 // constexpr auto col_offset = I*E;
145 const auto col_offset = (idx % vertex_sizes[I]) * E;
146 // Store column-major Jacobian blocks.
147 // Write one scalar column (length E) of the Jacobian matrix.
148 // TODO: make sure this only writes to each location once
149 // The Jacobian is stored as E x vertex_size in col major
150
151 // Only run once per factor - this check won't work for multiple kernel
152 // launches
153 // TODO: make sure this only writes to each location once for the error
154 if (idx % vertex_sizes[I] == 0) {
155#pragma unroll
156 for (size_t i = 0; i < E; ++i) {
157 error[factor_id * E + i] = local_error[i].real;
158 }
159 }
160
161 // This should write one Jacobian column per dimension per vertex for each
162 // factor We only need a Jacobian if the vertex is not fixed
163 if (!is_vertex_active(active_state, vertex_id)) {
164 return;
165 }
166
167 if constexpr (std::is_same<S, __half>::value) {
168// Need to clamp range
169#pragma unroll
170 for (size_t i = 0; i < E; ++i) {
171 jacs[j_size * factor_id + col_offset + i] =
172 static_cast<S>(std::clamp(local_error[i].dual, -65504.0f, 65504.0f));
173 }
174 } else {
175#pragma unroll
176 for (size_t i = 0; i < E; ++i) {
177 jacs[j_size * factor_id + col_offset + i] = local_error[i].dual;
178 }
179 }
180}
181
182template <typename T, typename S, typename F, typename VT, std::size_t... Is>
183void launch_kernel_autodiff(
184 F *f, std::array<const size_t *, F::get_num_vertices()> &hessian_ids,
185 VT &verts, std::array<S *, F::get_num_vertices()> &jacs,
186 const size_t num_factors, StreamPool &streams, std::index_sequence<Is...>) {
187 (([&] {
188 constexpr auto num_vertices = F::get_num_vertices();
189 const auto num_threads = num_factors * F::get_vertex_sizes()[Is];
190 // std::cout << "Launching autodiff kernel" << std::endl;
191 // std::cout << "Num threads: " << num_threads << std::endl;
192 size_t threads_per_block = 256;
193 size_t num_blocks =
194 (num_threads + threads_per_block - 1) / threads_per_block;
195
196 // std::cout << "Checking obs ptr: " << f->device_obs.data().get() <<
197 // std::endl; std::cout << "Checking residual ptr: " <<
198 // f->residuals.data().get() << std::endl; std::cout << "Checking ids
199 // ptr: " << f->device_ids.data().get() << std::endl;
200
201 compute_error_kernel_autodiff<T, S, Is, num_vertices,
202 typename F::ObservationType, F::error_dim, F,
203 typename F::VertexPointerPointerTuple>
204 <<<num_blocks, threads_per_block, 0, streams.select(Is)>>>(
205 f->device_obs.data().get(), f->residuals.data().get(),
206 f->data.data().get(), f->active_indices.data().get(),
207 f->device_ids.data().get(), hessian_ids[Is], num_threads, verts,
208 jacs[Is], f->vertex_descriptors[Is]->get_active_state(),
209 std::make_index_sequence<num_vertices>{});
210 }()),
211 ...);
212}
213
214template <typename T, typename S, typename F>
215void compute_error_autodiff(F *f, StreamPool &streams) {
216 // Assume autodiff
217
218 // Then for each vertex, we need to compute the error
219 constexpr auto num_vertices = F::get_num_vertices();
220 constexpr auto vertex_sizes = F::get_vertex_sizes();
221
222 // At this point all necessary data should be on the GPU
223 // std::array<T*, num_vertices> verts;
224 auto verts = f->get_vertices();
225 std::array<S *, num_vertices> jacs;
226 std::array<const size_t *, num_vertices> hessian_ids;
227 for (int i = 0; i < num_vertices; i++) {
228 // verts[i] = f->vertex_descriptors[i]->vertices();
229 jacs[i] = f->jacobians[i].data.data().get();
230 hessian_ids[i] = f->vertex_descriptors[i]->get_hessian_ids();
231
232 // Important: Must clear Jacobian storage
233 thrust::fill(thrust::cuda::par_nosync.on(streams.select(i)),
234 f->jacobians[i].data.begin(), f->jacobians[i].data.end(),
235 static_cast<S>(0));
236 }
237
238 const auto num_factors = f->active_count();
239
240 if constexpr (!is_analytical<F>()) {
241 launch_kernel_autodiff<T, S>(f, hessian_ids, verts, jacs, num_factors,
242 streams,
243 std::make_index_sequence<num_vertices>{});
244 }
245 streams.sync_n(num_vertices);
246}
247
248// TODO: Make this more efficient and see if code can be shared with the
249// autodiff kernel
250template <typename T, size_t N, typename M, size_t E, typename F, typename VT,
251 std::size_t... Is>
252__global__ void
253compute_error_kernel(const M *obs, T *error,
254 const typename F::ConstraintDataType *constraint_data,
255 const size_t *active_ids, const size_t *ids,
256 const size_t num_threads, VT args,
257 std::index_sequence<Is...>) {
258 const size_t idx = get_thread_id();
259
260 if (idx >= num_threads) {
261 return;
262 }
263
264 constexpr auto vertex_sizes = F::get_vertex_sizes();
265 const auto factor_id = active_ids[idx];
266
267 const M *local_obs = obs + factor_id;
268 T *local_error = error + factor_id * E;
269 const typename F::ConstraintDataType *local_data =
270 constraint_data + factor_id;
271
272 auto v = cuda::std::make_tuple(std::array<T, vertex_sizes[Is]>{}...);
273
274 auto vargs = cuda::std::make_tuple(
275 (*(std::get<Is>(args) + ids[factor_id * N + Is]))...);
276
277 auto copy_vertices = [&v, &vertex_sizes, &vargs](auto &&...ptrs) {
278 ((std::tuple_element<Is, typename F::Traits::VertexDescriptors>::type::
279 Traits::parameters(*cuda::std::get<Is>(vargs),
280 cuda::std::get<Is>(v).data())),
281 ...);
282 };
283
284 cuda::std::apply(copy_vertices, vargs);
285
286 call_error_fn<F, T>(vargs, v, local_obs, local_data, local_error,
287 std::make_index_sequence<N>{});
288
289#pragma unroll
290 for (size_t i = 0; i < E; ++i) {
291 error[factor_id * E + i] = local_error[i];
292 }
293}
294
295template <typename T, typename F> void compute_error(F *f) {
296 // Then for each vertex, we need to compute the error
297 constexpr auto num_vertices = F::get_num_vertices();
298 constexpr auto vertex_sizes = F::get_vertex_sizes();
299
300 // At this point all necessary data should be on the GPU
301 auto verts = f->get_vertices();
302 std::array<const size_t *, num_vertices> hessian_ids;
303 for (int i = 0; i < num_vertices; i++) {
304 hessian_ids[i] = f->vertex_descriptors[i]->get_hessian_ids();
305 }
306
307 constexpr auto error_dim = F::error_dim;
308 const auto num_factors = f->active_count();
309
310 const auto num_threads = num_factors;
311 size_t threads_per_block = 256;
312 size_t num_blocks = (num_threads + threads_per_block - 1) / threads_per_block;
313
314 compute_error_kernel<T, num_vertices, typename F::ObservationType,
315 F::error_dim, F, typename F::VertexPointerPointerTuple>
316 <<<num_blocks, threads_per_block>>>(
317 f->device_obs.data().get(), f->residuals.data().get(),
318 f->data.data().get(), f->active_indices.data().get(),
319 f->device_ids.data().get(), num_threads, verts,
320 std::make_index_sequence<num_vertices>{});
321
322 cudaStreamSynchronize(0);
323}
324
325} // namespace ops
326} // namespace graphite
Definition error.hpp:9
Definition error.hpp:21
Definition error.hpp:14