Graphite
Loading...
Searching...
No Matches
eigen.hpp
Go to the documentation of this file.
1
2#pragma once
6
7namespace graphite {
8
9template <typename T, typename S> class EigenLDLTSolver : public Solver<T, S> {
10private:
11 EigenLDLTWrapper<S> solver;
12
13 Eigen::SparseMatrix<S, Eigen::ColMajor> matrix;
14
15 using Index = typename Eigen::SparseMatrix<S, Eigen::ColMajor>::StorageIndex;
16
17 Hessian<T, S> H;
18 CSCMatrix<S, Index> d_matrix;
19
20 thrust::host_vector<S> h_x;
21 thrust::host_vector<S> h_b;
22
23 void fill_matrix_structure() {
24 const auto dim = d_matrix.d_pointers.size() - 1;
25 matrix.resize(dim, dim);
26 matrix.resizeNonZeros(d_matrix.d_values.size());
27
28 auto h_ptrs = matrix.outerIndexPtr();
29 auto h_indices = matrix.innerIndexPtr();
30
31 thrust::copy(d_matrix.d_pointers.begin(), d_matrix.d_pointers.end(),
32 h_ptrs);
33 thrust::copy(d_matrix.d_indices.begin(), d_matrix.d_indices.end(),
34 h_indices);
35
36 h_x.resize(dim);
37 h_b.resize(dim);
38 }
39
40 void fill_matrix_values() {
41 auto h_values = matrix.valuePtr();
42 thrust::copy(d_matrix.d_values.begin(), d_matrix.d_values.end(), h_values);
43 }
44
45public:
46 EigenLDLTSolver() : solver() {}
47
48 virtual void update_structure(Graph<T, S> *graph,
49 StreamPool &streams) override {
50 H.build_structure(graph, streams);
51 H.build_csc_structure(graph, d_matrix);
52 fill_matrix_structure(); // for CPU matrix
53 solver.analyze_pattern(matrix);
54 }
55
56 virtual void update_values(Graph<T, S> *graph, StreamPool &streams) override {
57 H.update_values(graph, streams);
58 H.update_csc_values(graph, d_matrix);
59 fill_matrix_values(); // for CPU matrix
60 }
61
62 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
63 StreamPool &streams) override {
64 H.apply_damping(graph, damping_factor, streams);
65 H.update_csc_values(graph, d_matrix);
66 fill_matrix_values(); // TODO: Use a more lightweight method to just update
67 // diagonal
68 }
69
70 virtual bool solve(Graph<T, S> *graph, T *x, StreamPool &streams) override {
71
72 auto dim = graph->get_hessian_dimension();
73
74 if (!solver.factorize(matrix)) {
75 std::cerr << "LDLT matrix decomposition failed!";
76 return false;
77 }
78
79 thrust::fill(thrust::device, x, x + dim, static_cast<T>(0.0));
80
81 thrust::copy(graph->get_b().begin(), graph->get_b().end(), h_b.data());
82 thrust::device_ptr<T> d_x(
83 x); // If you don't wrap the pointer, thrust breaks on older toolkits
84 thrust::copy(d_x, d_x + dim, h_x.data());
85
86 auto map_b = VecMap<S>(h_b.data(), dim, 1);
87 auto map_x = VecMap<S>(h_x.data(), dim, 1);
88 if (!solver.solve(map_b, map_x)) {
89 std::cerr << "LDLT solve failed!";
90 return false;
91 }
92
93 thrust::copy(h_x.begin(), h_x.end(), d_x);
94
95 return true;
96 }
97};
98
99} // namespace graphite
Definition eigen.hpp:9
Linear solver interface. Implement this for your own linear solvers.
Definition solver.hpp:12
Definition stream.hpp:7