Graphite  0.5.0
GPU-accelerated graph optimization framework
Loading...
Searching...
No Matches
cudss.hpp
Go to the documentation of this file.
1
2#pragma once
3#include <cstdint>
4#include <cuda_runtime.h>
5#include <cudss.h>
8#include <type_traits>
9
10namespace graphite {
11
13public:
16
21
24 cudssMatrixType_t matrix_type;
25
28 hybrid_memory = -1;
29 matrix_type = CUDSS_MTYPE_SYMMETRIC;
30 }
31};
32
33template <typename T> cudaDataType_t get_cuda_data_type();
34
35template <> inline cudaDataType_t get_cuda_data_type<double>() {
36 return CUDA_R_64F;
37}
38
39template <> inline cudaDataType_t get_cuda_data_type<float>() {
40 return CUDA_R_32F;
41}
42
43template <typename Index> inline cudaDataType_t get_cuda_index_type() {
44 static_assert(std::is_same<Index, int32_t>::value ||
45 std::is_same<Index, int64_t>::value,
46 "cudssSolver index type must be int32_t or int64_t");
47 if (std::is_same<Index, int32_t>::value) {
48 return CUDA_R_32I;
49 }
50 return CUDA_R_64I;
51}
52
53template <typename T, typename S, typename Index = int32_t>
54class cudssSolver : public Solver<T, S> {
55private:
56 static_assert(std::is_same<Index, int32_t>::value ||
57 std::is_same<Index, int64_t>::value,
58 "cudssSolver index type must be int32_t or int64_t");
59
60 void fill_matrix_structure() {
61 const auto dim = d_matrix.d_pointers.size() - 1;
62 const auto nnz = d_matrix.d_values.size();
63 const cudssMatrixViewType_t view_type = CUDSS_MVIEW_LOWER;
64 const cudssIndexBase_t index_base = CUDSS_BASE_ZERO;
65
66 if (m_A != NULL) {
67 cudssMatrixDestroy(m_A);
68 m_A = NULL;
69 }
70
71 cudssMatrixCreateCsr(&m_A, dim, dim, nnz, d_matrix.d_pointers.data().get(),
72 nullptr, d_matrix.d_indices.data().get(),
73 d_matrix.d_values.data().get(),
74 get_cuda_index_type<Index>(), get_cuda_data_type<S>(),
75 matrix_type, view_type, index_base);
76 }
77
78 void fill_matrix_values() {
79 cudssMatrixSetValues(m_A, d_matrix.d_values.data().get());
80 }
81
82 Hessian<T, S> H;
83 CSCMatrix<S, Index> d_matrix;
84 cudssMatrixType_t matrix_type;
85
86 bool factorization_failed;
87
88 cudaStream_t stream;
89 cudssHandle_t handle;
90
91 cudssConfig_t solver_config;
92 cudssData_t solver_data;
93
94 cudssMatrix_t m_x, m_b, m_A;
95
96 thrust::device_vector<T> solver_x;
97 int64_t configured_hybrid_memory_limit;
98
99public:
100 explicit cudssSolver(
101 const cudssSolverOptions &options = cudssSolverOptions()) {
102 stream = NULL;
103 m_x = NULL;
104 m_b = NULL;
105 m_A = NULL;
106 factorization_failed = false;
107 matrix_type = options.matrix_type;
108 configured_hybrid_memory_limit = options.hybrid_memory;
109
110 cudaStreamCreate(&stream);
111 cudssCreate(&handle);
112 cudssSetStream(handle, stream);
113
114 cudssConfigCreate(&solver_config);
115 int enable_hybrid_exec_mode =
116 (options.use_hybrid_execution && options.hybrid_memory <= 0) ? 1 : 0;
117 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_EXECUTE_MODE,
118 &enable_hybrid_exec_mode, sizeof(enable_hybrid_exec_mode));
119 int enable_hybrid_memory_mode = (options.hybrid_memory > 0) ? 1 : 0;
120 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_MODE,
121 &enable_hybrid_memory_mode,
122 sizeof(enable_hybrid_memory_mode));
123 if (options.hybrid_memory > 0) {
124 int64_t mem_limit = options.hybrid_memory;
125 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_DEVICE_MEMORY_LIMIT,
126 &mem_limit, sizeof(mem_limit));
127 }
128 cudssAlgType_t reordering_alg = CUDSS_ALG_DEFAULT;
129 cudssConfigSet(solver_config, CUDSS_CONFIG_REORDERING_ALG, &reordering_alg,
130 sizeof(reordering_alg));
131 cudssDataCreate(handle, &solver_data);
132 }
133
134 ~cudssSolver() {
135 if (m_A != NULL) {
136 cudssMatrixDestroy(m_A);
137 m_A = NULL;
138 }
139 if (m_b != NULL) {
140 cudssMatrixDestroy(m_b);
141 m_b = NULL;
142 }
143 if (m_x != NULL) {
144 cudssMatrixDestroy(m_x);
145 m_x = NULL;
146 }
147
148 cudssDataDestroy(handle, solver_data);
149 cudssConfigDestroy(solver_config);
150 cudssDestroy(handle);
151 cudaStreamSynchronize(stream);
152 cudaStreamDestroy(stream);
153 }
154
155 virtual void update_structure(Graph<T, S> *graph,
156 StreamPool &streams) override {
157 H.build_structure(graph, streams);
158 H.build_csc_structure(graph, d_matrix);
159 fill_matrix_structure();
160
161 // Create matrices for b and x
162 if (m_b != NULL) {
163 cudssMatrixDestroy(m_b);
164 m_b = NULL;
165 }
166 const auto dim = graph->get_hessian_dimension();
167 auto &b = graph->get_b();
168 int ldb = dim;
169 int ldx = dim;
170 cudssMatrixCreateDn(&m_b, dim, 1, ldb, b.data().get(),
171 get_cuda_data_type<T>(), CUDSS_LAYOUT_COL_MAJOR);
172
173 if (m_x != NULL) {
174 cudssMatrixDestroy(m_x);
175 m_x = NULL;
176 }
177 solver_x.resize(b.size());
178 cudssMatrixCreateDn(&m_x, dim, 1, ldx, solver_x.data().get(),
179 get_cuda_data_type<T>(), CUDSS_LAYOUT_COL_MAJOR);
180
181 // Factorize
182 factorization_failed = false;
183 auto status = cudssExecute(handle, CUDSS_PHASE_ANALYSIS, solver_config,
184 solver_data, m_A, m_x, m_b);
185 if (status != CUDSS_STATUS_SUCCESS) {
186 factorization_failed = true;
187 std::cerr << "cudss Analysis failed with error code: " << status
188 << std::endl;
189 } else if (configured_hybrid_memory_limit > 0) {
190 int64_t min_hybrid_memory = 0;
191 size_t size_written = 0;
192 status = cudssDataGet(
193 handle, solver_data, CUDSS_DATA_HYBRID_DEVICE_MEMORY_MIN,
194 &min_hybrid_memory, sizeof(min_hybrid_memory), &size_written);
195 if (status == CUDSS_STATUS_SUCCESS &&
196 configured_hybrid_memory_limit < min_hybrid_memory) {
197 configured_hybrid_memory_limit = min_hybrid_memory;
198 std::cerr << "Requested cuDSS hybrid memory limit is too low; raising "
199 "to minimum required "
200 << configured_hybrid_memory_limit << " bytes." << std::endl;
201 status = cudssConfigSet(solver_config,
202 CUDSS_CONFIG_HYBRID_DEVICE_MEMORY_LIMIT,
203 &configured_hybrid_memory_limit,
204 sizeof(configured_hybrid_memory_limit));
205 if (status != CUDSS_STATUS_SUCCESS) {
206 factorization_failed = true;
207 std::cerr << "Failed to update cuDSS hybrid memory limit with error "
208 "code: "
209 << status << std::endl;
210 }
211 }
212 }
213 cudaStreamSynchronize(stream);
214 }
215
216 virtual void update_values(Graph<T, S> *graph, StreamPool &streams) override {
217 H.update_values(graph, streams);
218 H.update_csc_values(graph, d_matrix);
219 fill_matrix_values(); // for CPU matrix
220 }
221
222 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
223 const bool use_identity,
224 StreamPool &streams) override {
225 H.apply_damping(graph, damping_factor, use_identity, streams);
226 H.update_csc_values(graph, d_matrix);
227 fill_matrix_values(); // TODO: Use a more lightweight method to just update
228 // diagonal
229 }
230
231 virtual bool solve(Graph<T, S> *graph, T *x, StreamPool &streams) override {
232
233 if (factorization_failed) {
234 return false;
235 }
236
237 auto dim = graph->get_hessian_dimension();
238
239 thrust::fill(thrust::device, x, x + dim, static_cast<T>(0.0));
240 thrust::copy(thrust::device, x, x + dim, solver_x.data());
241
242 cudssStatus_t status;
243
244 // set values for b and x
245 cudssMatrixSetValues(m_b, graph->get_b().data().get());
246 cudssMatrixSetValues(m_x, solver_x.data().get());
247
248 status = cudssExecute(handle, CUDSS_PHASE_FACTORIZATION, solver_config,
249 solver_data, m_A, m_x, m_b);
250 if (status != CUDSS_STATUS_SUCCESS) {
251 std::cerr << "cudss Factorization failed with error code: " << status
252 << std::endl;
253 return false;
254 }
255
256 status = cudssExecute(handle, CUDSS_PHASE_SOLVE, solver_config, solver_data,
257 m_A, m_x, m_b);
258 if (status != CUDSS_STATUS_SUCCESS) {
259 std::cerr << "cudss Solve failed with error code: " << status
260 << std::endl;
261 return false;
262 }
263 cudaStreamSynchronize(stream);
264
265 thrust::copy(thrust::device, solver_x.data(), solver_x.data() + dim, x);
266
267 return true;
268 }
269};
270
271} // namespace graphite
Linear solver interface. Implement this for your own linear solvers.
Definition solver.hpp:12
Definition stream.hpp:7
Definition cudss.hpp:12
cudssMatrixType_t matrix_type
Definition cudss.hpp:24
int64_t hybrid_memory
Definition cudss.hpp:20
bool use_hybrid_execution
Default is true, which enables cudss hybrid execution modewhere.
Definition cudss.hpp:15
Definition cudss.hpp:54
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4