Graphite  0.5.0
GPU-accelerated graph optimization framework
Loading...
Searching...
No Matches
eigen_schur.hpp
1
2#pragma once
3
5#include <graphite/schur.hpp>
8
9namespace graphite {
10
11template <typename T, typename S,
12 typename Index =
13 typename Eigen::SparseMatrix<S, Eigen::ColMajor>::StorageIndex>
14class EigenSchurLDLTSolver : public Solver<T, S> {
15private:
16 EigenLDLTWrapper<S, Index> solver;
17
18 Eigen::SparseMatrix<S, Eigen::ColMajor, Index> matrix;
19
20 Hessian<T, S> H;
21 SchurComplement<T, S> schur;
22 CSCMatrix<S, Index> d_matrix;
23
24 thrust::host_vector<S> h_x;
25 thrust::host_vector<S> h_b;
26
27 void fill_matrix_structure() {
28 const auto dim = d_matrix.d_pointers.size() - 1;
29 matrix.resize(dim, dim);
30 matrix.resizeNonZeros(d_matrix.d_values.size());
31
32 auto h_ptrs = matrix.outerIndexPtr();
33 auto h_indices = matrix.innerIndexPtr();
34
35 thrust::copy(d_matrix.d_pointers.begin(), d_matrix.d_pointers.end(),
36 h_ptrs);
37 thrust::copy(d_matrix.d_indices.begin(), d_matrix.d_indices.end(),
38 h_indices);
39
40 h_x.resize(dim);
41 h_b.resize(dim);
42 }
43
44 void fill_matrix_values() {
45 auto h_values = matrix.valuePtr();
46 thrust::copy(d_matrix.d_values.begin(), d_matrix.d_values.end(), h_values);
47 }
48
49public:
50 EigenSchurLDLTSolver() : solver(), schur(H) {}
51
52 virtual void update_structure(Graph<T, S> *graph,
53 StreamPool &streams) override {
54 H.build_structure(graph, streams);
55 schur.build_structure(graph, streams);
56 schur.build_csc_structure(graph, d_matrix);
57 fill_matrix_structure();
58 solver.analyze_pattern(matrix);
59 }
60
61 virtual void update_values(Graph<T, S> *graph, StreamPool &streams) override {
62 H.update_values(graph, streams);
63 }
64
65 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
66 const bool use_identity,
67 StreamPool &streams) override {
68 H.apply_damping(graph, damping_factor, use_identity, streams);
69 }
70
71 virtual bool solve(Graph<T, S> *graph, T *x, StreamPool &streams) override {
72
73 // Update matrix values here (to avoid extra work when damping
74 schur.update_values(graph, streams);
75 schur.update_csc_values(graph, d_matrix);
76 fill_matrix_values();
77
78 if (!solver.factorize(matrix)) {
79 std::cerr << "Schur LDLT matrix decomposition failed!";
80 return false;
81 }
82
83 const auto dim = graph->get_hessian_dimension();
84 const auto &offsets = graph->get_offset_vector();
85 const auto p_block_col = schur.lowest_eliminated_block_col;
86 const auto p_dim = offsets[p_block_col];
87
88 thrust::fill(thrust::device, x, x + dim, static_cast<T>(0.0));
89
90 thrust::copy(schur.get_b_Schur().begin(), schur.get_b_Schur().end(),
91 h_b.begin());
92
93 thrust::device_ptr<T> d_x(x);
94 thrust::copy(d_x, d_x + p_dim, h_x.data());
95
96 auto map_b = VecMap<S>(h_b.data(), p_dim, 1);
97 auto map_x = VecMap<S>(h_x.data(), p_dim, 1);
98 if (!solver.solve(map_b, map_x)) {
99 std::cerr << "Schur LDLT solve failed!";
100 return false;
101 }
102
103 thrust::copy(h_x.begin(), h_x.end(), d_x);
104
105 schur.compute_landmark_update(graph, streams, x + p_dim, x);
106
107 return true;
108 }
109};
110
111} // namespace graphite
Definition eigen_schur.hpp:14
Linear solver interface. Implement this for your own linear solvers.
Definition solver.hpp:12
Definition stream.hpp:7
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4