22 using StorageIndex = int32_t;
24 void fill_matrix_structure() {
25 const auto dim = d_matrix.d_pointers.size() - 1;
26 const auto nnz = d_matrix.d_values.size();
27 const cudssMatrixType_t matrix_type = CUDSS_MTYPE_SYMMETRIC;
28 const cudssMatrixViewType_t view_type = CUDSS_MVIEW_LOWER;
29 const cudssIndexBase_t index_base = CUDSS_BASE_ZERO;
32 cudssMatrixDestroy(m_A);
36 cudssMatrixCreateCsr(&m_A, dim, dim, nnz, d_matrix.d_pointers.data().get(),
37 nullptr, d_matrix.d_indices.data().get(),
38 d_matrix.d_values.data().get(), CUDA_R_32I,
39 get_cuda_data_type<S>(), matrix_type, view_type,
43 void fill_matrix_values() {
44 cudssMatrixSetValues(m_A, d_matrix.d_values.data().get());
48 CSCMatrix<S, StorageIndex> d_matrix;
50 bool factorization_failed;
55 cudssConfig_t solver_config;
56 cudssData_t solver_data;
58 cudssMatrix_t m_x, m_b, m_A;
60 thrust::device_vector<T> solver_x;
68 factorization_failed =
false;
70 cudaStreamCreate(&stream);
72 cudssSetStream(handle, stream);
74 cudssConfigCreate(&solver_config);
75 int enable_hybrid_exec_mode = use_hybrid_execution ? 1 : 0;
76 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_EXECUTE_MODE,
77 &enable_hybrid_exec_mode,
sizeof(enable_hybrid_exec_mode));
78 cudssAlgType_t reordering_alg = CUDSS_ALG_DEFAULT;
79 cudssConfigSet(solver_config, CUDSS_CONFIG_REORDERING_ALG, &reordering_alg,
80 sizeof(reordering_alg));
81 cudssDataCreate(handle, &solver_data);
86 cudssMatrixDestroy(m_A);
90 cudssMatrixDestroy(m_b);
94 cudssMatrixDestroy(m_x);
98 cudssDataDestroy(handle, solver_data);
99 cudssConfigDestroy(solver_config);
100 cudssDestroy(handle);
101 cudaStreamSynchronize(stream);
102 cudaStreamDestroy(stream);
105 virtual void update_structure(Graph<T, S> *graph,
107 H.build_structure(graph, streams);
108 H.build_csc_structure(graph, d_matrix);
109 fill_matrix_structure();
113 cudssMatrixDestroy(m_b);
116 const auto dim = graph->get_hessian_dimension();
117 auto &b = graph->get_b();
120 cudssMatrixCreateDn(&m_b, dim, 1, ldb, b.data().get(),
121 get_cuda_data_type<T>(), CUDSS_LAYOUT_COL_MAJOR);
124 cudssMatrixDestroy(m_x);
127 solver_x.resize(b.size());
128 cudssMatrixCreateDn(&m_x, dim, 1, ldx, solver_x.data().get(),
129 get_cuda_data_type<T>(), CUDSS_LAYOUT_COL_MAJOR);
132 factorization_failed =
false;
133 auto status = cudssExecute(handle, CUDSS_PHASE_ANALYSIS, solver_config,
134 solver_data, m_A, m_x, m_b);
135 if (status != CUDSS_STATUS_SUCCESS) {
136 factorization_failed =
true;
137 std::cerr <<
"cudss Analysis failed with error code: " << status
140 cudaStreamSynchronize(stream);
143 virtual void update_values(Graph<T, S> *graph,
StreamPool &streams)
override {
144 H.update_values(graph, streams);
145 H.update_csc_values(graph, d_matrix);
146 fill_matrix_values();
149 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
151 H.apply_damping(graph, damping_factor, streams);
152 H.update_csc_values(graph, d_matrix);
153 fill_matrix_values();
157 virtual bool solve(Graph<T, S> *graph, T *x,
StreamPool &streams)
override {
159 if (factorization_failed) {
163 auto dim = graph->get_hessian_dimension();
165 thrust::fill(thrust::device, x, x + dim,
static_cast<T
>(0.0));
166 thrust::copy(thrust::device, x, x + dim, solver_x.data());
168 cudssStatus_t status;
171 cudssMatrixSetValues(m_b, graph->get_b().data().get());
172 cudssMatrixSetValues(m_x, solver_x.data().get());
174 status = cudssExecute(handle, CUDSS_PHASE_FACTORIZATION, solver_config,
175 solver_data, m_A, m_x, m_b);
176 if (status != CUDSS_STATUS_SUCCESS) {
177 std::cerr <<
"cudss Factorization failed with error code: " << status
182 status = cudssExecute(handle, CUDSS_PHASE_SOLVE, solver_config, solver_data,
184 if (status != CUDSS_STATUS_SUCCESS) {
185 std::cerr <<
"cudss Solve failed with error code: " << status
189 cudaStreamSynchronize(stream);
191 thrust::copy(thrust::device, solver_x.data(), solver_x.data() + dim, x);