13 using P = std::conditional_t<is_low_precision<S>::value, T, S>;
15 std::vector<std::pair<size_t, size_t>> block_sizes;
16 std::unordered_map<BaseVertexDescriptor<T, S> *, thrust::device_vector<P>>
19 std::unordered_map<BaseVertexDescriptor<T, S> *, thrust::device_vector<P>>
22 std::unordered_map<BaseVertexDescriptor<T, S> *, thrust::device_vector<P>>
25 cublasHandle_t handle;
29 thrust::host_vector<P *> A_ptrs, Ainv_ptrs;
30 thrust::device_vector<P *> A_ptrs_device, Ainv_ptrs_device;
31 thrust::device_vector<int> info;
35 cublasCreate(&handle);
36 cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE);
41 virtual void update_structure(Graph<T, S> *graph,
StreamPool &streams) {
43 this->dimension = dimension;
44 auto &vertex_descriptors = graph->get_vertex_descriptors();
46 for (
auto &desc : vertex_descriptors) {
48 const auto d = desc->dimension();
49 const size_t num_values =
50 d * d * desc->count();
51 block_diagonals[desc].resize(num_values);
52 scalar_diagonals[desc].resize(desc->count() * d);
53 P_inv[desc].resize(num_values);
58 size_t max_num_blocks = 0;
59 size_t max_data_size = 0;
61 for (
auto &desc : vertex_descriptors) {
62 const size_t num_blocks = desc->count();
63 const size_t d = desc->dimension();
64 const size_t block_size = d * d;
66 max_num_blocks = std::max(max_num_blocks, num_blocks);
67 max_data_size = std::max(max_data_size, num_blocks * block_size);
70 A_ptrs.resize(max_num_blocks);
71 Ainv_ptrs.resize(max_num_blocks);
72 info.resize(max_num_blocks);
74 A_ptrs_device.resize(max_num_blocks);
75 Ainv_ptrs_device.resize(max_num_blocks);
79 virtual void update_values(Graph<T, S> *graph,
StreamPool &streams) {
80 const cudaStream_t stream = 0;
81 auto &vertex_descriptors = graph->get_vertex_descriptors();
82 auto &factor_descriptors = graph->get_factor_descriptors();
83 auto jacobian_scales = graph->get_jacobian_scales().data().get();
86 for (
auto &desc : vertex_descriptors) {
87 thrust::fill(thrust::cuda::par_nosync.on(stream),
88 block_diagonals[desc].begin(), block_diagonals[desc].end(),
91 for (
auto &desc : factor_descriptors) {
92 desc->compute_hessian_block_diagonal_async(block_diagonals,
93 jacobian_scales, stream);
96 for (
auto &desc : vertex_descriptors) {
98 auto b = block_diagonals[desc].data().get();
99 auto s = scalar_diagonals[desc].data().get();
101 auto start = thrust::make_counting_iterator<size_t>(0);
102 auto end = start + scalar_diagonals[desc].size();
103 const size_t D = desc->dimension();
104 thrust::for_each(thrust::cuda::par_nosync.on(stream), start, end,
105 [b, s, D] __device__(
const size_t idx) {
106 const size_t vertex_id = idx / D;
107 const auto block = b + vertex_id * D * D;
108 const size_t col = idx % D;
109 s[idx] = block[col * D + col];
113 cudaStreamSynchronize(stream);
116 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
119 const cudaStream_t stream = 0;
120 cublasSetStream(handle, stream);
121 auto &vertex_descriptors = graph->get_vertex_descriptors();
122 auto &factor_descriptors = graph->get_factor_descriptors();
126 for (
auto &desc : vertex_descriptors) {
127 desc->augment_block_diagonal_async(block_diagonals[desc].data().get(),
128 scalar_diagonals[desc].data().get(),
129 damping_factor, stream);
132 const auto d = desc->dimension();
133 const size_t num_blocks = desc->count();
134 const auto block_size = d * d;
135 const size_t data_size = num_blocks * block_size;
137 P *a_ptr = block_diagonals[desc].data().get();
139 P *a_inv_ptr = P_inv[desc].data().get();
140 for (
size_t i = 0; i < num_blocks; ++i) {
141 A_ptrs[i] = a_ptr + i * block_size;
142 Ainv_ptrs[i] = a_inv_ptr + i * block_size;
145 cudaMemcpyAsync(A_ptrs_device.data().get(), A_ptrs.data(),
146 sizeof(P *) * num_blocks, cudaMemcpyHostToDevice, stream);
147 cudaMemcpyAsync(Ainv_ptrs_device.data().get(), Ainv_ptrs.data(),
148 sizeof(P *) * num_blocks, cudaMemcpyHostToDevice, stream);
151 if constexpr (std::is_same<P, double>::value) {
153 cublasDmatinvBatched(handle, d, A_ptrs_device.data().get(), d,
154 Ainv_ptrs_device.data().get(), d,
155 info.data().get(), num_blocks);
156 }
else if constexpr (std::is_same<P, float>::value) {
157 cublasSmatinvBatched(handle, d, A_ptrs_device.data().get(), d,
158 Ainv_ptrs_device.data().get(), d,
159 info.data().get(), num_blocks);
162 is_low_precision<S>::value || std::is_same<S, float>::value ||
163 std::is_same<S, double>::value,
164 "BlockJacobiPreconditioner only supports bfloat16, float, or "
170 cudaStreamSynchronize(stream);
173 void apply(Graph<T, S> *graph, T *z,
const T *r,
177 auto &vertex_descriptors = graph->get_vertex_descriptors();
178 for (
auto &desc : vertex_descriptors) {
179 const auto d = desc->dimension();
180 P *blocks = P_inv[desc].data().get();
181 desc->apply_block_jacobi(z, r, blocks, streams.select(i));