Graphite
Loading...
Searching...
No Matches
block_jacobi.hpp
Go to the documentation of this file.
1
2#pragma once
3#include <cublas_v2.h>
5#include <thrust/count.h>
6#include <thrust/execution_policy.h>
7
8namespace graphite {
9
10template <typename T, typename S>
12private:
13 using P = std::conditional_t<is_low_precision<S>::value, T, S>;
14 size_t dimension;
15 std::vector<std::pair<size_t, size_t>> block_sizes;
16 std::unordered_map<BaseVertexDescriptor<T, S> *, thrust::device_vector<P>>
17 block_diagonals;
18
19 std::unordered_map<BaseVertexDescriptor<T, S> *, thrust::device_vector<P>>
20 scalar_diagonals;
21
22 std::unordered_map<BaseVertexDescriptor<T, S> *, thrust::device_vector<P>>
23 P_inv;
24
25 cublasHandle_t handle;
26
27 // For batched inversion
28 // TODO: Figure out a better way to handle the memory
29 thrust::host_vector<P *> A_ptrs, Ainv_ptrs;
30 thrust::device_vector<P *> A_ptrs_device, Ainv_ptrs_device;
31 thrust::device_vector<int> info;
32
33public:
35 cublasCreate(&handle);
36 cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE);
37 }
38
39 ~BlockJacobiPreconditioner() { cublasDestroy(handle); }
40
41 virtual void update_structure(Graph<T, S> *graph, StreamPool &streams) {
42
43 this->dimension = dimension;
44 auto &vertex_descriptors = graph->get_vertex_descriptors();
45
46 for (auto &desc : vertex_descriptors) {
47 // Reserve space
48 const auto d = desc->dimension();
49 const size_t num_values =
50 d * d * desc->count(); // includes inactive vertices
51 block_diagonals[desc].resize(num_values);
52 scalar_diagonals[desc].resize(desc->count() * d);
53 P_inv[desc].resize(num_values);
54 }
55
56 // Determine max sizes for buffers
57 {
58 size_t max_num_blocks = 0;
59 size_t max_data_size = 0;
60
61 for (auto &desc : vertex_descriptors) {
62 const size_t num_blocks = desc->count();
63 const size_t d = desc->dimension();
64 const size_t block_size = d * d;
65
66 max_num_blocks = std::max(max_num_blocks, num_blocks);
67 max_data_size = std::max(max_data_size, num_blocks * block_size);
68 }
69
70 A_ptrs.resize(max_num_blocks);
71 Ainv_ptrs.resize(max_num_blocks);
72 info.resize(max_num_blocks);
73
74 A_ptrs_device.resize(max_num_blocks);
75 Ainv_ptrs_device.resize(max_num_blocks);
76 }
77 };
78
79 virtual void update_values(Graph<T, S> *graph, StreamPool &streams) {
80 const cudaStream_t stream = 0;
81 auto &vertex_descriptors = graph->get_vertex_descriptors();
82 auto &factor_descriptors = graph->get_factor_descriptors();
83 auto jacobian_scales = graph->get_jacobian_scales().data().get();
84
85 // Compute Hessian blocks on the diagonal
86 for (auto &desc : vertex_descriptors) {
87 thrust::fill(thrust::cuda::par_nosync.on(stream),
88 block_diagonals[desc].begin(), block_diagonals[desc].end(),
89 static_cast<S>(0.0));
90 }
91 for (auto &desc : factor_descriptors) {
92 desc->compute_hessian_block_diagonal_async(block_diagonals,
93 jacobian_scales, stream);
94 }
95 // back up diagonals for each vertex descriptor
96 for (auto &desc : vertex_descriptors) {
97
98 auto b = block_diagonals[desc].data().get();
99 auto s = scalar_diagonals[desc].data().get();
100
101 auto start = thrust::make_counting_iterator<size_t>(0);
102 auto end = start + scalar_diagonals[desc].size();
103 const size_t D = desc->dimension();
104 thrust::for_each(thrust::cuda::par_nosync.on(stream), start, end,
105 [b, s, D] __device__(const size_t idx) {
106 const size_t vertex_id = idx / D;
107 const auto block = b + vertex_id * D * D;
108 const size_t col = idx % D;
109 s[idx] = block[col * D + col];
110 });
111 }
112
113 cudaStreamSynchronize(stream);
114 };
115
116 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
117 StreamPool &streams) {
118
119 const cudaStream_t stream = 0;
120 cublasSetStream(handle, stream);
121 auto &vertex_descriptors = graph->get_vertex_descriptors();
122 auto &factor_descriptors = graph->get_factor_descriptors();
123
124 // Invert the blocks
125
126 for (auto &desc : vertex_descriptors) {
127 desc->augment_block_diagonal_async(block_diagonals[desc].data().get(),
128 scalar_diagonals[desc].data().get(),
129 damping_factor, stream);
130
131 // Invert the block diagonal using cublas
132 const auto d = desc->dimension();
133 const size_t num_blocks = desc->count();
134 const auto block_size = d * d;
135 const size_t data_size = num_blocks * block_size;
136
137 P *a_ptr = block_diagonals[desc].data().get();
138
139 P *a_inv_ptr = P_inv[desc].data().get();
140 for (size_t i = 0; i < num_blocks; ++i) {
141 A_ptrs[i] = a_ptr + i * block_size;
142 Ainv_ptrs[i] = a_inv_ptr + i * block_size;
143 }
144
145 cudaMemcpyAsync(A_ptrs_device.data().get(), A_ptrs.data(),
146 sizeof(P *) * num_blocks, cudaMemcpyHostToDevice, stream);
147 cudaMemcpyAsync(Ainv_ptrs_device.data().get(), Ainv_ptrs.data(),
148 sizeof(P *) * num_blocks, cudaMemcpyHostToDevice, stream);
149
150 // cublas should use stream 0
151 if constexpr (std::is_same<P, double>::value) {
152
153 cublasDmatinvBatched(handle, d, A_ptrs_device.data().get(), d,
154 Ainv_ptrs_device.data().get(), d,
155 info.data().get(), num_blocks);
156 } else if constexpr (std::is_same<P, float>::value) {
157 cublasSmatinvBatched(handle, d, A_ptrs_device.data().get(), d,
158 Ainv_ptrs_device.data().get(), d,
159 info.data().get(), num_blocks);
160 } else {
161 static_assert(
162 is_low_precision<S>::value || std::is_same<S, float>::value ||
163 std::is_same<S, double>::value,
164 "BlockJacobiPreconditioner only supports bfloat16, float, or "
165 "double types.");
166 }
167 }
168
169 // Final sync
170 cudaStreamSynchronize(stream);
171 };
172
173 void apply(Graph<T, S> *graph, T *z, const T *r,
174 StreamPool &streams) override {
175 // Apply the preconditioner
176 size_t i = 0;
177 auto &vertex_descriptors = graph->get_vertex_descriptors();
178 for (auto &desc : vertex_descriptors) {
179 const auto d = desc->dimension();
180 P *blocks = P_inv[desc].data().get();
181 desc->apply_block_jacobi(z, r, blocks, streams.select(i));
182 i++;
183 }
184 streams.sync_n(i);
185 }
186};
187
188} // namespace graphite
Definition block_jacobi.hpp:11
Definition preconditioner.hpp:8
Definition stream.hpp:7