15 : iterations(100), learning_rate(1e-3), beta1(0.9), beta2(0.999),
16 epsilon(1e-8), optimization_level(0), verbose(
false),
17 stop_flag(
nullptr), streams(
nullptr) {}
24 uint8_t optimization_level;
29 bool validate()
const {
30 if (streams ==
nullptr) {
32 std::cerr <<
"Adam options invalid: streams is null" << std::endl;
51bool adam(Graph<T, S> *graph, AdamOptions<T, S> *options) {
54 auto start = std::chrono::steady_clock::now();
56 if (!options->validate()) {
57 if (options->verbose) {
58 std::cerr <<
"Adam options invalid" << std::endl;
63 auto streams = options->streams;
65 if (!graph->initialize_optimization(options->optimization_level)) {
69 if (!graph->build_structure()) {
73 thrust::device_vector<T> delta_x(graph->get_hessian_dimension());
78 std::chrono::duration<double>(std::chrono::steady_clock::now() - start)
81 if (options->verbose) {
82 std::cout << std::setprecision(12) << std::setw(18) <<
"Iteration"
83 << std::setw(24) <<
"Initial Chi2" << std::setw(24)
84 <<
"Current Chi2" << std::setw(24) << std::setw(24) <<
"Time"
85 << std::setw(24) <<
"Total Time" << std::endl;
87 <<
"---------------------------------------------------------------"
88 "---------------------------------------------------------------"
93 const T lr = options->learning_rate;
94 const T beta1 = options->beta1;
95 const T beta2 = options->beta2;
96 const T epsilon = options->epsilon;
98 const auto dim = delta_x.size();
99 thrust::device_vector<T> m(dim, T(0.0));
100 thrust::device_vector<T> v(dim, T(0.0));
101 thrust::device_vector<T> g(dim, T(0.0));
103 const auto num_iterations = options->iterations;
104 for (
size_t i = 0; i < num_iterations && run; i++) {
106 start = std::chrono::steady_clock::now();
107 graph->linearize(*streams);
108 T chi2 = graph->chi2();
109 thrust::copy(thrust::device, graph->get_b().begin(), graph->get_b().end(),
111 ops::compute_adam_step_async(0, g.size(), g.data().get(),
112 delta_x.data().get(), m.data().get(),
113 v.data().get(), lr, beta1, beta2, epsilon, i);
114 cudaStreamSynchronize(0);
115 graph->apply_update(delta_x.data().get(), *streams);
118 graph->compute_error();
119 T new_chi2 = graph->chi2();
121 double iteration_time =
122 std::chrono::duration<double>(std::chrono::steady_clock::now() - start)
124 time += iteration_time;
125 if (options->verbose) {
126 std::cout << std::setprecision(12) << std::setw(18) << i << std::setw(24)
127 << chi2 << std::setw(24) << new_chi2 << std::setw(24)
128 << iteration_time << std::setw(24) << time << std::endl;
131 if (options->stop_flag && *(options->stop_flag)) {
132 std::cout <<
"Stopping optimization due to stop flag" << std::endl;