Graphite  0.5.0
GPU-accelerated graph optimization framework
Loading...
Searching...
No Matches
gradient_descent.hpp
Go to the documentation of this file.
1
2#pragma once
3#include <graphite/graph.hpp>
6#include <iomanip>
7
8namespace graphite {
9
10namespace optimizer {
11
12template <typename T, typename S> class GradientDescentOptions {
13public:
15 : iterations(100), learning_rate(1e-3), optimization_level(0),
16 verbose(false), stop_flag(nullptr), streams(nullptr) {}
17
18 size_t iterations;
19 double learning_rate;
20 uint8_t optimization_level;
21 bool verbose;
22 bool *stop_flag;
23 StreamPool *streams;
24
25 bool validate() const {
26 if (streams == nullptr) {
27 if (verbose) {
28 std::cerr << "Gradient Descent options invalid: streams is null"
29 << std::endl;
30 }
31 return false;
32 }
33
34 return true;
35 }
36};
37
47template <typename T, typename S>
48bool gradient_descent(Graph<T, S> *graph,
49 GradientDescentOptions<T, S> *options) {
50
51 // Initialize something for all iterations
52 auto start = std::chrono::steady_clock::now();
53
54 if (!options->validate()) {
55 if (options->verbose) {
56 std::cerr << "Gradient Descent options invalid" << std::endl;
57 }
58 return false;
59 }
60
61 auto streams = options->streams;
62
63 if (!graph->initialize_optimization(options->optimization_level)) {
64 return false;
65 }
66
67 if (!graph->build_structure()) {
68 return false;
69 }
70
71 thrust::device_vector<T> delta_x(graph->get_hessian_dimension());
72
73 bool run = true;
74
75 double time =
76 std::chrono::duration<double>(std::chrono::steady_clock::now() - start)
77 .count();
78 // Print iteration table headers
79 if (options->verbose) {
80 std::cout << std::setprecision(12) << std::setw(18) << "Iteration"
81 << std::setw(24) << "Initial Chi2" << std::setw(24)
82 << "Current Chi2" << std::setw(24) << std::setw(24) << "Time"
83 << std::setw(24) << "Total Time" << std::endl;
84 std::cout
85 << "---------------------------------------------------------------"
86 "---------------------------------------------------------------"
87 "------------"
88 << std::endl;
89 }
90
91 const T alpha = options->learning_rate;
92
93 const auto num_iterations = options->iterations;
94 for (size_t i = 0; i < num_iterations && run; i++) {
95
96 start = std::chrono::steady_clock::now();
97 graph->linearize(*streams);
98 T chi2 = graph->chi2();
99 thrust::fill(thrust::device, delta_x.begin(), delta_x.end(), T(0.0));
100 ops::axpy_async(0, delta_x.size(), delta_x.data().get(), alpha,
101 graph->get_b().data().get(), delta_x.data().get());
102 cudaStreamSynchronize(0);
103 graph->apply_update(delta_x.data().get(), *streams);
104
105 // Try step
106 graph->compute_error();
107 T new_chi2 = graph->chi2();
108
109 double iteration_time =
110 std::chrono::duration<double>(std::chrono::steady_clock::now() - start)
111 .count();
112 time += iteration_time;
113 if (options->verbose) {
114 std::cout << std::setprecision(12) << std::setw(18) << i << std::setw(24)
115 << chi2 << std::setw(24) << new_chi2 << std::setw(24)
116 << iteration_time << std::setw(24) << time << std::endl;
117 }
118
119 if (options->stop_flag && *(options->stop_flag)) {
120 std::cout << "Stopping optimization due to stop flag" << std::endl;
121 break;
122 }
123 }
124
125 return run;
126}
127
128} // namespace optimizer
129} // namespace graphite
Definition stream.hpp:7
Definition gradient_descent.hpp:12
bool gradient_descent(Graph< T, S > *graph, GradientDescentOptions< T, S > *options)
Naive gradient descent optimization algorithm.
Definition gradient_descent.hpp:48
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4