Graphite
Loading...
Searching...
No Matches
product.hpp
Go to the documentation of this file.
1
2#pragma once
4
5namespace graphite {
6
7namespace ops {
8
9template <typename T, typename G, size_t I, size_t N, typename M, size_t E,
10 typename F, typename VT, std::size_t... Is>
11__device__ void
12compute_Jcol_ad(Dual<T, G> *error, const size_t col, const size_t factor_id,
13 const size_t vertex_id, const M *obs,
14 const typename F::ConstraintDataType *constraint_data,
15 size_t *ids, const size_t *hessian_ids, VT args,
16 std::index_sequence<Is...>) {
17
18 constexpr auto vertex_sizes = F::get_vertex_sizes();
19
20 const M *local_obs = obs + factor_id;
21 // Dual<T, G> local_error[E];
22 const typename F::ConstraintDataType *local_data =
23 constraint_data + factor_id;
24
25 auto v = cuda::std::make_tuple(std::array<Dual<T, G>, vertex_sizes[Is]>{}...);
26
27 auto vargs =
28 std::make_tuple((*(std::get<Is>(args) + ids[factor_id * N + Is]))...);
29
30 auto copy_vertices = [&v, &vertex_sizes, &vargs](auto &&...ptrs) {
31 ((std::tuple_element<Is, typename F::Traits::VertexDescriptors>::type::
32 Traits::parameters(*std::get<Is>(vargs),
33 cuda::std::get<Is>(v).data())),
34 ...);
35 };
36
37 std::apply(copy_vertices, vargs);
38
39 cuda::std::get<I>(v)[col].dual = static_cast<G>(1);
40
41 F::Traits::error(cuda::std::get<Is>(v).data()..., local_obs, error, vargs,
42 local_data);
43}
44
45// Compute J * x where the length of vector x matches the Hessian dimension
46// Each Jacobian block needs to be accessed just once
47// So we need E threads for each block (error dimension)
48// In total we should hae E*num_factors threads?
49template <typename T, typename S, size_t I, size_t N, size_t E, size_t D,
50 typename F, std::size_t... Is>
51__global__ void compute_Jv_kernel(T *y, const T *x, const size_t *active_ids,
52 const size_t *ids, const size_t *hessian_ids,
53 const size_t num_threads, const S *jacs,
54 const uint8_t *active_state,
55 std::index_sequence<Is...>) {
56
57 const size_t idx = get_thread_id();
58
59 if (idx >= num_threads) {
60 return;
61 }
62
63 constexpr auto jacobian_size = D * E;
64
65 // Each J block is stored as E x d col major, where d is the vertex size
66 const size_t factor_id = active_ids[idx / E];
67 const size_t local_id =
68 ids[factor_id * N +
69 I]; // N is the number of vertices involved in the factor
70 if (!is_vertex_active(active_state, local_id)) {
71 return;
72 }
73 const auto jacobian_offset = factor_id * jacobian_size;
74
75 T value = 0;
76
77 const auto hessian_offset =
78 hessian_ids[local_id]; // each vertex has a hessian_ids array
79 const auto row_offset = (idx % E);
80 // Adding i*E skips to the next column
81 // size_t residual_offset = 0; // need to pass this in
82 // it's the offset into the r vector
83 // #pragma unroll
84 // for (int i = 0; i < d; i++) {
85 // value += jacs[jacobian_offset + row_offset + i*E] * x[hessian_offset +
86 // i];
87 // }
88
89 const S *jrow = jacs + jacobian_offset + row_offset;
90 const T *x_start = x + hessian_offset;
91
92#pragma unroll
93 for (int i = 0; i < D; i++) {
94 value += (T)(jrow[i * E] * (S)x_start[i]);
95 }
96
97 atomicAdd(&y[idx], value);
98 // y[idx] += value; // avoid unless sure that atomicAdd is not needed
99}
100
101template <typename T, typename S, size_t I, size_t N, typename M, size_t E,
102 typename F, typename VT, std::size_t... Is>
103__global__ void compute_Jv_dynamic_manual2(
104 T *y, T *x, const M *obs, const T *jacobian_scales,
105 const typename F::ConstraintDataType *constraint_data,
106 const size_t *active_ids, const size_t *ids, const size_t *hessian_ids,
107 const size_t num_factors, VT args, const uint8_t *active_state,
108 std::index_sequence<Is...>) {
109 const size_t idx = get_thread_id();
110 constexpr size_t D = F::get_vertex_sizes()[I];
111
112 const size_t factor_id = active_ids[idx / E];
113 const size_t row_in_jacobian = idx % E;
114 if (factor_id < num_factors) {
115
116 const auto vertex_id = ids[factor_id * N + I];
117
118 const auto hess_col = hessian_ids[vertex_id];
119 const T *x_start = x + hess_col;
120
121 // Each thread block stores a complete Jacobian row in shared memory
122 if (is_vertex_active(active_state, vertex_id)) {
123 // using G = std::conditional_t<is_low_precision<S>::value, T, S>;
124 using G = T;
125 // Dual<T, G> error[E];
126 constexpr auto jacobian_size = E * D;
127 G jacobian[jacobian_size];
128 compute_Jblock<T, I, N, M, E, F, VT>(jacobian, factor_id, vertex_id, obs,
129 constraint_data, ids, hessian_ids,
130 args, std::make_index_sequence<N>{});
131
132 T sum = 0.0;
133#pragma unroll
134 for (size_t i = 0; i < D; i++) {
135 const T scaled_j = static_cast<T>(jacobian[row_in_jacobian + i * E]) *
136 jacobian_scales[hess_col + i];
137 sum += static_cast<T>((S)scaled_j * (S)x_start[i]);
138 }
139 atomicAdd(&y[idx], sum);
140 }
141 }
142}
143
144template <typename T, typename S, typename F, std::size_t... Is>
145void launch_kernel_compute_Jv(
146 F *f, T *out, T *in,
147 std::array<const size_t *, F::get_num_vertices()> &hessian_ids,
148 std::array<S *, F::get_num_vertices()> &jacs, const T *jacobian_scales,
149 const size_t num_factors, StreamPool &streams, std::index_sequence<Is...>) {
150 (([&] {
151 constexpr auto num_vertices = F::get_num_vertices();
152 constexpr auto vertex_sizes = F::get_vertex_sizes();
153 if (f->store_jacobians() || !is_analytical<F>()) {
154 const auto num_threads = num_factors * F::error_dim;
155
156 size_t threads_per_block = 256;
157 size_t num_blocks =
158 (num_threads + threads_per_block - 1) / threads_per_block;
159 compute_Jv_kernel<T, S, Is, num_vertices, F::error_dim,
160 f->get_vertex_sizes()[Is], F>
161 <<<num_blocks, threads_per_block, 0, streams.select(Is)>>>(
162 out, in, f->active_indices.data().get(),
163 f->device_ids.data().get(), hessian_ids[Is], num_threads,
164 jacs[Is], f->vertex_descriptors[Is]->get_active_state(),
165 std::make_index_sequence<num_vertices>{});
166 } else {
167 constexpr auto num_vertices = F::get_num_vertices();
168 constexpr auto vertex_sizes = F::get_vertex_sizes();
169 const auto num_threads = num_factors * F::error_dim;
170
171 size_t threads_per_block = 256;
172 size_t num_blocks =
173 (num_threads + threads_per_block - 1) / threads_per_block;
174 constexpr size_t E = F::error_dim;
175
176 if constexpr (is_analytical<F>()) {
177
178 compute_Jv_dynamic_manual2<T, S, Is, num_vertices,
179 typename F::ObservationType, E, F,
180 typename F::VertexPointerPointerTuple>
181 <<<num_blocks, threads_per_block, 0, streams.select(Is)>>>(
182 out, in, f->device_obs.data().get(), jacobian_scales,
183 f->data.data().get(), f->active_indices.data().get(),
184 f->device_ids.data().get(), hessian_ids[Is], num_factors,
185 f->get_vertices(),
186 f->vertex_descriptors[Is]->get_active_state(),
187 std::make_index_sequence<num_vertices>{});
188 }
189 }
190 }()),
191 ...);
192}
193
194template <typename T, typename S, typename F>
195void compute_Jv(F *f, T *out, T *in, const T *jacobian_scales,
196 StreamPool &streams) {
197 constexpr auto num_vertices = F::get_num_vertices();
198 constexpr auto vertex_sizes = F::get_vertex_sizes();
199
200 // std::array<T*, num_vertices> verts;
201 auto verts = f->get_vertices();
202 std::array<S *, num_vertices> jacs;
203 std::array<const size_t *, num_vertices> hessian_ids;
204 for (int i = 0; i < num_vertices; i++) {
205 // verts[i] = f->vertex_descriptors[i]->x();
206 jacs[i] = f->jacobians[i].data.data().get();
207 hessian_ids[i] = f->vertex_descriptors[i]->get_hessian_ids();
208 }
209
210 const auto num_factors = f->active_count();
211
212 launch_kernel_compute_Jv<T, S>(f, out, in, hessian_ids, jacs, jacobian_scales,
213 num_factors, streams,
214 std::make_index_sequence<num_vertices>{});
215 streams.sync_n(num_vertices);
216}
217
218// Compute J^T * x where x is the size of the residual vector
219// Each Jacobian block needs to be accessed just once
220// For each block, we need d threads where d is the vertex size
221// We need to load the x vector location for the corresponding block row of J
222// So this assumes that the x vector has the same layout as the residual vector
223// for this factor (rather than a global residual vector) The aggregate output
224// will be H x len(x) where H is hessian dimension
225// Compute J^T * P * x where P is the precision matrix
226template <typename T, typename S, size_t I, size_t N, size_t E, size_t D,
227 typename F, std::size_t... Is>
228__global__ void compute_JtPv_kernel(T *y, const T *x, const size_t *active_ids,
229 const size_t *ids,
230 const size_t *hessian_ids,
231 const size_t num_threads, const S *jacs,
232 const uint8_t *active_state, const S *pmat,
233 const S *chi2_derivative,
234 const std::index_sequence<Is...>) {
235 const size_t idx = get_thread_id();
236
237 if (idx >= num_threads) {
238 return;
239 }
240
241 constexpr auto jacobian_size = D * E;
242
243 // Stored as E x d col major, but we need to transpose it to d x E, where d is
244 // the vertex size
245 const size_t factor_id = active_ids[idx / D];
246 const size_t local_id =
247 ids[factor_id * N +
248 I]; // N is the number of vertices involved in the factor
249 if (!is_vertex_active(active_state, local_id)) {
250 return;
251 }
252 const auto jacobian_offset = factor_id * jacobian_size;
253 const auto error_offset = factor_id * E;
254 const auto col_offset = (idx % D) * E; // for untransposed J
255
256 constexpr auto precision_matrix_size = E * E;
257 const auto precision_offset = factor_id * precision_matrix_size;
258
259 // Use loss kernel
260 // const auto dL = chi2_derivative[factor_id];
261
262 // T x2[E] = {0};
263 // T value = 0;
264
265 const S *jcol = jacs + jacobian_offset + col_offset;
266
267 const S *precision_matrix = pmat + precision_offset;
268 const T *x_start = x + error_offset;
269
270 T value = 0;
271#pragma unroll
272 for (int i = 0; i < E; i++) { // pmat row
273 const auto p_row = precision_matrix + i * E;
274 S x2 = 0;
275#pragma unroll
276 for (int j = 0; j < E; j++) { // pmat col
277 x2 += p_row[j] * (S)x_start[j];
278 }
279 value += (T)(jcol[i] * x2);
280 }
281
282 value *= (T)chi2_derivative[factor_id];
283
284 const auto hessian_offset =
285 hessian_ids[local_id]; // each vertex has a hessian_ids array
286
287 atomicAdd(&y[hessian_offset + (idx % D)], value);
288}
289
290template <typename T, typename S, size_t I, size_t N, typename M, size_t E,
291 size_t D, typename F, typename VT, std::size_t... Is>
292__global__ void compute_JtPv_dynamic_kernel(
293 T *y, const T *x, const size_t *active_ids, const size_t *ids,
294 const size_t *hessian_ids, const size_t num_threads, const VT args,
295 const M *obs, const T *jacobian_scales,
296 const typename F::ConstraintDataType *constraint_data,
297 const uint8_t *active_state, const S *pmat, const S *chi2_derivative,
298 const std::index_sequence<Is...>) {
299 const size_t idx = get_thread_id();
300
301 if (idx >= num_threads) {
302 return;
303 }
304
305 constexpr auto jacobian_size = D * E;
306
307 // Stored as E x d col major, but we need to transpose it to d x E, where d is
308 // the vertex size
309 const size_t factor_id = active_ids[idx / D];
310 const size_t local_id =
311 ids[factor_id * N +
312 I]; // N is the number of vertices involved in the factor
313 if (!is_vertex_active(active_state, local_id)) {
314 return;
315 }
316 const auto error_offset = factor_id * E;
317 const auto col_offset = (idx % D) * E; // for untransposed J
318
319 constexpr auto precision_matrix_size = E * E;
320 const auto precision_offset = factor_id * precision_matrix_size;
321
322 using G = T;
323 G jacobian[jacobian_size];
324
325 compute_Jblock<T, I, N, M, E, F, VT>(jacobian, factor_id, local_id, obs,
326 constraint_data, ids, hessian_ids, args,
327 std::make_index_sequence<N>{});
328
329 const auto hessian_offset = hessian_ids[local_id];
330 const auto scale = jacobian_scales[hessian_offset + (idx % D)];
331
332 const G *jcol = jacobian + col_offset;
333
334 const S *precision_matrix = pmat + precision_offset;
335 const T *x_start = x + error_offset;
336
337 T value = 0;
338#pragma unroll
339 for (int i = 0; i < E; i++) { // pmat row
340 const auto p_row = precision_matrix + i * E;
341 S x2 = 0;
342#pragma unroll
343 for (int j = 0; j < E; j++) { // pmat col
344 x2 += p_row[j] * (S)x_start[j];
345 }
346 value += (T)((S)jcol[i] * x2);
347 }
348
349 value *= (T)chi2_derivative[factor_id] * scale;
350
351 atomicAdd(&y[hessian_offset + (idx % D)], value);
352}
353
354template <typename T, typename S, typename F, std::size_t... Is>
355void launch_kernel_compute_JtPv(
356 F *f, T *out, T *in,
357 std::array<const size_t *, F::get_num_vertices()> &hessian_ids,
358 std::array<S *, F::get_num_vertices()> &jacs, const T *jacobian_scales,
359 const size_t num_factors, StreamPool &streams, std::index_sequence<Is...>) {
360 (([&] {
361 constexpr auto num_vertices = F::get_num_vertices();
362 const auto num_threads = num_factors * F::get_vertex_sizes()[Is];
363 // std::cout << "Launching compute Jtv kernel" << std::endl;
364 // std::cout << "Num threads: " << num_threads << std::endl;
365 size_t threads_per_block = 256;
366 size_t num_blocks =
367 (num_threads + threads_per_block - 1) / threads_per_block;
368
369 // std::cout << "Checking obs ptr: " << f->device_obs.data().get() <<
370 // std::endl; std::cout << "Checking residual ptr: " <<
371 // f->residuals.data().get() << std::endl; std::cout << "Checking ids
372 // ptr: " << f->device_ids.data().get() << std::endl;
373 if (f->store_jacobians() || !is_analytical<F>()) {
374 compute_JtPv_kernel<T, S, Is, num_vertices, F::error_dim,
375 f->get_vertex_sizes()[Is], F>
376 <<<num_blocks, threads_per_block, 0, streams.select(Is)>>>(
377 out, in, f->active_indices.data().get(),
378 f->device_ids.data().get(), hessian_ids[Is], num_threads,
379 jacs[Is], f->vertex_descriptors[Is]->get_active_state(),
380 f->precision_matrices.data().get(),
381 f->chi2_derivative.data().get(),
382 std::make_index_sequence<num_vertices>{});
383 } else {
384 if constexpr (is_analytical<F>()) {
385 compute_JtPv_dynamic_kernel<T, S, Is, num_vertices,
386 typename F::ObservationType, F::error_dim,
387 f->get_vertex_sizes()[Is], F,
388 typename F::VertexPointerPointerTuple>
389 <<<num_blocks, threads_per_block, 0, streams.select(Is)>>>(
390 out, in, f->active_indices.data().get(),
391 f->device_ids.data().get(), hessian_ids[Is], num_threads,
392 f->get_vertices(), f->device_obs.data().get(), jacobian_scales,
393 f->data.data().get(),
394 f->vertex_descriptors[Is]->get_active_state(),
395 f->precision_matrices.data().get(),
396 f->chi2_derivative.data().get(),
397 std::make_index_sequence<num_vertices>{});
398 }
399 }
400 }()),
401 ...);
402}
403
404template <typename T, typename S, typename F>
405void compute_Jtv(F *f, T *out, T *in, const T *jacobian_scales,
406 StreamPool &streams) {
407 constexpr auto num_vertices = f->get_num_vertices();
408 constexpr auto vertex_sizes = F::get_vertex_sizes();
409
410 // std::array<T*, num_vertices> verts;
411 auto verts = f->get_vertices();
412 std::array<S *, num_vertices> jacs;
413 std::array<const size_t *, num_vertices> hessian_ids;
414 for (int i = 0; i < num_vertices; i++) {
415 // verts[i] = f->vertex_descriptors[i]->x();
416 jacs[i] = f->jacobians[i].data.data().get();
417 hessian_ids[i] = f->vertex_descriptors[i]->get_hessian_ids();
418 }
419
420 const auto num_factors = f->active_count();
421
422 launch_kernel_compute_JtPv<T, S>(f, out, in, hessian_ids, jacs,
423 jacobian_scales, num_factors, streams,
424 std::make_index_sequence<num_vertices>{});
425 streams.sync_n(num_vertices);
426}
427
428} // namespace ops
429
430} // namespace graphite