Graphite  0.5.0
GPU-accelerated graph optimization framework
Loading...
Searching...
No Matches
pcg.hpp
Go to the documentation of this file.
1
2#pragma once
6#include <thrust/execution_policy.h>
7#include <thrust/inner_product.h>
8
9namespace graphite {
10
12template <typename T, typename S> class PCGSolver : public Solver<T, S> {
13private:
14 thrust::device_vector<T> v;
15
16 // Need vectors for residuals of each factor
17 thrust::device_vector<T> v1; // v1 = Jv (dimension same as r)
18 thrust::device_vector<T> v2; // v2 = J^T v1 (dimension same as x)
19 thrust::device_vector<T> r; // residual
20 thrust::device_vector<T> p; // search direction
21 thrust::device_vector<T> z; // preconditioned residual
22 thrust::device_vector<T> diag; // diagonal of Hessian
23 thrust::device_vector<T> x_backup;
24 thrust::device_vector<T> y;
25
26 size_t max_iter;
27 T tol;
28 T rejection_ratio;
29 T damping_factor;
30 bool use_identity_damping;
31
32 Preconditioner<T, S> *preconditioner;
33
34public:
35 PCGSolver(size_t max_iter, T tol, T rejection_ratio,
36 Preconditioner<T, S> *preconditioner)
37 : max_iter(max_iter), tol(tol), rejection_ratio(rejection_ratio),
38 damping_factor(0), use_identity_damping(false),
39 preconditioner(preconditioner) {}
40
41 virtual void update_structure(Graph<T, S> *graph,
42 StreamPool &streams) override {
43
44 preconditioner->update_structure(graph, streams);
45 }
46
47 virtual void update_values(Graph<T, S> *graph, StreamPool &streams) override {
48 preconditioner->update_values(graph, streams);
49 }
50
51 virtual void set_damping_factor(Graph<T, S> *graph, T damping_factor,
52 const bool use_identity,
53 StreamPool &streams) override {
54 this->damping_factor = damping_factor;
55 this->use_identity_damping = use_identity;
56 preconditioner->set_damping_factor(graph, damping_factor, use_identity,
57 streams);
58 }
59
60 // Assumes that x is already initialized
61 virtual bool solve(Graph<T, S> *graph, T *x, StreamPool &streams) override {
62
63 auto &vertex_descriptors = graph->get_vertex_descriptors();
64 auto &factor_descriptors = graph->get_factor_descriptors();
65 T *b = graph->get_b().data().get();
66 size_t dim_h = graph->get_hessian_dimension();
67
68 size_t dim_r = 0;
69 for (size_t i = 0; i < factor_descriptors.size(); i++) {
70 dim_r += factor_descriptors[i]->get_residual_size();
71 }
72
73 // Resize vectors (assuming this causes host synchronization)
74 v1.resize(dim_r);
75 v2.resize(dim_h);
76 r.resize(dim_h); // dim h because dim(r) = dim(Ax) = dim(b)
77 diag.resize(dim_h);
78
79 const cudaStream_t stream = 0;
80
81 thrust::fill(thrust::cuda::par_nosync.on(stream), x, x + dim_h, 0.0);
82 thrust::fill(thrust::cuda::par_nosync.on(stream), v1.begin(), v1.end(),
83 0.0);
84 thrust::fill(thrust::cuda::par_nosync.on(stream), v2.begin(), v2.end(),
85 0.0);
86
87 // Compute residual
88 thrust::copy(thrust::cuda::par_nosync.on(stream), graph->get_b().begin(),
89 graph->get_b().end(), r.begin());
90
91 // 3. Add damping factor
92 // v2 += damping_factor*diag(H)*x
93 thrust::fill(thrust::cuda::par_nosync.on(stream), diag.begin(), diag.end(),
94 0.0);
95 for (size_t i = 0; i < factor_descriptors.size(); i++) {
96 factor_descriptors[i]->compute_hessian_scalar_diagonal_async(
97 diag.data().get(),
98 graph->get_jacobian_scales().data().get()); // also on default stream
99 }
100
101 // Check for negative values in diag and print an error if found
102 T min_diag = static_cast<T>(1.0e-6);
103 T max_diag = static_cast<T>(1.0e32);
104 ops::clamp_async(stream, dim_h, min_diag, max_diag, diag.data().get());
105
106 cudaStreamSynchronize(stream);
107
108 // Rescale r
109 y.resize(dim_h);
110 auto rnorm = thrust::inner_product(thrust::device, r.begin(), r.end(),
111 r.begin(), static_cast<T>(0.0));
112 rnorm = std::sqrt(rnorm);
113 auto scale = 1.0 / rnorm;
114 ops::rescale_vec_async<T>(stream, dim_h, y.data().get(), scale,
115 r.data().get());
116 cudaStreamSynchronize(stream);
117 // Apply preconditioner
118 z.resize(dim_h);
119
120 thrust::fill(z.begin(), z.end(), 0.0);
121 preconditioner->apply(graph, z.data().get(), y.data().get(), streams);
122
123 p.resize(dim_h);
124 thrust::copy(z.begin(), z.end(), p.begin()); // p = z
125
126 x_backup.resize(dim_h);
127
128 // 1. First compute dot(r, z)
129 T rz = (T)thrust::inner_product(r.begin(), r.end(), z.begin(),
130 static_cast<T>(0.0));
131
132 T rz_0 = std::numeric_limits<T>::infinity();
133 for (size_t k = 0; k < max_iter; k++) {
134 if (rz == 0) {
135 // std::cout << "rz is zero, stopping at iteration " << k << std::endl;
136 break;
137 }
138
139 // auto t_jv_start = std::chrono::steady_clock::now();
140 // 2. Compute v1 = Jp
141 thrust::fill(v1.begin(), v1.end(), 0.0);
142 auto v1_ptr = v1.data().get(); // reset
143 for (size_t i = 0; i < factor_descriptors.size(); i++) {
144 factor_descriptors[i]->compute_Jv(
145 v1_ptr, p.data().get(), graph->get_jacobian_scales().data().get(),
146 streams);
147 v1_ptr += factor_descriptors[i]->get_residual_size();
148 }
149 // auto t_jv_end = std::chrono::steady_clock::now();
150 // std::cout << "Time for Jv: "
151 // << std::chrono::duration<double>(t_jv_end -
152 // t_jv_start).count()
153 // << " seconds" << std::endl;
154
155 // 3. Compute v2 = J^T v1
156 thrust::fill(v2.begin(), v2.end(), 0.0);
157 v1_ptr = v1.data().get(); // reset
158 for (size_t i = 0; i < factor_descriptors.size(); i++) {
159 factor_descriptors[i]->compute_Jtv(
160 v2.data().get(), v1_ptr, graph->get_jacobian_scales().data().get(),
161 streams);
162 v1_ptr += factor_descriptors[i]->get_residual_size();
163 }
164 // Add damping factor
165 // v2 += damping_factor*diag(H)*p
166 ops::damp_by_factor_async(stream, dim_h, v2.data().get(), damping_factor,
167 use_identity_damping, diag.data().get(),
168 p.data().get());
169
170 // 4. Compute alpha = dot(r, z) / dot(p, v2)
171 T alpha = (rz) / thrust::inner_product(thrust::cuda::par.on(stream),
172 p.begin(), p.end(), v2.begin(),
173 static_cast<T>(0.0));
174 // 5. x += alpha * p
175 thrust::copy(thrust::cuda::par.on(stream), x, x + dim_h,
176 x_backup.begin());
177 ops::axpy_async(stream, dim_h, x, alpha, p.data().get(), x);
178
179 // 6. r -= alpha * v2
180 ops::axpy_async(stream, dim_h, r.data().get(), -alpha, v2.data().get(),
181 r.data().get());
182 // cudaStreamSynchronize(0);
183
184 rnorm = (T)thrust::inner_product(thrust::cuda::par.on(stream), r.begin(),
185 r.end(), r.begin(), static_cast<T>(0.0));
186 rnorm = std::sqrt(rnorm);
187 scale = 1.0 / rnorm;
188 ops::rescale_vec_async<T>(stream, dim_h, y.data().get(), scale,
189 r.data().get());
190
191 // Apply preconditioner again
192 thrust::fill(thrust::cuda::par.on(stream), z.begin(), z.end(), 0.0);
193 preconditioner->apply(graph, z.data().get(), y.data().get(), streams);
194 T rz_new = thrust::inner_product(thrust::cuda::par.on(stream), r.begin(),
195 r.end(), z.begin(), static_cast<T>(0.0));
196
197 // if (rz_new > rejection_ratio * rz_0) {
198 if (std::abs(rz_new) > rejection_ratio * rz_0 || std::isnan(rz_new)) {
199 thrust::copy(thrust::device, x_backup.begin(), x_backup.end(), x);
200 // std::cout << "Rejection: rz_new = " << rz_new
201 // << ", rz_0 = " << rz_0 << " at iteration " << k + 1 <<
202 // std::endl;
203 std::cout << "rejected pcg update\n";
204 break;
205 }
206 rz_0 = std::min(rz_0, std::abs(rz_new));
207
208 // 8. Compute beta
209 // std::cout << "rz_new: " << rz_new << ", rz: " << rz
210 // << ", at iteration " << k + 1 << std::endl;
211
212 T beta = rz_new / (rz);
213 rz = rz_new;
214
215 // 9. Update p
216 ops::axpy_async(stream, dim_h, p.data().get(), beta, p.data().get(),
217 z.data().get());
218 cudaStreamSynchronize(stream);
219
220 if (std::abs(static_cast<T>(rz_new)) < tol) {
221 // std::cout << "Converged after " << k + 1
222 // << " iterations with residual: " << rz_new << std::endl;
223 break;
224 }
225 // if (k == max_iter - 1) {
226 // std::cout << "Reached maximum iterations: " << max_iter
227 // << " with residual: " << rz_new << std::endl;
228 // }
229 }
230 // TODO: Figure out failure cases
231 return true;
232 }
233};
234} // namespace graphite
Preconditioned Conjugate Gradient (PCG) solver.
Definition pcg.hpp:12
Linear solver interface. Implement this for your own linear solvers.
Definition solver.hpp:12
Definition stream.hpp:7
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4