16 static_assert(std::is_same<Index, int32_t>::value ||
17 std::is_same<Index, int64_t>::value,
18 "cudssSchurSolver index type must be int32_t or int64_t");
21 SchurComplement<T, S> schur;
22 CSCMatrix<S, Index> d_matrix;
23 cudssMatrixType_t matrix_type;
25 bool factorization_failed;
30 cudssConfig_t solver_config;
31 cudssData_t solver_data;
33 cudssMatrix_t m_x, m_b, m_A;
35 thrust::device_vector<T> solver_x;
37 int64_t configured_hybrid_memory_limit;
39 void fill_matrix_structure() {
40 const auto dim = d_matrix.d_pointers.size() - 1;
41 const auto nnz = d_matrix.d_values.size();
42 const cudssMatrixViewType_t view_type = CUDSS_MVIEW_LOWER;
43 const cudssIndexBase_t index_base = CUDSS_BASE_ZERO;
46 cudssMatrixDestroy(m_A);
50 cudssMatrixCreateCsr(&m_A, dim, dim, nnz, d_matrix.d_pointers.data().get(),
51 nullptr, d_matrix.d_indices.data().get(),
52 d_matrix.d_values.data().get(),
53 get_cuda_index_type<Index>(), get_cuda_data_type<S>(),
54 matrix_type, view_type, index_base);
57 void fill_matrix_values() {
58 cudssMatrixSetValues(m_A, d_matrix.d_values.data().get());
64 : schur(H), factorization_failed(
false), schur_dim(0) {
69 matrix_type = options.matrix_type;
70 configured_hybrid_memory_limit = options.hybrid_memory;
72 cudaStreamCreate(&stream);
74 cudssSetStream(handle, stream);
76 cudssConfigCreate(&solver_config);
77 int enable_hybrid_exec_mode =
78 (options.use_hybrid_execution && options.hybrid_memory <= 0) ? 1 : 0;
79 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_EXECUTE_MODE,
80 &enable_hybrid_exec_mode,
sizeof(enable_hybrid_exec_mode));
81 int enable_hybrid_memory_mode = (options.hybrid_memory > 0) ? 1 : 0;
82 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_MODE,
83 &enable_hybrid_memory_mode,
84 sizeof(enable_hybrid_memory_mode));
85 if (options.hybrid_memory > 0) {
86 int64_t mem_limit = options.hybrid_memory;
87 cudssConfigSet(solver_config, CUDSS_CONFIG_HYBRID_DEVICE_MEMORY_LIMIT,
88 &mem_limit,
sizeof(mem_limit));
90 cudssAlgType_t reordering_alg = CUDSS_ALG_DEFAULT;
91 cudssConfigSet(solver_config, CUDSS_CONFIG_REORDERING_ALG, &reordering_alg,
92 sizeof(reordering_alg));
93 cudssDataCreate(handle, &solver_data);
98 cudssMatrixDestroy(m_A);
102 cudssMatrixDestroy(m_b);
106 cudssMatrixDestroy(m_x);
110 cudssDataDestroy(handle, solver_data);
111 cudssConfigDestroy(solver_config);
112 cudssDestroy(handle);
113 cudaStreamSynchronize(stream);
114 cudaStreamDestroy(stream);
117 virtual void update_structure(Graph<T, S> *graph,
119 H.build_structure(graph, streams);
120 schur.build_structure(graph, streams);
121 schur.build_csc_structure(graph, d_matrix);
122 fill_matrix_structure();
124 const auto &offsets = graph->get_offset_vector();
125 schur_dim = offsets[schur.lowest_eliminated_block_col];
128 cudssMatrixDestroy(m_b);
131 int ldb =
static_cast<int>(schur_dim);
132 cudssMatrixCreateDn(&m_b, schur_dim, 1, ldb,
133 schur.get_b_Schur().data().get(),
134 get_cuda_data_type<T>(), CUDSS_LAYOUT_COL_MAJOR);
137 cudssMatrixDestroy(m_x);
140 solver_x.resize(schur_dim);
141 int ldx =
static_cast<int>(schur_dim);
142 cudssMatrixCreateDn(&m_x, schur_dim, 1, ldx, solver_x.data().get(),
143 get_cuda_data_type<T>(), CUDSS_LAYOUT_COL_MAJOR);
145 factorization_failed =
false;
146 auto status = cudssExecute(handle, CUDSS_PHASE_ANALYSIS, solver_config,
147 solver_data, m_A, m_x, m_b);
148 if (status != CUDSS_STATUS_SUCCESS) {
149 factorization_failed =
true;
150 std::cerr <<
"cudss Schur analysis failed with error code: " << status
152 }
else if (configured_hybrid_memory_limit > 0) {
153 int64_t min_hybrid_memory = 0;
154 size_t size_written = 0;
155 status = cudssDataGet(
156 handle, solver_data, CUDSS_DATA_HYBRID_DEVICE_MEMORY_MIN,
157 &min_hybrid_memory,
sizeof(min_hybrid_memory), &size_written);
158 if (status == CUDSS_STATUS_SUCCESS &&
159 configured_hybrid_memory_limit < min_hybrid_memory) {
160 configured_hybrid_memory_limit = min_hybrid_memory;
161 std::cerr <<
"Requested cuDSS Schur hybrid memory limit is too low; "
162 "raising to minimum required "
163 << configured_hybrid_memory_limit <<
" bytes." << std::endl;
164 status = cudssConfigSet(solver_config,
165 CUDSS_CONFIG_HYBRID_DEVICE_MEMORY_LIMIT,
166 &configured_hybrid_memory_limit,
167 sizeof(configured_hybrid_memory_limit));
168 if (status != CUDSS_STATUS_SUCCESS) {
169 factorization_failed =
true;
171 <<
"Failed to update cuDSS Schur hybrid memory limit with error "
173 << status << std::endl;
177 cudaStreamSynchronize(stream);
180 virtual void update_values(Graph<T, S> *graph,
StreamPool &streams)
override {
181 H.update_values(graph, streams);
184 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
185 const bool use_identity,
187 H.apply_damping(graph, damping_factor, use_identity, streams);
190 virtual bool solve(Graph<T, S> *graph, T *x,
StreamPool &streams)
override {
191 if (factorization_failed) {
196 schur.update_values(graph, streams);
197 schur.update_csc_values(graph, d_matrix);
198 fill_matrix_values();
200 const auto dim = graph->get_hessian_dimension();
202 thrust::fill(thrust::device, x, x + dim,
static_cast<T
>(0.0));
203 thrust::fill(thrust::device, solver_x.begin(), solver_x.end(),
204 static_cast<T
>(0.0));
206 cudssStatus_t status;
208 cudssMatrixSetValues(m_b, schur.get_b_Schur().data().get());
209 cudssMatrixSetValues(m_x, solver_x.data().get());
211 status = cudssExecute(handle, CUDSS_PHASE_FACTORIZATION, solver_config,
212 solver_data, m_A, m_x, m_b);
213 if (status != CUDSS_STATUS_SUCCESS) {
214 std::cerr <<
"cudss Schur factorization failed with error code: "
215 << status << std::endl;
219 status = cudssExecute(handle, CUDSS_PHASE_SOLVE, solver_config, solver_data,
221 if (status != CUDSS_STATUS_SUCCESS) {
222 std::cerr <<
"cudss Schur solve failed with error code: " << status
226 cudaStreamSynchronize(stream);
228 thrust::copy(thrust::device, solver_x.data(), solver_x.data() + schur_dim,
231 schur.compute_landmark_update(graph, streams, x + schur_dim, x);