13 using P = std::conditional_t<is_low_precision<S>::value, T, S>;
15 std::vector<std::pair<size_t, size_t>> block_sizes;
16 std::unordered_map<BaseVertexDescriptor<T, S> *, thrust::device_vector<P>>
19 std::unordered_map<BaseVertexDescriptor<T, S> *, thrust::device_vector<P>>
22 std::unordered_map<BaseVertexDescriptor<T, S> *, thrust::device_vector<P>>
25 cublasHandle_t handle;
29 thrust::host_vector<P *> A_ptrs, Ainv_ptrs;
30 thrust::device_vector<P *> A_ptrs_device, Ainv_ptrs_device;
31 thrust::device_vector<int> info;
35 cublasCreate(&handle);
36 cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE);
41 virtual void update_structure(Graph<T, S> *graph,
StreamPool &streams) {
43 this->dimension = dimension;
44 auto &vertex_descriptors = graph->get_vertex_descriptors();
46 for (
auto &desc : vertex_descriptors) {
48 const auto d = desc->dimension();
49 const size_t num_values =
50 d * d * desc->count();
51 block_diagonals[desc].resize(num_values);
52 scalar_diagonals[desc].resize(desc->count() * d);
53 P_inv[desc].resize(num_values);
58 size_t max_num_blocks = 0;
59 size_t max_data_size = 0;
61 for (
auto &desc : vertex_descriptors) {
62 const size_t num_blocks = desc->count();
63 const size_t d = desc->dimension();
64 const size_t block_size = d * d;
66 max_num_blocks = std::max(max_num_blocks, num_blocks);
67 max_data_size = std::max(max_data_size, num_blocks * block_size);
70 A_ptrs.resize(max_num_blocks);
71 Ainv_ptrs.resize(max_num_blocks);
72 info.resize(max_num_blocks);
74 A_ptrs_device.resize(max_num_blocks);
75 Ainv_ptrs_device.resize(max_num_blocks);
79 virtual void update_values(Graph<T, S> *graph,
StreamPool &streams) {
80 const cudaStream_t stream = 0;
81 auto &vertex_descriptors = graph->get_vertex_descriptors();
82 auto &factor_descriptors = graph->get_factor_descriptors();
83 auto jacobian_scales = graph->get_jacobian_scales().data().get();
86 for (
auto &desc : vertex_descriptors) {
87 thrust::fill(thrust::cuda::par_nosync.on(stream),
88 block_diagonals[desc].begin(), block_diagonals[desc].end(),
91 for (
auto &desc : factor_descriptors) {
92 desc->compute_hessian_block_diagonal_async(block_diagonals,
93 jacobian_scales, stream);
96 for (
auto &desc : vertex_descriptors) {
98 auto b = block_diagonals[desc].data().get();
99 auto s = scalar_diagonals[desc].data().get();
101 auto start = thrust::make_counting_iterator<size_t>(0);
102 auto end = start + scalar_diagonals[desc].size();
103 const size_t D = desc->dimension();
104 thrust::for_each(thrust::cuda::par_nosync.on(stream), start, end,
105 [b, s, D] __device__(
const size_t idx) {
106 const size_t vertex_id = idx / D;
107 const auto block = b + vertex_id * D * D;
108 const size_t col = idx % D;
109 s[idx] = block[col * D + col];
113 cudaStreamSynchronize(stream);
116 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
117 const bool use_identity,
120 const cudaStream_t stream = 0;
121 cublasSetStream(handle, stream);
122 auto &vertex_descriptors = graph->get_vertex_descriptors();
123 auto &factor_descriptors = graph->get_factor_descriptors();
127 for (
auto &desc : vertex_descriptors) {
128 desc->augment_block_diagonal_async(block_diagonals[desc].data().get(),
129 scalar_diagonals[desc].data().get(),
130 damping_factor, use_identity, stream);
133 const auto d = desc->dimension();
134 const size_t num_blocks = desc->count();
135 const auto block_size = d * d;
136 const size_t data_size = num_blocks * block_size;
138 P *a_ptr = block_diagonals[desc].data().get();
140 P *a_inv_ptr = P_inv[desc].data().get();
141 for (
size_t i = 0; i < num_blocks; ++i) {
142 A_ptrs[i] = a_ptr + i * block_size;
143 Ainv_ptrs[i] = a_inv_ptr + i * block_size;
146 cudaMemcpyAsync(A_ptrs_device.data().get(), A_ptrs.data(),
147 sizeof(P *) * num_blocks, cudaMemcpyHostToDevice, stream);
148 cudaMemcpyAsync(Ainv_ptrs_device.data().get(), Ainv_ptrs.data(),
149 sizeof(P *) * num_blocks, cudaMemcpyHostToDevice, stream);
152 if constexpr (std::is_same<P, double>::value) {
154 cublasDmatinvBatched(handle, d, A_ptrs_device.data().get(), d,
155 Ainv_ptrs_device.data().get(), d,
156 info.data().get(), num_blocks);
157 }
else if constexpr (std::is_same<P, float>::value) {
158 cublasSmatinvBatched(handle, d, A_ptrs_device.data().get(), d,
159 Ainv_ptrs_device.data().get(), d,
160 info.data().get(), num_blocks);
163 is_low_precision<S>::value || std::is_same<S, float>::value ||
164 std::is_same<S, double>::value,
165 "BlockJacobiPreconditioner only supports bfloat16, float, or "
171 cudaStreamSynchronize(stream);
174 void apply(Graph<T, S> *graph, T *z,
const T *r,
178 auto &vertex_descriptors = graph->get_vertex_descriptors();
179 for (
auto &desc : vertex_descriptors) {
180 const auto d = desc->dimension();
181 P *blocks = P_inv[desc].data().get();
182 desc->apply_block_jacobi(z, r, blocks, streams.select(i));