Graphite  0.5.0
GPU-accelerated graph optimization framework
Loading...
Searching...
No Matches
eigen.hpp
Go to the documentation of this file.
1
2#pragma once
6
7namespace graphite {
8
9template <typename T, typename S,
10 typename Index =
11 typename Eigen::SparseMatrix<S, Eigen::ColMajor>::StorageIndex>
12class EigenLDLTSolver : public Solver<T, S> {
13private:
14 EigenLDLTWrapper<S, Index> solver;
15
16 Eigen::SparseMatrix<S, Eigen::ColMajor, Index> matrix;
17
18 Hessian<T, S> H;
19 CSCMatrix<S, Index> d_matrix;
20
21 thrust::host_vector<S> h_x;
22 thrust::host_vector<S> h_b;
23
24 void fill_matrix_structure() {
25 const auto dim = d_matrix.d_pointers.size() - 1;
26 matrix.resize(dim, dim);
27 matrix.resizeNonZeros(d_matrix.d_values.size());
28
29 auto h_ptrs = matrix.outerIndexPtr();
30 auto h_indices = matrix.innerIndexPtr();
31
32 thrust::copy(d_matrix.d_pointers.begin(), d_matrix.d_pointers.end(),
33 h_ptrs);
34 thrust::copy(d_matrix.d_indices.begin(), d_matrix.d_indices.end(),
35 h_indices);
36
37 h_x.resize(dim);
38 h_b.resize(dim);
39 }
40
41 void fill_matrix_values() {
42 auto h_values = matrix.valuePtr();
43 thrust::copy(d_matrix.d_values.begin(), d_matrix.d_values.end(), h_values);
44 }
45
46public:
47 EigenLDLTSolver() : solver() {}
48
49 virtual void update_structure(Graph<T, S> *graph,
50 StreamPool &streams) override {
51 H.build_structure(graph, streams);
52 H.build_csc_structure(graph, d_matrix);
53 fill_matrix_structure(); // for CPU matrix
54 solver.analyze_pattern(matrix);
55 }
56
57 virtual void update_values(Graph<T, S> *graph, StreamPool &streams) override {
58 H.update_values(graph, streams);
59 H.update_csc_values(graph, d_matrix);
60 fill_matrix_values(); // for CPU matrix
61 }
62
63 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
64 const bool use_identity,
65 StreamPool &streams) override {
66 H.apply_damping(graph, damping_factor, use_identity, streams);
67 H.update_csc_values(graph, d_matrix);
68 fill_matrix_values(); // TODO: Use a more lightweight method to just update
69 // diagonal
70 }
71
72 virtual bool solve(Graph<T, S> *graph, T *x, StreamPool &streams) override {
73
74 auto dim = graph->get_hessian_dimension();
75
76 if (!solver.factorize(matrix)) {
77 std::cerr << "LDLT matrix decomposition failed!";
78 return false;
79 }
80
81 thrust::fill(thrust::device, x, x + dim, static_cast<T>(0.0));
82
83 thrust::copy(graph->get_b().begin(), graph->get_b().end(), h_b.data());
84 thrust::device_ptr<T> d_x(
85 x); // If you don't wrap the pointer, thrust breaks on older toolkits
86 thrust::copy(d_x, d_x + dim, h_x.data());
87
88 auto map_b = VecMap<S>(h_b.data(), dim, 1);
89 auto map_x = VecMap<S>(h_x.data(), dim, 1);
90 if (!solver.solve(map_b, map_x)) {
91 std::cerr << "LDLT solve failed!";
92 return false;
93 }
94
95 thrust::copy(h_x.begin(), h_x.end(), d_x);
96
97 return true;
98 }
99};
100
101} // namespace graphite
Definition eigen.hpp:12
Linear solver interface. Implement this for your own linear solvers.
Definition solver.hpp:12
Definition stream.hpp:7
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4