49 GradientDescentOptions<T, S> *options) {
52 auto start = std::chrono::steady_clock::now();
54 if (!options->validate()) {
55 if (options->verbose) {
56 std::cerr <<
"Gradient Descent options invalid" << std::endl;
61 auto streams = options->streams;
63 if (!graph->initialize_optimization(options->optimization_level)) {
67 if (!graph->build_structure()) {
71 thrust::device_vector<T> delta_x(graph->get_hessian_dimension());
76 std::chrono::duration<double>(std::chrono::steady_clock::now() - start)
79 if (options->verbose) {
80 std::cout << std::setprecision(12) << std::setw(18) <<
"Iteration"
81 << std::setw(24) <<
"Initial Chi2" << std::setw(24)
82 <<
"Current Chi2" << std::setw(24) << std::setw(24) <<
"Time"
83 << std::setw(24) <<
"Total Time" << std::endl;
85 <<
"---------------------------------------------------------------"
86 "---------------------------------------------------------------"
91 const T alpha = options->learning_rate;
93 const auto num_iterations = options->iterations;
94 for (
size_t i = 0; i < num_iterations && run; i++) {
96 start = std::chrono::steady_clock::now();
97 graph->linearize(*streams);
98 T chi2 = graph->chi2();
99 thrust::fill(thrust::device, delta_x.begin(), delta_x.end(), T(0.0));
100 ops::axpy_async(0, delta_x.size(), delta_x.data().get(), alpha,
101 graph->get_b().data().get(), delta_x.data().get());
102 cudaStreamSynchronize(0);
103 graph->apply_update(delta_x.data().get(), *streams);
106 graph->compute_error();
107 T new_chi2 = graph->chi2();
109 double iteration_time =
110 std::chrono::duration<double>(std::chrono::steady_clock::now() - start)
112 time += iteration_time;
113 if (options->verbose) {
114 std::cout << std::setprecision(12) << std::setw(18) << i << std::setw(24)
115 << chi2 << std::setw(24) << new_chi2 << std::setw(24)
116 << iteration_time << std::setw(24) << time << std::endl;
119 if (options->stop_flag && *(options->stop_flag)) {
120 std::cout <<
"Stopping optimization due to stop flag" << std::endl;
bool gradient_descent(Graph< T, S > *graph, GradientDescentOptions< T, S > *options)
Naive gradient descent optimization algorithm.
Definition gradient_descent.hpp:48