9template <
typename T,
typename G,
size_t I,
size_t N,
typename M,
size_t E,
10 typename F,
typename VT, std::size_t... Is>
12compute_Jcol_ad(Dual<T, G> *error,
const size_t col,
const size_t factor_id,
13 const size_t vertex_id,
const M *obs,
14 const typename F::ConstraintDataType *constraint_data,
15 size_t *ids,
const size_t *hessian_ids, VT args,
16 std::index_sequence<Is...>) {
18 constexpr auto vertex_sizes = F::get_vertex_sizes();
20 const M *local_obs = obs + factor_id;
22 const typename F::ConstraintDataType *local_data =
23 constraint_data + factor_id;
25 auto v = cuda::std::make_tuple(std::array<Dual<T, G>, vertex_sizes[Is]>{}...);
28 std::make_tuple((*(std::get<Is>(args) + ids[factor_id * N + Is]))...);
30 auto copy_vertices = [&v, &vertex_sizes, &vargs](
auto &&...ptrs) {
31 ((std::tuple_element<Is, typename F::Traits::VertexDescriptors>::type::
32 Traits::parameters(*std::get<Is>(vargs),
33 cuda::std::get<Is>(v).data())),
37 std::apply(copy_vertices, vargs);
39 cuda::std::get<I>(v)[col].dual =
static_cast<G
>(1);
41 F::Traits::error(cuda::std::get<Is>(v).data()..., local_obs, error, vargs,
49template <
typename T,
typename S,
size_t I,
size_t N,
size_t E,
size_t D,
50 typename F, std::size_t... Is>
51__global__
void compute_Jv_kernel(T *y,
const T *x,
const size_t *active_ids,
52 const size_t *ids,
const size_t *hessian_ids,
53 const size_t num_threads,
const S *jacs,
54 const uint8_t *active_state,
55 std::index_sequence<Is...>) {
57 const size_t idx = get_thread_id();
59 if (idx >= num_threads) {
63 constexpr auto jacobian_size = D * E;
66 const size_t factor_id = active_ids[idx / E];
67 const size_t local_id =
70 if (!is_vertex_active(active_state, local_id)) {
73 const auto jacobian_offset = factor_id * jacobian_size;
77 const auto hessian_offset =
78 hessian_ids[local_id];
79 const auto row_offset = (idx % E);
89 const S *jrow = jacs + jacobian_offset + row_offset;
90 const T *x_start = x + hessian_offset;
93 for (
int i = 0; i < D; i++) {
94 value += (T)(jrow[i * E] * (S)x_start[i]);
97 atomicAdd(&y[idx], value);
101template <
typename T,
typename S,
size_t I,
size_t N,
typename M,
size_t E,
102 typename F,
typename VT, std::size_t... Is>
103__global__
void compute_Jv_dynamic_manual2(
104 T *y, T *x,
const M *obs,
const T *jacobian_scales,
105 const typename F::ConstraintDataType *constraint_data,
106 const size_t *active_ids,
const size_t *ids,
const size_t *hessian_ids,
107 const size_t num_factors, VT args,
const uint8_t *active_state,
108 std::index_sequence<Is...>) {
109 const size_t idx = get_thread_id();
110 constexpr size_t D = F::get_vertex_sizes()[I];
112 const size_t factor_id = active_ids[idx / E];
113 const size_t row_in_jacobian = idx % E;
114 if (factor_id < num_factors) {
116 const auto vertex_id = ids[factor_id * N + I];
118 const auto hess_col = hessian_ids[vertex_id];
119 const T *x_start = x + hess_col;
122 if (is_vertex_active(active_state, vertex_id)) {
126 constexpr auto jacobian_size = E * D;
127 G jacobian[jacobian_size];
128 compute_Jblock<T, I, N, M, E, F, VT>(jacobian, factor_id, vertex_id, obs,
129 constraint_data, ids, hessian_ids,
130 args, std::make_index_sequence<N>{});
134 for (
size_t i = 0; i < D; i++) {
135 const T scaled_j =
static_cast<T
>(jacobian[row_in_jacobian + i * E]) *
136 jacobian_scales[hess_col + i];
137 sum +=
static_cast<T
>((S)scaled_j * (S)x_start[i]);
139 atomicAdd(&y[idx], sum);
144template <
typename T,
typename S,
typename F, std::size_t... Is>
145void launch_kernel_compute_Jv(
147 std::array<
const size_t *, F::get_num_vertices()> &hessian_ids,
148 std::array<S *, F::get_num_vertices()> &jacs,
const T *jacobian_scales,
149 const size_t num_factors, StreamPool &streams, std::index_sequence<Is...>) {
151 constexpr auto num_vertices = F::get_num_vertices();
152 constexpr auto vertex_sizes = F::get_vertex_sizes();
153 if (f->store_jacobians() || !is_analytical<F>()) {
154 const auto num_threads = num_factors * F::error_dim;
156 size_t threads_per_block = 256;
158 (num_threads + threads_per_block - 1) / threads_per_block;
159 compute_Jv_kernel<T, S, Is, num_vertices, F::error_dim,
160 f->get_vertex_sizes()[Is], F>
161 <<<num_blocks, threads_per_block, 0, streams.select(Is)>>>(
162 out, in, f->active_indices.data().get(),
163 f->device_ids.data().get(), hessian_ids[Is], num_threads,
164 jacs[Is], f->vertex_descriptors[Is]->get_active_state(),
165 std::make_index_sequence<num_vertices>{});
167 constexpr auto num_vertices = F::get_num_vertices();
168 constexpr auto vertex_sizes = F::get_vertex_sizes();
169 const auto num_threads = num_factors * F::error_dim;
171 size_t threads_per_block = 256;
173 (num_threads + threads_per_block - 1) / threads_per_block;
174 constexpr size_t E = F::error_dim;
176 if constexpr (is_analytical<F>()) {
178 compute_Jv_dynamic_manual2<T, S, Is, num_vertices,
179 typename F::ObservationType, E, F,
180 typename F::VertexPointerPointerTuple>
181 <<<num_blocks, threads_per_block, 0, streams.select(Is)>>>(
182 out, in, f->device_obs.data().get(), jacobian_scales,
183 f->data.data().get(), f->active_indices.data().get(),
184 f->device_ids.data().get(), hessian_ids[Is], num_factors,
186 f->vertex_descriptors[Is]->get_active_state(),
187 std::make_index_sequence<num_vertices>{});
194template <
typename T,
typename S,
typename F>
195void compute_Jv(F *f, T *out, T *in,
const T *jacobian_scales,
196 StreamPool &streams) {
197 constexpr auto num_vertices = F::get_num_vertices();
198 constexpr auto vertex_sizes = F::get_vertex_sizes();
201 auto verts = f->get_vertices();
202 std::array<S *, num_vertices> jacs;
203 std::array<const size_t *, num_vertices> hessian_ids;
204 for (
int i = 0; i < num_vertices; i++) {
206 jacs[i] = f->jacobians[i].data.data().get();
207 hessian_ids[i] = f->vertex_descriptors[i]->get_hessian_ids();
210 const auto num_factors = f->active_count();
212 launch_kernel_compute_Jv<T, S>(f, out, in, hessian_ids, jacs, jacobian_scales,
213 num_factors, streams,
214 std::make_index_sequence<num_vertices>{});
215 streams.sync_n(num_vertices);
226template <
typename T,
typename S,
size_t I,
size_t N,
size_t E,
size_t D,
227 typename F, std::size_t... Is>
228__global__
void compute_JtPv_kernel(T *y,
const T *x,
const size_t *active_ids,
230 const size_t *hessian_ids,
231 const size_t num_threads,
const S *jacs,
232 const uint8_t *active_state,
const S *pmat,
233 const S *chi2_derivative,
234 const std::index_sequence<Is...>) {
235 const size_t idx = get_thread_id();
237 if (idx >= num_threads) {
241 constexpr auto jacobian_size = D * E;
245 const size_t factor_id = active_ids[idx / D];
246 const size_t local_id =
249 if (!is_vertex_active(active_state, local_id)) {
252 const auto jacobian_offset = factor_id * jacobian_size;
253 const auto error_offset = factor_id * E;
254 const auto col_offset = (idx % D) * E;
256 constexpr auto precision_matrix_size = E * E;
257 const auto precision_offset = factor_id * precision_matrix_size;
265 const S *jcol = jacs + jacobian_offset + col_offset;
267 const S *precision_matrix = pmat + precision_offset;
268 const T *x_start = x + error_offset;
272 for (
int i = 0; i < E; i++) {
273 const auto p_row = precision_matrix + i * E;
276 for (
int j = 0; j < E; j++) {
277 x2 += p_row[j] * (S)x_start[j];
279 value += (T)(jcol[i] * x2);
282 value *= (T)chi2_derivative[factor_id];
284 const auto hessian_offset =
285 hessian_ids[local_id];
287 atomicAdd(&y[hessian_offset + (idx % D)], value);
290template <
typename T,
typename S,
size_t I,
size_t N,
typename M,
size_t E,
291 size_t D,
typename F,
typename VT, std::size_t... Is>
292__global__
void compute_JtPv_dynamic_kernel(
293 T *y,
const T *x,
const size_t *active_ids,
const size_t *ids,
294 const size_t *hessian_ids,
const size_t num_threads,
const VT args,
295 const M *obs,
const T *jacobian_scales,
296 const typename F::ConstraintDataType *constraint_data,
297 const uint8_t *active_state,
const S *pmat,
const S *chi2_derivative,
298 const std::index_sequence<Is...>) {
299 const size_t idx = get_thread_id();
301 if (idx >= num_threads) {
305 constexpr auto jacobian_size = D * E;
309 const size_t factor_id = active_ids[idx / D];
310 const size_t local_id =
313 if (!is_vertex_active(active_state, local_id)) {
316 const auto error_offset = factor_id * E;
317 const auto col_offset = (idx % D) * E;
319 constexpr auto precision_matrix_size = E * E;
320 const auto precision_offset = factor_id * precision_matrix_size;
323 G jacobian[jacobian_size];
325 compute_Jblock<T, I, N, M, E, F, VT>(jacobian, factor_id, local_id, obs,
326 constraint_data, ids, hessian_ids, args,
327 std::make_index_sequence<N>{});
329 const auto hessian_offset = hessian_ids[local_id];
330 const auto scale = jacobian_scales[hessian_offset + (idx % D)];
332 const G *jcol = jacobian + col_offset;
334 const S *precision_matrix = pmat + precision_offset;
335 const T *x_start = x + error_offset;
339 for (
int i = 0; i < E; i++) {
340 const auto p_row = precision_matrix + i * E;
343 for (
int j = 0; j < E; j++) {
344 x2 += p_row[j] * (S)x_start[j];
346 value += (T)((S)jcol[i] * x2);
349 value *= (T)chi2_derivative[factor_id] * scale;
351 atomicAdd(&y[hessian_offset + (idx % D)], value);
354template <
typename T,
typename S,
typename F, std::size_t... Is>
355void launch_kernel_compute_JtPv(
357 std::array<
const size_t *, F::get_num_vertices()> &hessian_ids,
358 std::array<S *, F::get_num_vertices()> &jacs,
const T *jacobian_scales,
359 const size_t num_factors, StreamPool &streams, std::index_sequence<Is...>) {
361 constexpr auto num_vertices = F::get_num_vertices();
362 const auto num_threads = num_factors * F::get_vertex_sizes()[Is];
365 size_t threads_per_block = 256;
367 (num_threads + threads_per_block - 1) / threads_per_block;
373 if (f->store_jacobians() || !is_analytical<F>()) {
374 compute_JtPv_kernel<T, S, Is, num_vertices, F::error_dim,
375 f->get_vertex_sizes()[Is], F>
376 <<<num_blocks, threads_per_block, 0, streams.select(Is)>>>(
377 out, in, f->active_indices.data().get(),
378 f->device_ids.data().get(), hessian_ids[Is], num_threads,
379 jacs[Is], f->vertex_descriptors[Is]->get_active_state(),
380 f->precision_matrices.data().get(),
381 f->chi2_derivative.data().get(),
382 std::make_index_sequence<num_vertices>{});
384 if constexpr (is_analytical<F>()) {
385 compute_JtPv_dynamic_kernel<T, S, Is, num_vertices,
386 typename F::ObservationType, F::error_dim,
387 f->get_vertex_sizes()[Is], F,
388 typename F::VertexPointerPointerTuple>
389 <<<num_blocks, threads_per_block, 0, streams.select(Is)>>>(
390 out, in, f->active_indices.data().get(),
391 f->device_ids.data().get(), hessian_ids[Is], num_threads,
392 f->get_vertices(), f->device_obs.data().get(), jacobian_scales,
393 f->data.data().get(),
394 f->vertex_descriptors[Is]->get_active_state(),
395 f->precision_matrices.data().get(),
396 f->chi2_derivative.data().get(),
397 std::make_index_sequence<num_vertices>{});
404template <
typename T,
typename S,
typename F>
405void compute_Jtv(F *f, T *out, T *in,
const T *jacobian_scales,
406 StreamPool &streams) {
407 constexpr auto num_vertices = f->get_num_vertices();
408 constexpr auto vertex_sizes = F::get_vertex_sizes();
411 auto verts = f->get_vertices();
412 std::array<S *, num_vertices> jacs;
413 std::array<const size_t *, num_vertices> hessian_ids;
414 for (
int i = 0; i < num_vertices; i++) {
416 jacs[i] = f->jacobians[i].data.data().get();
417 hessian_ids[i] = f->vertex_descriptors[i]->get_hessian_ids();
420 const auto num_factors = f->active_count();
422 launch_kernel_compute_JtPv<T, S>(f, out, in, hessian_ids, jacs,
423 jacobian_scales, num_factors, streams,
424 std::make_index_sequence<num_vertices>{});
425 streams.sync_n(num_vertices);