Graphite  0.5.0
GPU-accelerated graph optimization framework
Loading...
Searching...
No Matches
schur.hpp
Go to the documentation of this file.
1
2#pragma once
3#include <Eigen/Core>
4#include <Eigen/Dense>
5#include <graphite/block.hpp>
7
8namespace graphite {
9
10namespace ops {
14template <typename T> struct MulOp {
15 T *destination;
16 T *left;
17 T *middle;
18 T *right;
19};
20
22 size_t a_offset;
23 size_t x_offset;
24 size_t y_offset;
25};
26
28 size_t src_offset;
29 size_t dst_offset;
30};
31
33 size_t landmark_col;
34 size_t pose_row_i;
35 size_t pose_row_j;
36 size_t left_offset;
37 size_t right_offset;
38};
39
40__global__ void count_pose_rows_per_landmark_column_kernel(
41 const size_t *col_pointers, const size_t *row_indices,
42 size_t landmark_col_start, size_t num_block_columns, size_t *pose_counts) {
43 const size_t idx = get_thread_id();
44 const size_t num_landmark_cols = num_block_columns - landmark_col_start;
45 if (idx >= num_landmark_cols) {
46 return;
47 }
48
49 const size_t l = landmark_col_start + idx;
50 const size_t col_start = col_pointers[l];
51 const size_t col_end = col_pointers[l + 1];
52
53 size_t count = 0;
54 for (size_t ka = col_start; ka < col_end; ka++) {
55 if (row_indices[ka] >= landmark_col_start) {
56 break;
57 }
58 count++;
59 }
60 pose_counts[idx] = count;
61}
62
63__global__ void fill_schur_structure_pairs_kernel(
64 const size_t *col_pointers, const size_t *row_indices,
65 const size_t landmark_col_start, const size_t num_block_columns,
66 const size_t *pose_counts, const size_t *pair_offsets,
67 BlockCoordinates *pairs_out) {
68 const size_t idx = get_thread_id();
69 const size_t num_landmark_cols = num_block_columns - landmark_col_start;
70 if (idx >= num_landmark_cols) {
71 return;
72 }
73
74 const size_t l = landmark_col_start + idx;
75 const size_t col_start = col_pointers[l];
76 const size_t pose_count = pose_counts[idx];
77 size_t out_offset = pair_offsets[idx];
78
79 for (size_t a = 0; a < pose_count; a++) {
80 const size_t i = row_indices[col_start + a];
81 for (size_t b = a; b < pose_count; b++) {
82 const size_t j = row_indices[col_start + b];
83 pairs_out[out_offset++] = BlockCoordinates{i, j};
84 }
85 }
86}
87
88__global__ void fill_schur_mul_tuples_kernel(
89 const size_t *col_pointers, const size_t *row_indices,
90 const size_t *block_offsets, size_t landmark_col_start,
91 size_t num_block_columns, const size_t *pose_counts,
92 const size_t *pair_offsets, SchurMulTuple *tuples_out) {
93 const size_t idx = get_thread_id();
94 const size_t num_landmark_cols = num_block_columns - landmark_col_start;
95 if (idx >= num_landmark_cols) {
96 return;
97 }
98
99 const size_t l = landmark_col_start + idx;
100 const size_t col_start = col_pointers[l];
101 const size_t pose_count = pose_counts[idx];
102 size_t out_offset = pair_offsets[idx];
103
104 for (size_t a = 0; a < pose_count; a++) {
105 const size_t ka = col_start + a;
106 const size_t i = row_indices[ka];
107 const size_t left_offset = block_offsets[ka];
108 for (size_t b = a; b < pose_count; b++) {
109 const size_t kb = col_start + b;
110 const size_t j = row_indices[kb];
111 tuples_out[out_offset++] =
112 SchurMulTuple{l, i, j, left_offset, block_offsets[kb]};
113 }
114 }
115}
116
117template <typename T, typename S>
118__global__ void
119schur_block_product_kernel(const MulOp<S> *ops, const size_t num_ops,
120 const size_t dim_a, const size_t dim_b,
121 const size_t dim_c) {
122 const size_t idx = get_thread_id();
123 const size_t block_size = dim_a * dim_c;
124 const size_t op_id = idx / block_size;
125 if (op_id >= num_ops) {
126 return;
127 }
128
129 const size_t offset = idx % block_size;
130 const size_t row = offset % dim_a;
131 const size_t col = offset / dim_a;
132
133 const auto &op = ops[op_id];
134 const S *left = op.left;
135 const S *middle = op.middle;
136 const S *right = op.right;
137
138 // Computes destination -= left * middle * right^T.
139 T value = 0;
140#pragma unroll
141 for (size_t k = 0; k < dim_b; k++) {
142 T m_rt = 0;
143#pragma unroll
144 for (size_t j = 0; j < dim_b; j++) {
145 m_rt += static_cast<T>(middle[k + j * dim_b]) *
146 static_cast<T>(right[col + j * dim_c]);
147 }
148 value += static_cast<T>(left[row + k * dim_a]) * m_rt;
149 }
150
151 atomicAdd(op.destination + (row + col * dim_a), static_cast<S>(-value));
152}
153
154template <int DIM_B, typename T, typename S>
155__global__ void
156schur_block_product_kernel_dim_b(const MulOp<S> *ops, const size_t num_ops,
157 const size_t dim_a, const size_t dim_c) {
158 const size_t idx = get_thread_id();
159 const size_t block_size = dim_a * dim_c;
160 const size_t op_id = idx / block_size;
161 if (op_id >= num_ops) {
162 return;
163 }
164
165 const size_t offset = idx % block_size;
166 const size_t row = offset % dim_a;
167 const size_t col = offset / dim_a;
168
169 const auto &op = ops[op_id];
170 const S *left = op.left;
171 const S *middle = op.middle;
172 const S *right = op.right;
173
174 // Computes destination -= left * middle * right^T.
175 T value = 0;
176#pragma unroll
177 for (int k = 0; k < DIM_B; k++) {
178 T m_rt = 0;
179#pragma unroll
180 for (int j = 0; j < DIM_B; j++) {
181 m_rt += static_cast<T>(middle[k + j * DIM_B]) *
182 static_cast<T>(right[col + static_cast<size_t>(j) * dim_c]);
183 }
184 value += static_cast<T>(left[row + static_cast<size_t>(k) * dim_a]) * m_rt;
185 }
186
187 atomicAdd(op.destination + (row + col * dim_a), static_cast<S>(-value));
188}
189
190template <typename highp, typename S, typename T>
191__global__ void block_matvec_assign_batched_kernel(
192 const S *values, const size_t *a_offsets, const T *x_base, T *y_base,
193 const size_t *vec_offsets, size_t num_blocks, size_t dim) {
194 const size_t idx = get_thread_id();
195 const size_t total_rows = num_blocks * dim;
196 if (idx >= total_rows) {
197 return;
198 }
199
200 const size_t block_id = idx / dim;
201 const size_t row = idx % dim;
202
203 const S *A = values + a_offsets[block_id];
204 const size_t vec_offset = vec_offsets[block_id];
205 const T *x = x_base + vec_offset;
206 T *y = y_base + vec_offset;
207
208 T sum = 0;
209 for (size_t c = 0; c < dim; c++) {
210 sum += static_cast<T>(A[row + c * dim]) * static_cast<T>(x[c]);
211 }
212 y[row] = sum;
213}
214
215template <typename T, typename S>
216__global__ void block_matvec_add_batched_kernel(
217 const S *values, const HplMatVecOp *ops, const size_t num_ops,
218 const T *x_base, T *y_base, const size_t rows, const size_t cols) {
219 const size_t idx = get_thread_id();
220 const size_t total_rows = num_ops * rows;
221 if (idx >= total_rows) {
222 return;
223 }
224
225 const size_t op_id = idx / rows;
226 const size_t row = idx % rows;
227 const auto &op = ops[op_id];
228
229 const S *A = values + op.a_offset;
230 const T *x = x_base + op.x_offset;
231 T *y = y_base + op.y_offset;
232
233 T sum = 0;
234 for (size_t c = 0; c < cols; c++) {
235 sum += static_cast<T>(A[row + c * rows]) * static_cast<T>(x[c]);
236 }
237 atomicAdd(y + row, sum);
238}
239
240template <typename T, typename S>
241__global__ void block_matvec_transpose_add_batched_kernel(
242 const S *values, const HplMatVecOp *ops, const size_t num_ops,
243 const T *x_base, T *y_base, const size_t rows, const size_t cols) {
244 const size_t idx = get_thread_id();
245 const size_t total_cols = num_ops * cols;
246 if (idx >= total_cols) {
247 return;
248 }
249
250 const size_t op_id = idx / cols;
251 const size_t col = idx % cols;
252 const auto &op = ops[op_id];
253
254 const S *A = values + op.a_offset;
255 const T *x = x_base + op.x_offset;
256 T *y = y_base + op.y_offset;
257
258 T sum = 0;
259 for (size_t r = 0; r < rows; r++) {
260 sum += static_cast<T>(A[r + col * rows]) * static_cast<T>(x[r]);
261 }
262 atomicAdd(y + col, sum);
263}
264
265template <typename Src, typename Dst = Src>
266__global__ void
267block_copy_batched_kernel(const Src *src_values, Dst *dst_values,
268 const BlockCopyOp *ops, const size_t num_ops,
269 const size_t rows, const size_t cols) {
270 const size_t idx = get_thread_id();
271 const size_t block_size = rows * cols;
272 const size_t total = num_ops * block_size;
273 if (idx >= total) {
274 return;
275 }
276
277 const size_t op_id = idx / block_size;
278 const size_t local_idx = idx % block_size;
279 const auto &op = ops[op_id];
280 dst_values[op.dst_offset + local_idx] =
281 static_cast<Dst>(src_values[op.src_offset + local_idx]);
282}
283
284} // namespace ops
285
286} // namespace graphite
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4
Definition schur.hpp:27
Definition schur.hpp:21
Stores offsets for Hpl*Hll^(-1)*Hpl^T operation.
Definition schur.hpp:14
Definition schur.hpp:32