Graphite  0.5.0
GPU-accelerated graph optimization framework
Loading...
Searching...
No Matches
cudss_schur.hpp
Go to the documentation of this file.
1
2#pragma once
3
4#include <cuda_runtime.h>
5#include <cudss.h>
7#include <graphite/schur.hpp>
10
11namespace graphite {
12
13template <typename T, typename S, typename Index = int32_t>
14class cudssSchurSolver : public Solver<T, S> {
15private:
16 static_assert(std::is_same<Index, int32_t>::value ||
17 std::is_same<Index, int64_t>::value,
18 "cudssSchurSolver index type must be int32_t or int64_t");
19
20 Hessian<T, S> H;
21 SchurComplement<T, S> schur;
22 CSCMatrix<S, Index> d_matrix;
23 cudssMatrixType_t matrix_type;
24
25 bool factorization_failed;
26
27 cudaStream_t stream;
28 cudssHandle_t handle;
29
30 cudssConfig_t solver_config;
31 cudssData_t solver_data;
32
33 cudssMatrix_t m_x, m_b, m_A;
34
35 thrust::device_vector<T> solver_x;
36 size_t schur_dim;
37 int64_t configured_hybrid_memory_limit;
38
39 void fill_matrix_structure() {
40 const auto dim = d_matrix.d_pointers.size() - 1;
41 const auto nnz = d_matrix.d_values.size();
42 const cudssMatrixViewType_t view_type = CUDSS_MVIEW_LOWER;
43 const cudssIndexBase_t index_base = CUDSS_BASE_ZERO;
44
45 if (m_A != NULL) {
46 cudssMatrixDestroy(m_A);
47 m_A = NULL;
48 }
49
50 cudssMatrixCreateCsr(&m_A, dim, dim, nnz, d_matrix.d_pointers.data().get(),
51 nullptr, d_matrix.d_indices.data().get(),
52 d_matrix.d_values.data().get(),
53 get_cuda_index_type<Index>(), get_cuda_data_type<S>(),
54 matrix_type, view_type, index_base);
55 }
56
57 void fill_matrix_values() {
58 cudssMatrixSetValues(m_A, d_matrix.d_values.data().get());
59 }
60
61public:
62 explicit cudssSchurSolver(
63 const cudssSolverOptions &options = cudssSolverOptions())
64 : schur(H), factorization_failed(false), schur_dim(0) {
65 stream = NULL;
66 m_x = NULL;
67 m_b = NULL;
68 m_A = NULL;
69 matrix_type = options.matrix_type;
70 configured_hybrid_memory_limit = options.hybrid_memory;
71
72 cudaStreamCreate(&stream);
73 cudssCreate(&handle);
74 cudssSetStream(handle, stream);
75
76 cudssConfigCreate(&solver_config);
77 int enable_hybrid_exec_mode =
78 (options.use_hybrid_execution && options.hybrid_memory <= 0) ? 1 : 0;
79 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_EXECUTE_MODE,
80 &enable_hybrid_exec_mode, sizeof(enable_hybrid_exec_mode));
81 int enable_hybrid_memory_mode = (options.hybrid_memory > 0) ? 1 : 0;
82 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_MODE,
83 &enable_hybrid_memory_mode,
84 sizeof(enable_hybrid_memory_mode));
85 if (options.hybrid_memory > 0) {
86 int64_t mem_limit = options.hybrid_memory;
87 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_DEVICE_MEMORY_LIMIT,
88 &mem_limit, sizeof(mem_limit));
89 }
90 cudssAlgType_t reordering_alg = CUDSS_ALG_DEFAULT;
91 cudssConfigSet(solver_config, CUDSS_CONFIG_REORDERING_ALG, &reordering_alg,
92 sizeof(reordering_alg));
93 cudssDataCreate(handle, &solver_data);
94 }
95
97 if (m_A != NULL) {
98 cudssMatrixDestroy(m_A);
99 m_A = NULL;
100 }
101 if (m_b != NULL) {
102 cudssMatrixDestroy(m_b);
103 m_b = NULL;
104 }
105 if (m_x != NULL) {
106 cudssMatrixDestroy(m_x);
107 m_x = NULL;
108 }
109
110 cudssDataDestroy(handle, solver_data);
111 cudssConfigDestroy(solver_config);
112 cudssDestroy(handle);
113 cudaStreamSynchronize(stream);
114 cudaStreamDestroy(stream);
115 }
116
117 virtual void update_structure(Graph<T, S> *graph,
118 StreamPool &streams) override {
119 H.build_structure(graph, streams);
120 schur.build_structure(graph, streams);
121 schur.build_csc_structure(graph, d_matrix);
122 fill_matrix_structure();
123
124 const auto &offsets = graph->get_offset_vector();
125 schur_dim = offsets[schur.lowest_eliminated_block_col];
126
127 if (m_b != NULL) {
128 cudssMatrixDestroy(m_b);
129 m_b = NULL;
130 }
131 int ldb = static_cast<int>(schur_dim);
132 cudssMatrixCreateDn(&m_b, schur_dim, 1, ldb,
133 schur.get_b_Schur().data().get(),
134 get_cuda_data_type<T>(), CUDSS_LAYOUT_COL_MAJOR);
135
136 if (m_x != NULL) {
137 cudssMatrixDestroy(m_x);
138 m_x = NULL;
139 }
140 solver_x.resize(schur_dim);
141 int ldx = static_cast<int>(schur_dim);
142 cudssMatrixCreateDn(&m_x, schur_dim, 1, ldx, solver_x.data().get(),
143 get_cuda_data_type<T>(), CUDSS_LAYOUT_COL_MAJOR);
144
145 factorization_failed = false;
146 auto status = cudssExecute(handle, CUDSS_PHASE_ANALYSIS, solver_config,
147 solver_data, m_A, m_x, m_b);
148 if (status != CUDSS_STATUS_SUCCESS) {
149 factorization_failed = true;
150 std::cerr << "cudss Schur analysis failed with error code: " << status
151 << std::endl;
152 } else if (configured_hybrid_memory_limit > 0) {
153 int64_t min_hybrid_memory = 0;
154 size_t size_written = 0;
155 status = cudssDataGet(
156 handle, solver_data, CUDSS_DATA_HYBRID_DEVICE_MEMORY_MIN,
157 &min_hybrid_memory, sizeof(min_hybrid_memory), &size_written);
158 if (status == CUDSS_STATUS_SUCCESS &&
159 configured_hybrid_memory_limit < min_hybrid_memory) {
160 configured_hybrid_memory_limit = min_hybrid_memory;
161 std::cerr << "Requested cuDSS Schur hybrid memory limit is too low; "
162 "raising to minimum required "
163 << configured_hybrid_memory_limit << " bytes." << std::endl;
164 status = cudssConfigSet(solver_config,
165 CUDSS_CONFIG_HYBRID_DEVICE_MEMORY_LIMIT,
166 &configured_hybrid_memory_limit,
167 sizeof(configured_hybrid_memory_limit));
168 if (status != CUDSS_STATUS_SUCCESS) {
169 factorization_failed = true;
170 std::cerr
171 << "Failed to update cuDSS Schur hybrid memory limit with error "
172 "code: "
173 << status << std::endl;
174 }
175 }
176 }
177 cudaStreamSynchronize(stream);
178 }
179
180 virtual void update_values(Graph<T, S> *graph, StreamPool &streams) override {
181 H.update_values(graph, streams);
182 }
183
184 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
185 const bool use_identity,
186 StreamPool &streams) override {
187 H.apply_damping(graph, damping_factor, use_identity, streams);
188 }
189
190 virtual bool solve(Graph<T, S> *graph, T *x, StreamPool &streams) override {
191 if (factorization_failed) {
192 return false;
193 }
194
195 // Update values before solving
196 schur.update_values(graph, streams);
197 schur.update_csc_values(graph, d_matrix);
198 fill_matrix_values();
199
200 const auto dim = graph->get_hessian_dimension();
201
202 thrust::fill(thrust::device, x, x + dim, static_cast<T>(0.0));
203 thrust::fill(thrust::device, solver_x.begin(), solver_x.end(),
204 static_cast<T>(0.0));
205
206 cudssStatus_t status;
207
208 cudssMatrixSetValues(m_b, schur.get_b_Schur().data().get());
209 cudssMatrixSetValues(m_x, solver_x.data().get());
210
211 status = cudssExecute(handle, CUDSS_PHASE_FACTORIZATION, solver_config,
212 solver_data, m_A, m_x, m_b);
213 if (status != CUDSS_STATUS_SUCCESS) {
214 std::cerr << "cudss Schur factorization failed with error code: "
215 << status << std::endl;
216 return false;
217 }
218
219 status = cudssExecute(handle, CUDSS_PHASE_SOLVE, solver_config, solver_data,
220 m_A, m_x, m_b);
221 if (status != CUDSS_STATUS_SUCCESS) {
222 std::cerr << "cudss Schur solve failed with error code: " << status
223 << std::endl;
224 return false;
225 }
226 cudaStreamSynchronize(stream);
227
228 thrust::copy(thrust::device, solver_x.data(), solver_x.data() + schur_dim,
229 x);
230
231 schur.compute_landmark_update(graph, streams, x + schur_dim, x);
232
233 return true;
234 }
235};
236
237} // namespace graphite
Linear solver interface. Implement this for your own linear solvers.
Definition solver.hpp:12
Definition stream.hpp:7
Definition cudss_schur.hpp:14
Definition cudss.hpp:12
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4