Graphite
Loading...
Searching...
No Matches
cudss.hpp
Go to the documentation of this file.
1
2#pragma once
3#include <cuda_runtime.h>
4#include <cudss.h>
7
8namespace graphite {
9
10template <typename T> cudaDataType_t get_cuda_data_type();
11
12template <> inline cudaDataType_t get_cuda_data_type<double>() {
13 return CUDA_R_64F;
14}
15
16template <> inline cudaDataType_t get_cuda_data_type<float>() {
17 return CUDA_R_32F;
18}
19
20template <typename T, typename S> class cudssSolver : public Solver<T, S> {
21private:
22 using StorageIndex = int32_t;
23
24 void fill_matrix_structure() {
25 const auto dim = d_matrix.d_pointers.size() - 1;
26 const auto nnz = d_matrix.d_values.size();
27 const cudssMatrixType_t matrix_type = CUDSS_MTYPE_SYMMETRIC;
28 const cudssMatrixViewType_t view_type = CUDSS_MVIEW_LOWER;
29 const cudssIndexBase_t index_base = CUDSS_BASE_ZERO;
30
31 if (m_A != NULL) {
32 cudssMatrixDestroy(m_A);
33 m_A = NULL;
34 }
35
36 cudssMatrixCreateCsr(&m_A, dim, dim, nnz, d_matrix.d_pointers.data().get(),
37 nullptr, d_matrix.d_indices.data().get(),
38 d_matrix.d_values.data().get(), CUDA_R_32I,
39 get_cuda_data_type<S>(), matrix_type, view_type,
40 index_base);
41 }
42
43 void fill_matrix_values() {
44 cudssMatrixSetValues(m_A, d_matrix.d_values.data().get());
45 }
46
47 Hessian<T, S> H;
48 CSCMatrix<S, StorageIndex> d_matrix;
49
50 bool factorization_failed;
51
52 cudaStream_t stream;
53 cudssHandle_t handle;
54
55 cudssConfig_t solver_config;
56 cudssData_t solver_data;
57
58 cudssMatrix_t m_x, m_b, m_A;
59
60 thrust::device_vector<T> solver_x;
61
62public:
63 cudssSolver(bool use_hybrid_execution) {
64 stream = NULL;
65 m_x = NULL;
66 m_b = NULL;
67 m_A = NULL;
68 factorization_failed = false;
69
70 cudaStreamCreate(&stream);
71 cudssCreate(&handle);
72 cudssSetStream(handle, stream);
73
74 cudssConfigCreate(&solver_config);
75 int enable_hybrid_exec_mode = use_hybrid_execution ? 1 : 0;
76 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_EXECUTE_MODE,
77 &enable_hybrid_exec_mode, sizeof(enable_hybrid_exec_mode));
78 cudssAlgType_t reordering_alg = CUDSS_ALG_DEFAULT;
79 cudssConfigSet(solver_config, CUDSS_CONFIG_REORDERING_ALG, &reordering_alg,
80 sizeof(reordering_alg));
81 cudssDataCreate(handle, &solver_data);
82 }
83
84 ~cudssSolver() {
85 if (m_A != NULL) {
86 cudssMatrixDestroy(m_A);
87 m_A = NULL;
88 }
89 if (m_b != NULL) {
90 cudssMatrixDestroy(m_b);
91 m_b = NULL;
92 }
93 if (m_x != NULL) {
94 cudssMatrixDestroy(m_x);
95 m_x = NULL;
96 }
97
98 cudssDataDestroy(handle, solver_data);
99 cudssConfigDestroy(solver_config);
100 cudssDestroy(handle);
101 cudaStreamSynchronize(stream);
102 cudaStreamDestroy(stream);
103 }
104
105 virtual void update_structure(Graph<T, S> *graph,
106 StreamPool &streams) override {
107 H.build_structure(graph, streams);
108 H.build_csc_structure(graph, d_matrix);
109 fill_matrix_structure();
110
111 // Create matrices for b and x
112 if (m_b != NULL) {
113 cudssMatrixDestroy(m_b);
114 m_b = NULL;
115 }
116 const auto dim = graph->get_hessian_dimension();
117 auto &b = graph->get_b();
118 int ldb = dim;
119 int ldx = dim;
120 cudssMatrixCreateDn(&m_b, dim, 1, ldb, b.data().get(),
121 get_cuda_data_type<T>(), CUDSS_LAYOUT_COL_MAJOR);
122
123 if (m_x != NULL) {
124 cudssMatrixDestroy(m_x);
125 m_x = NULL;
126 }
127 solver_x.resize(b.size());
128 cudssMatrixCreateDn(&m_x, dim, 1, ldx, solver_x.data().get(),
129 get_cuda_data_type<T>(), CUDSS_LAYOUT_COL_MAJOR);
130
131 // Factorize
132 factorization_failed = false;
133 auto status = cudssExecute(handle, CUDSS_PHASE_ANALYSIS, solver_config,
134 solver_data, m_A, m_x, m_b);
135 if (status != CUDSS_STATUS_SUCCESS) {
136 factorization_failed = true;
137 std::cerr << "cudss Analysis failed with error code: " << status
138 << std::endl;
139 }
140 cudaStreamSynchronize(stream);
141 }
142
143 virtual void update_values(Graph<T, S> *graph, StreamPool &streams) override {
144 H.update_values(graph, streams);
145 H.update_csc_values(graph, d_matrix);
146 fill_matrix_values(); // for CPU matrix
147 }
148
149 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
150 StreamPool &streams) override {
151 H.apply_damping(graph, damping_factor, streams);
152 H.update_csc_values(graph, d_matrix);
153 fill_matrix_values(); // TODO: Use a more lightweight method to just update
154 // diagonal
155 }
156
157 virtual bool solve(Graph<T, S> *graph, T *x, StreamPool &streams) override {
158
159 if (factorization_failed) {
160 return false;
161 }
162
163 auto dim = graph->get_hessian_dimension();
164
165 thrust::fill(thrust::device, x, x + dim, static_cast<T>(0.0));
166 thrust::copy(thrust::device, x, x + dim, solver_x.data());
167
168 cudssStatus_t status;
169
170 // set values for b and x
171 cudssMatrixSetValues(m_b, graph->get_b().data().get());
172 cudssMatrixSetValues(m_x, solver_x.data().get());
173
174 status = cudssExecute(handle, CUDSS_PHASE_FACTORIZATION, solver_config,
175 solver_data, m_A, m_x, m_b);
176 if (status != CUDSS_STATUS_SUCCESS) {
177 std::cerr << "cudss Factorization failed with error code: " << status
178 << std::endl;
179 return false;
180 }
181
182 status = cudssExecute(handle, CUDSS_PHASE_SOLVE, solver_config, solver_data,
183 m_A, m_x, m_b);
184 if (status != CUDSS_STATUS_SUCCESS) {
185 std::cerr << "cudss Solve failed with error code: " << status
186 << std::endl;
187 return false;
188 }
189 cudaStreamSynchronize(stream);
190
191 thrust::copy(thrust::device, solver_x.data(), solver_x.data() + dim, x);
192
193 return true;
194 }
195};
196
197} // namespace graphite
Linear solver interface. Implement this for your own linear solvers.
Definition solver.hpp:12
Definition stream.hpp:7
Definition cudss.hpp:20