19 using P = std::conditional_t<is_low_precision<S>::value, T, S>;
23 thrust::host_vector<size_t> h_src_offsets;
24 thrust::host_vector<size_t> h_vec_offsets;
25 thrust::host_vector<ops::BlockCopyOp> h_copy_ops;
26 thrust::device_vector<size_t> d_a_offsets;
27 thrust::device_vector<size_t> d_vec_offsets;
28 thrust::device_vector<ops::BlockCopyOp> d_copy_ops;
29 thrust::device_vector<P> blocks;
30 thrust::device_vector<P> blocks_inv;
31 thrust::host_vector<P *> h_A_ptrs;
32 thrust::host_vector<P *> h_Ainv_ptrs;
33 thrust::device_vector<P *> d_A_ptrs;
34 thrust::device_vector<P *> d_Ainv_ptrs;
35 thrust::device_vector<int> d_info;
38 std::unordered_map<size_t, DimGroup> dim_groups;
40 cublasHandle_t handle;
44 cublasCreate(&handle);
45 cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE);
50 void update_structure(Graph<T, S> *graph, SchurComplement<T, S> *schur,
51 StreamPool &streams)
override {
55 pose_dim = graph->get_offset_vector()[schur->lowest_eliminated_block_col];
56 const auto &offsets = graph->get_offset_vector();
58 for (
size_t block = 0; block < schur->lowest_eliminated_block_col;
60 const size_t dim = graph->get_variable_dimension(block);
61 auto diag_it = schur->block_indices.find(BlockCoordinates{block, block});
62 if (diag_it == schur->block_indices.end()) {
66 auto &group = dim_groups[dim];
68 group.h_src_offsets.push_back(diag_it->second);
69 group.h_vec_offsets.push_back(offsets[block]);
72 for (
auto &entry : dim_groups) {
73 auto &group = entry.second;
74 const size_t num_blocks = group.h_src_offsets.size();
75 const size_t block_size = group.dim * group.dim;
77 group.blocks.resize(num_blocks * block_size);
78 group.blocks_inv.resize(num_blocks * block_size);
79 group.h_A_ptrs.resize(num_blocks);
80 group.h_Ainv_ptrs.resize(num_blocks);
81 group.d_A_ptrs.resize(num_blocks);
82 group.d_Ainv_ptrs.resize(num_blocks);
83 group.d_info.resize(num_blocks);
84 group.d_vec_offsets = group.h_vec_offsets;
85 group.h_copy_ops.resize(num_blocks);
86 group.d_copy_ops.resize(num_blocks);
88 thrust::host_vector<size_t> h_a_offsets(num_blocks);
89 for (
size_t i = 0; i < num_blocks; ++i) {
90 h_a_offsets[i] = i * block_size;
92 ops::BlockCopyOp{group.h_src_offsets[i], h_a_offsets[i]};
94 group.d_a_offsets = h_a_offsets;
95 group.d_copy_ops = group.h_copy_ops;
97 P *blocks_ptr = group.blocks.data().get();
98 P *blocks_inv_ptr = group.blocks_inv.data().get();
99 for (
size_t i = 0; i < num_blocks; ++i) {
100 group.h_A_ptrs[i] = blocks_ptr + i * block_size;
101 group.h_Ainv_ptrs[i] = blocks_inv_ptr + i * block_size;
104 cudaMemcpyAsync(group.d_A_ptrs.data().get(), group.h_A_ptrs.data(),
105 sizeof(P *) * num_blocks, cudaMemcpyHostToDevice,
107 cudaMemcpyAsync(group.d_Ainv_ptrs.data().get(), group.h_Ainv_ptrs.data(),
108 sizeof(P *) * num_blocks, cudaMemcpyHostToDevice,
110 cudaStreamSynchronize(streams.select(0));
114 void update_values(Graph<T, S> *graph, SchurComplement<T, S> *schur,
115 StreamPool &streams)
override {
117 auto stream = streams.select(0);
118 cublasSetStream(handle, stream);
120 const S *schur_values = schur->values.data().get();
121 constexpr size_t threads_per_block = 256;
123 for (
auto &entry : dim_groups) {
124 auto &group = entry.second;
125 const size_t num_blocks = group.h_src_offsets.size();
126 if (num_blocks == 0) {
129 const size_t block_size = group.dim * group.dim;
130 P *blocks_ptr = group.blocks.data().get();
132 const size_t total = num_blocks * block_size;
133 const size_t blocks = (total + threads_per_block - 1) / threads_per_block;
134 ops::block_copy_batched_kernel<S, P>
135 <<<blocks, threads_per_block, 0, stream>>>(
136 schur_values, blocks_ptr, group.d_copy_ops.data().get(),
137 num_blocks, group.dim, group.dim);
139 if constexpr (std::is_same<P, double>::value) {
140 cublasDmatinvBatched(handle, group.dim, group.d_A_ptrs.data().get(),
141 group.dim, group.d_Ainv_ptrs.data().get(),
142 group.dim, group.d_info.data().get(), num_blocks);
143 }
else if constexpr (std::is_same<P, float>::value) {
144 cublasSmatinvBatched(handle, group.dim, group.d_A_ptrs.data().get(),
145 group.dim, group.d_Ainv_ptrs.data().get(),
146 group.dim, group.d_info.data().get(), num_blocks);
150 cudaStreamSynchronize(stream);
153 void set_damping_factor(Graph<T, S> *graph, SchurComplement<T, S> *schur,
154 T damping_factor,
const bool use_identity,
155 StreamPool &streams)
override {}
157 void apply(Graph<T, S> *graph, SchurComplement<T, S> *schur, T *z,
const T *r,
158 StreamPool &streams)
override {
161 const auto stream = streams.select(0);
162 thrust::fill(thrust::cuda::par_nosync.on(stream), z, z + pose_dim,
165 constexpr size_t threads_per_block = 256;
166 for (
auto &entry : dim_groups) {
167 auto &group = entry.second;
168 const size_t num_blocks = group.h_src_offsets.size();
169 const size_t total_rows = num_blocks * group.dim;
170 const size_t blocks =
171 (total_rows + threads_per_block - 1) / threads_per_block;
173 ops::block_matvec_assign_batched_kernel<T, P, T>
174 <<<blocks, threads_per_block, 0, stream>>>(
175 group.blocks_inv.data().get(), group.d_a_offsets.data().get(), r,
176 z, group.d_vec_offsets.data().get(), num_blocks, group.dim);