10 constexpr static size_t value = N;
11 static_assert(N == P,
"Helper to print N and P at compile time");
17template <
class R,
class... Args>
19 : std::integral_constant<std::size_t,
sizeof...(Args)> {};
23template <
class R,
class First,
class... Rest>
28template <
typename F,
typename D> __device__
constexpr bool takes_vertices() {
29 return std::is_reference<
30 typename first_arg<
decltype(&F::Traits::template error<D>)>::type>::value;
33template <
typename F,
typename D,
typename VertexPointers,
34 typename ParameterBlocks,
typename Observation,
35 typename ConstraintData,
typename ErrorVector, std::size_t... Is>
37call_error_fn(VertexPointers &vertices, ParameterBlocks ¶meters,
38 Observation &local_obs, ConstraintData &local_data,
39 ErrorVector &local_error, std::index_sequence<Is...>) {
41 using DataType =
typename F::ConstraintDataType;
42 using ObsType =
typename F::ObservationType;
43 constexpr size_t N = F::get_num_vertices();
45 constexpr size_t num_parameters =
46 fn_arity<
decltype(&F::Traits::template error<D>)>::value;
48 if constexpr (std::is_empty<ObsType>::value &&
49 std::is_empty<DataType>::value) {
50 if constexpr (N * 2 + 1 == num_parameters) {
51 F::Traits::error((*cuda::std::get<Is>(vertices))...,
52 cuda::std::get<Is>(parameters).data()..., local_error);
53 }
else if constexpr (takes_vertices<F, D>()) {
54 F::Traits::error((*cuda::std::get<Is>(vertices))..., local_error);
56 F::Traits::error(cuda::std::get<Is>(parameters).data()..., local_error);
58 }
else if constexpr (std::is_empty<DataType>::value) {
60 if constexpr (N * 2 + 2 == num_parameters) {
61 F::Traits::error((*cuda::std::get<Is>(vertices))...,
62 cuda::std::get<Is>(parameters).data()..., *local_obs,
64 }
else if constexpr (takes_vertices<F, D>()) {
65 F::Traits::error((*cuda::std::get<Is>(vertices))..., *local_obs,
68 F::Traits::error(cuda::std::get<Is>(parameters).data()..., *local_obs,
71 }
else if constexpr (std::is_empty<ObsType>::value) {
72 if constexpr (N * 2 + 2 == num_parameters) {
73 F::Traits::error((*cuda::std::get<Is>(vertices))...,
74 cuda::std::get<Is>(parameters).data()..., *local_data,
76 }
else if constexpr (takes_vertices<F, D>()) {
77 F::Traits::error((*cuda::std::get<Is>(vertices))..., *local_data,
80 F::Traits::error(cuda::std::get<Is>(parameters).data()..., *local_data,
84 if constexpr (N * 2 + 3 == num_parameters) {
85 F::Traits::error((*cuda::std::get<Is>(vertices))...,
86 cuda::std::get<Is>(parameters).data()..., *local_obs,
87 *local_data, local_error);
88 }
else if constexpr (takes_vertices<F, D>()) {
89 F::Traits::error((*cuda::std::get<Is>(vertices))..., *local_obs,
90 *local_data, local_error);
92 F::Traits::error(cuda::std::get<Is>(parameters).data()..., *local_obs,
93 *local_data, local_error);
98template <
typename T,
typename S,
size_t I,
size_t N,
typename M,
size_t E,
99 typename F,
typename VT, std::size_t... Is>
100__global__
void compute_error_kernel_autodiff(
101 const M *obs, T *error,
102 const typename F::ConstraintDataType *constraint_data,
103 const size_t *active_ids,
const size_t *ids,
const size_t *hessian_ids,
104 const size_t num_threads, VT args, S *jacs,
const uint8_t *active_state,
105 std::index_sequence<Is...>) {
106 const size_t idx = get_thread_id();
108 if (idx >= num_threads) {
112 constexpr auto vertex_sizes = F::get_vertex_sizes();
113 const auto factor_id = active_ids[idx / vertex_sizes[I]];
114 const auto vertex_id = ids[factor_id * N + I];
118 using G = std::conditional_t<is_low_precision<S>::value, T, S>;
119 const M *local_obs = obs + factor_id;
120 Dual<T, G> local_error[E];
121 const typename F::ConstraintDataType *local_data =
122 constraint_data + factor_id;
124 auto v = cuda::std::make_tuple(std::array<Dual<T, G>, vertex_sizes[Is]>{}...);
126 auto vargs = cuda::std::make_tuple(
127 (*(std::get<Is>(args) + ids[factor_id * N + Is]))...);
129 auto copy_vertices = [&v, &vertex_sizes, &vargs](
auto &&...ptrs) {
130 ((std::tuple_element<Is, typename F::Traits::VertexDescriptors>::type::
131 Traits::parameters(*cuda::std::get<Is>(vargs),
132 cuda::std::get<Is>(v).data())),
136 cuda::std::apply(copy_vertices, vargs);
138 cuda::std::get<I>(v)[idx % vertex_sizes[I]].dual =
static_cast<G
>(1);
140 call_error_fn<F, Dual<T, G>>(vargs, v, local_obs, local_data, local_error,
141 std::make_index_sequence<N>{});
143 constexpr auto j_size = vertex_sizes[I] * E;
145 const auto col_offset = (idx % vertex_sizes[I]) * E;
154 if (idx % vertex_sizes[I] == 0) {
156 for (
size_t i = 0; i < E; ++i) {
157 error[factor_id * E + i] = local_error[i].real;
163 if (!is_vertex_active(active_state, vertex_id)) {
167 if constexpr (std::is_same<S, __half>::value) {
170 for (
size_t i = 0; i < E; ++i) {
171 jacs[j_size * factor_id + col_offset + i] =
172 static_cast<S
>(std::clamp(local_error[i].dual, -65504.0f, 65504.0f));
176 for (
size_t i = 0; i < E; ++i) {
177 jacs[j_size * factor_id + col_offset + i] = local_error[i].dual;
182template <
typename T,
typename S,
typename F,
typename VT, std::size_t... Is>
183void launch_kernel_autodiff(
184 F *f, std::array<
const size_t *, F::get_num_vertices()> &hessian_ids,
185 VT &verts, std::array<S *, F::get_num_vertices()> &jacs,
186 const size_t num_factors, StreamPool &streams, std::index_sequence<Is...>) {
188 constexpr auto num_vertices = F::get_num_vertices();
189 const auto num_threads = num_factors * F::get_vertex_sizes()[Is];
192 size_t threads_per_block = 256;
194 (num_threads + threads_per_block - 1) / threads_per_block;
201 compute_error_kernel_autodiff<T, S, Is, num_vertices,
202 typename F::ObservationType, F::error_dim, F,
203 typename F::VertexPointerPointerTuple>
204 <<<num_blocks, threads_per_block, 0, streams.select(Is)>>>(
205 f->device_obs.data().get(), f->residuals.data().get(),
206 f->data.data().get(), f->active_indices.data().get(),
207 f->device_ids.data().get(), hessian_ids[Is], num_threads, verts,
208 jacs[Is], f->vertex_descriptors[Is]->get_active_state(),
209 std::make_index_sequence<num_vertices>{});
214template <
typename T,
typename S,
typename F>
215void compute_error_autodiff(F *f, StreamPool &streams) {
219 constexpr auto num_vertices = F::get_num_vertices();
220 constexpr auto vertex_sizes = F::get_vertex_sizes();
224 auto verts = f->get_vertices();
225 std::array<S *, num_vertices> jacs;
226 std::array<const size_t *, num_vertices> hessian_ids;
227 for (
int i = 0; i < num_vertices; i++) {
229 jacs[i] = f->jacobians[i].data.data().get();
230 hessian_ids[i] = f->vertex_descriptors[i]->get_hessian_ids();
233 thrust::fill(thrust::cuda::par_nosync.on(streams.select(i)),
234 f->jacobians[i].data.begin(), f->jacobians[i].data.end(),
238 const auto num_factors = f->active_count();
240 if constexpr (!is_analytical<F>()) {
241 launch_kernel_autodiff<T, S>(f, hessian_ids, verts, jacs, num_factors,
243 std::make_index_sequence<num_vertices>{});
245 streams.sync_n(num_vertices);
250template <
typename T,
size_t N,
typename M,
size_t E,
typename F,
typename VT,
253compute_error_kernel(
const M *obs, T *error,
254 const typename F::ConstraintDataType *constraint_data,
255 const size_t *active_ids,
const size_t *ids,
256 const size_t num_threads, VT args,
257 std::index_sequence<Is...>) {
258 const size_t idx = get_thread_id();
260 if (idx >= num_threads) {
264 constexpr auto vertex_sizes = F::get_vertex_sizes();
265 const auto factor_id = active_ids[idx];
267 const M *local_obs = obs + factor_id;
268 T *local_error = error + factor_id * E;
269 const typename F::ConstraintDataType *local_data =
270 constraint_data + factor_id;
272 auto v = cuda::std::make_tuple(std::array<T, vertex_sizes[Is]>{}...);
274 auto vargs = cuda::std::make_tuple(
275 (*(std::get<Is>(args) + ids[factor_id * N + Is]))...);
277 auto copy_vertices = [&v, &vertex_sizes, &vargs](
auto &&...ptrs) {
278 ((std::tuple_element<Is, typename F::Traits::VertexDescriptors>::type::
279 Traits::parameters(*cuda::std::get<Is>(vargs),
280 cuda::std::get<Is>(v).data())),
284 cuda::std::apply(copy_vertices, vargs);
286 call_error_fn<F, T>(vargs, v, local_obs, local_data, local_error,
287 std::make_index_sequence<N>{});
290 for (
size_t i = 0; i < E; ++i) {
291 error[factor_id * E + i] = local_error[i];
295template <
typename T,
typename F>
void compute_error(F *f) {
297 constexpr auto num_vertices = F::get_num_vertices();
298 constexpr auto vertex_sizes = F::get_vertex_sizes();
301 auto verts = f->get_vertices();
302 std::array<const size_t *, num_vertices> hessian_ids;
303 for (
int i = 0; i < num_vertices; i++) {
304 hessian_ids[i] = f->vertex_descriptors[i]->get_hessian_ids();
307 constexpr auto error_dim = F::error_dim;
308 const auto num_factors = f->active_count();
310 const auto num_threads = num_factors;
311 size_t threads_per_block = 256;
312 size_t num_blocks = (num_threads + threads_per_block - 1) / threads_per_block;
314 compute_error_kernel<T, num_vertices,
typename F::ObservationType,
315 F::error_dim, F,
typename F::VertexPointerPointerTuple>
316 <<<num_blocks, threads_per_block>>>(
317 f->device_obs.data().get(), f->residuals.data().get(),
318 f->data.data().get(), f->active_indices.data().get(),
319 f->device_ids.data().get(), num_threads, verts,
320 std::make_index_sequence<num_vertices>{});
322 cudaStreamSynchronize(0);