Graphite  0.5.0
GPU-accelerated graph optimization framework
Loading...
Searching...
No Matches
adam.hpp
Go to the documentation of this file.
1
2#pragma once
3#include <graphite/graph.hpp>
6#include <iomanip>
7
8namespace graphite {
9
10namespace optimizer {
11
12template <typename T, typename S> class AdamOptions {
13public:
15 : iterations(100), learning_rate(1e-3), beta1(0.9), beta2(0.999),
16 epsilon(1e-8), optimization_level(0), verbose(false),
17 stop_flag(nullptr), streams(nullptr) {}
18
19 size_t iterations;
20 double learning_rate;
21 double beta1;
22 double beta2;
23 double epsilon;
24 uint8_t optimization_level;
25 bool verbose;
26 bool *stop_flag;
27 StreamPool *streams;
28
29 bool validate() const {
30 if (streams == nullptr) {
31 if (verbose) {
32 std::cerr << "Adam options invalid: streams is null" << std::endl;
33 }
34 return false;
35 }
36
37 return true;
38 }
39};
40
50template <typename T, typename S>
51bool adam(Graph<T, S> *graph, AdamOptions<T, S> *options) {
52
53 // Initialize something for all iterations
54 auto start = std::chrono::steady_clock::now();
55
56 if (!options->validate()) {
57 if (options->verbose) {
58 std::cerr << "Adam options invalid" << std::endl;
59 }
60 return false;
61 }
62
63 auto streams = options->streams;
64
65 if (!graph->initialize_optimization(options->optimization_level)) {
66 return false;
67 }
68
69 if (!graph->build_structure()) {
70 return false;
71 }
72
73 thrust::device_vector<T> delta_x(graph->get_hessian_dimension());
74
75 bool run = true;
76
77 double time =
78 std::chrono::duration<double>(std::chrono::steady_clock::now() - start)
79 .count();
80 // Print iteration table headers
81 if (options->verbose) {
82 std::cout << std::setprecision(12) << std::setw(18) << "Iteration"
83 << std::setw(24) << "Initial Chi2" << std::setw(24)
84 << "Current Chi2" << std::setw(24) << std::setw(24) << "Time"
85 << std::setw(24) << "Total Time" << std::endl;
86 std::cout
87 << "---------------------------------------------------------------"
88 "---------------------------------------------------------------"
89 "------------"
90 << std::endl;
91 }
92
93 const T lr = options->learning_rate;
94 const T beta1 = options->beta1;
95 const T beta2 = options->beta2;
96 const T epsilon = options->epsilon;
97
98 const auto dim = delta_x.size();
99 thrust::device_vector<T> m(dim, T(0.0));
100 thrust::device_vector<T> v(dim, T(0.0));
101 thrust::device_vector<T> g(dim, T(0.0));
102
103 const auto num_iterations = options->iterations;
104 for (size_t i = 0; i < num_iterations && run; i++) {
105
106 start = std::chrono::steady_clock::now();
107 graph->linearize(*streams);
108 T chi2 = graph->chi2();
109 thrust::copy(thrust::device, graph->get_b().begin(), graph->get_b().end(),
110 g.begin());
111 ops::compute_adam_step_async(0, g.size(), g.data().get(),
112 delta_x.data().get(), m.data().get(),
113 v.data().get(), lr, beta1, beta2, epsilon, i);
114 cudaStreamSynchronize(0);
115 graph->apply_update(delta_x.data().get(), *streams);
116
117 // Try step
118 graph->compute_error();
119 T new_chi2 = graph->chi2();
120
121 double iteration_time =
122 std::chrono::duration<double>(std::chrono::steady_clock::now() - start)
123 .count();
124 time += iteration_time;
125 if (options->verbose) {
126 std::cout << std::setprecision(12) << std::setw(18) << i << std::setw(24)
127 << chi2 << std::setw(24) << new_chi2 << std::setw(24)
128 << iteration_time << std::setw(24) << time << std::endl;
129 }
130
131 if (options->stop_flag && *(options->stop_flag)) {
132 std::cout << "Stopping optimization due to stop flag" << std::endl;
133 break;
134 }
135 }
136
137 return run;
138}
139
140} // namespace optimizer
141
142} // namespace graphite
Definition stream.hpp:7
Definition adam.hpp:12
bool adam(Graph< T, S > *graph, AdamOptions< T, S > *options)
Adam optimization algorithm.
Definition adam.hpp:51
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4