56 static_assert(std::is_same<Index, int32_t>::value ||
57 std::is_same<Index, int64_t>::value,
58 "cudssSolver index type must be int32_t or int64_t");
60 void fill_matrix_structure() {
61 const auto dim = d_matrix.d_pointers.size() - 1;
62 const auto nnz = d_matrix.d_values.size();
63 const cudssMatrixViewType_t view_type = CUDSS_MVIEW_LOWER;
64 const cudssIndexBase_t index_base = CUDSS_BASE_ZERO;
67 cudssMatrixDestroy(m_A);
71 cudssMatrixCreateCsr(&m_A, dim, dim, nnz, d_matrix.d_pointers.data().get(),
72 nullptr, d_matrix.d_indices.data().get(),
73 d_matrix.d_values.data().get(),
74 get_cuda_index_type<Index>(), get_cuda_data_type<S>(),
75 matrix_type, view_type, index_base);
78 void fill_matrix_values() {
79 cudssMatrixSetValues(m_A, d_matrix.d_values.data().get());
83 CSCMatrix<S, Index> d_matrix;
84 cudssMatrixType_t matrix_type;
86 bool factorization_failed;
91 cudssConfig_t solver_config;
92 cudssData_t solver_data;
94 cudssMatrix_t m_x, m_b, m_A;
96 thrust::device_vector<T> solver_x;
97 int64_t configured_hybrid_memory_limit;
106 factorization_failed =
false;
107 matrix_type = options.matrix_type;
108 configured_hybrid_memory_limit = options.hybrid_memory;
110 cudaStreamCreate(&stream);
111 cudssCreate(&handle);
112 cudssSetStream(handle, stream);
114 cudssConfigCreate(&solver_config);
115 int enable_hybrid_exec_mode =
116 (options.use_hybrid_execution && options.hybrid_memory <= 0) ? 1 : 0;
117 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_EXECUTE_MODE,
118 &enable_hybrid_exec_mode,
sizeof(enable_hybrid_exec_mode));
119 int enable_hybrid_memory_mode = (options.hybrid_memory > 0) ? 1 : 0;
120 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_MODE,
121 &enable_hybrid_memory_mode,
122 sizeof(enable_hybrid_memory_mode));
123 if (options.hybrid_memory > 0) {
124 int64_t mem_limit = options.hybrid_memory;
125 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_DEVICE_MEMORY_LIMIT,
126 &mem_limit,
sizeof(mem_limit));
128 cudssAlgType_t reordering_alg = CUDSS_ALG_DEFAULT;
129 cudssConfigSet(solver_config, CUDSS_CONFIG_REORDERING_ALG, &reordering_alg,
130 sizeof(reordering_alg));
131 cudssDataCreate(handle, &solver_data);
136 cudssMatrixDestroy(m_A);
140 cudssMatrixDestroy(m_b);
144 cudssMatrixDestroy(m_x);
148 cudssDataDestroy(handle, solver_data);
149 cudssConfigDestroy(solver_config);
150 cudssDestroy(handle);
151 cudaStreamSynchronize(stream);
152 cudaStreamDestroy(stream);
155 virtual void update_structure(Graph<T, S> *graph,
157 H.build_structure(graph, streams);
158 H.build_csc_structure(graph, d_matrix);
159 fill_matrix_structure();
163 cudssMatrixDestroy(m_b);
166 const auto dim = graph->get_hessian_dimension();
167 auto &b = graph->get_b();
170 cudssMatrixCreateDn(&m_b, dim, 1, ldb, b.data().get(),
171 get_cuda_data_type<T>(), CUDSS_LAYOUT_COL_MAJOR);
174 cudssMatrixDestroy(m_x);
177 solver_x.resize(b.size());
178 cudssMatrixCreateDn(&m_x, dim, 1, ldx, solver_x.data().get(),
179 get_cuda_data_type<T>(), CUDSS_LAYOUT_COL_MAJOR);
182 factorization_failed =
false;
183 auto status = cudssExecute(handle, CUDSS_PHASE_ANALYSIS, solver_config,
184 solver_data, m_A, m_x, m_b);
185 if (status != CUDSS_STATUS_SUCCESS) {
186 factorization_failed =
true;
187 std::cerr <<
"cudss Analysis failed with error code: " << status
189 }
else if (configured_hybrid_memory_limit > 0) {
190 int64_t min_hybrid_memory = 0;
191 size_t size_written = 0;
192 status = cudssDataGet(
193 handle, solver_data, CUDSS_DATA_HYBRID_DEVICE_MEMORY_MIN,
194 &min_hybrid_memory,
sizeof(min_hybrid_memory), &size_written);
195 if (status == CUDSS_STATUS_SUCCESS &&
196 configured_hybrid_memory_limit < min_hybrid_memory) {
197 configured_hybrid_memory_limit = min_hybrid_memory;
198 std::cerr <<
"Requested cuDSS hybrid memory limit is too low; raising "
199 "to minimum required "
200 << configured_hybrid_memory_limit <<
" bytes." << std::endl;
201 status = cudssConfigSet(solver_config,
202 CUDSS_CONFIG_HYBRID_DEVICE_MEMORY_LIMIT,
203 &configured_hybrid_memory_limit,
204 sizeof(configured_hybrid_memory_limit));
205 if (status != CUDSS_STATUS_SUCCESS) {
206 factorization_failed =
true;
207 std::cerr <<
"Failed to update cuDSS hybrid memory limit with error "
209 << status << std::endl;
213 cudaStreamSynchronize(stream);
216 virtual void update_values(Graph<T, S> *graph,
StreamPool &streams)
override {
217 H.update_values(graph, streams);
218 H.update_csc_values(graph, d_matrix);
219 fill_matrix_values();
222 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
223 const bool use_identity,
225 H.apply_damping(graph, damping_factor, use_identity, streams);
226 H.update_csc_values(graph, d_matrix);
227 fill_matrix_values();
231 virtual bool solve(Graph<T, S> *graph, T *x,
StreamPool &streams)
override {
233 if (factorization_failed) {
237 auto dim = graph->get_hessian_dimension();
239 thrust::fill(thrust::device, x, x + dim,
static_cast<T
>(0.0));
240 thrust::copy(thrust::device, x, x + dim, solver_x.data());
242 cudssStatus_t status;
245 cudssMatrixSetValues(m_b, graph->get_b().data().get());
246 cudssMatrixSetValues(m_x, solver_x.data().get());
248 status = cudssExecute(handle, CUDSS_PHASE_FACTORIZATION, solver_config,
249 solver_data, m_A, m_x, m_b);
250 if (status != CUDSS_STATUS_SUCCESS) {
251 std::cerr <<
"cudss Factorization failed with error code: " << status
256 status = cudssExecute(handle, CUDSS_PHASE_SOLVE, solver_config, solver_data,
258 if (status != CUDSS_STATUS_SUCCESS) {
259 std::cerr <<
"cudss Solve failed with error code: " << status
263 cudaStreamSynchronize(stream);
265 thrust::copy(thrust::device, solver_x.data(), solver_x.data() + dim, x);