6#include <graphite/ops/state.hpp>
16 using InvP = std::conditional_t<is_low_precision<S>::value, T, S>;
21 virtual void apply_update_async(
const T *delta_x, T *jacobian_scales,
22 cudaStream_t stream) = 0;
23 virtual void augment_block_diagonal_async(InvP *block_diagonal,
24 InvP *scalar_diagonal,
const T mu,
25 const bool use_identity,
26 cudaStream_t stream) = 0;
27 virtual void apply_block_jacobi(T *z,
const T *r, InvP *block_diagonal,
28 cudaStream_t stream) = 0;
30 virtual size_t dimension()
const = 0;
31 virtual size_t count()
const = 0;
33 virtual void backup_parameters_async() = 0;
34 virtual void restore_parameters_async() = 0;
35 virtual void to_device() = 0;
37 virtual const std::unordered_map<size_t, size_t> &get_global_map()
const = 0;
38 virtual const size_t *get_hessian_ids()
const = 0;
39 virtual void set_hessian_column(
const size_t global_id,
40 const size_t hessian_column,
41 const size_t block_index) = 0;
42 virtual bool is_fixed(
const size_t id)
const = 0;
43 virtual bool is_active(
const size_t id)
const = 0;
44 virtual void set_eliminate(
bool eliminate) = 0;
45 virtual bool get_eliminate()
const = 0;
46 virtual uint8_t *get_active_state()
const = 0;
47 virtual const size_t *get_block_ids()
const = 0;
54template <
typename T,
typename S,
typename VTraits>
57 using InvP = std::conditional_t<is_low_precision<S>::value, T, S>;
59 using Traits = VTraits;
61 using VertexType =
typename Traits::Vertex;
62 using State = get_State_or_t<Traits, VertexType>;
65 thrust::device_vector<VertexType *> x_device;
66 thrust::host_vector<VertexType *> x_host;
67 thrust::device_vector<State> backup_state;
72 std::unordered_map<size_t, size_t> global_to_local_map;
73 std::vector<size_t> local_to_global_map;
74 thrust::host_vector<size_t> local_to_hessian_offsets;
75 thrust::device_vector<size_t> hessian_ids;
76 managed_vector<size_t> block_ids;
77 managed_vector<uint8_t> active_state;
79 static constexpr size_t dim = Traits::dimension;
90 cudaStream_t stream)
override {
91 ops::apply_update<T, S>(
this, delta_x, jacobian_scales, stream);
98 const T mu,
const bool use_identity,
99 cudaStream_t stream)
override {
100 ops::augment_block_diagonal<T, S>(
this, block_diagonal, scalar_diagonal, mu,
101 use_identity, stream);
108 cudaStream_t stream)
override {
109 ops::apply_block_jacobi<T, S>(
this, z, r, block_diagonal, stream);
117 hessian_ids = local_to_hessian_offsets;
124 VertexType **
vertices() {
return x_device.data().get(); }
130 VertexType **
vertices = x_device.data().get();
132 const int num_vertices =
static_cast<int>(
count());
133 const int num_threads = num_vertices;
134 const int block_size = 256;
135 const auto num_blocks = (num_threads + block_size - 1) / block_size;
136 backup_state.resize(num_vertices);
138 ops::backup_state_kernel<VertexType, State, Traits, T>
139 <<<num_blocks, block_size>>>(
vertices, backup_state.data().get(),
140 active_state.data().get(), num_vertices);
147 VertexType **
vertices = x_device.data().get();
149 const int num_vertices =
static_cast<int>(
count());
150 const int num_threads = num_vertices;
151 const int block_size = 256;
152 const auto num_blocks = (num_threads + block_size - 1) / block_size;
153 ops::set_state_kernel<VertexType, State, Traits, T>
154 <<<num_blocks, block_size>>>(
vertices, backup_state.data().get(),
155 active_state.data().get(), num_vertices);
162 virtual size_t count()
const override {
return x_host.size(); }
170 x_host.reserve(size);
171 global_to_local_map.reserve(size);
172 local_to_global_map.reserve(size);
173 local_to_hessian_offsets.reserve(size);
174 block_ids.reserve(size);
175 active_state.reserve(size);
187 if (global_to_local_map.find(
id) == global_to_local_map.end()) {
188 std::cerr <<
"Vertex with id " <<
id <<
" not found." << std::endl;
192 const auto local_id = global_to_local_map[id];
193 const auto last_index = x_host.size() - 1;
196 std::swap(x_host[local_id], x_host[last_index]);
198 if (local_to_hessian_offsets.size() > 0) {
199 std::swap(local_to_hessian_offsets[local_id],
200 local_to_hessian_offsets[last_index]);
204 const auto last_global_id = local_to_global_map[last_index];
205 global_to_local_map[last_global_id] = local_id;
206 local_to_global_map[local_id] = last_global_id;
209 active_state[local_id] = active_state[last_index];
212 active_state.pop_back();
216 local_to_hessian_offsets.pop_back();
217 global_to_local_map.erase(
id);
218 local_to_global_map.pop_back();
219 block_ids.pop_back();
228 if (global_to_local_map.find(
id) == global_to_local_map.end()) {
229 std::cerr <<
"Vertex with id " <<
id <<
" not found." << std::endl;
233 const auto local_id = global_to_local_map[id];
234 x_host[local_id] = vertex;
245 const bool fixed =
false) {
246 x_host.push_back(vertex);
247 const auto local_id = x_host.size() - 1;
248 global_to_local_map.insert({id, local_id});
249 local_to_global_map.push_back(
id);
250 local_to_hessian_offsets.push_back(0);
251 block_ids.push_back(0);
254 active_state.push_back(
static_cast<uint8_t
>(fixed));
263 const auto local_id = global_to_local_map.at(
id);
265 active_state[local_id] =
static_cast<uint8_t
>(fixed);
274 const auto local_id = global_to_local_map.at(
id);
275 return (active_state[local_id] & 0x1) > 0;
284 const auto local_id = global_to_local_map.at(
id);
285 return is_vertex_active(active_state.data().get(), local_id);
293 void set_eliminate(
bool eliminate)
override { this->eliminate = eliminate; }
306 return active_state.data().get();
314 return block_ids.data().get();
324 auto it = global_to_local_map.find(
id);
325 if (it != global_to_local_map.end()) {
326 return x_host[it->second];
328 std::cerr <<
"Vertex with id " <<
id <<
" not found." << std::endl;
339 return global_to_local_map.find(
id) != global_to_local_map.end();
342 const std::unordered_map<size_t, size_t> &get_global_map()
const override {
343 return global_to_local_map;
352 const size_t *get_hessian_ids()
const override {
353 return hessian_ids.data().get();
363 const size_t block_index) {
364 const auto local_id = global_to_local_map.at(global_id);
365 local_to_hessian_offsets[local_id] = hessian_column;
366 block_ids[local_id] = block_index;
375 backup_state.clear();
377 global_to_local_map.clear();
378 local_to_global_map.clear();
379 local_to_hessian_offsets.clear();
382 active_state.clear();
Represents a collection of optimizable variables to be processed together on the GPU.
Definition vertex.hpp:55
void clear()
Clears all vertices and resets the descriptor to an empty state.
Definition vertex.hpp:372
virtual void to_device() override
Prepares descriptor for GPU processing.
Definition vertex.hpp:115
bool exists(const size_t id) const
Checks if a vertex exists in the descriptor.
Definition vertex.hpp:338
virtual size_t count() const override
Returns the number of vertices in the descriptor (including fixed vertices).
Definition vertex.hpp:162
VertexType * get_vertex(const size_t id)
Retrieves a pointer to the vertex associated with the given ID.
Definition vertex.hpp:323
void apply_block_jacobi(T *z, const T *r, InvP *block_diagonal, cudaStream_t stream) override
Applies the block Jacobi preconditioner asynchronously.
Definition vertex.hpp:107
void replace_vertex(const size_t id, VertexType *vertex)
Replaces a vertex pointer in the descriptor for the given ID.
Definition vertex.hpp:227
void remove_vertex(const size_t id)
Removes a vertex from the descriptor by its ID.
Definition vertex.hpp:182
virtual void backup_parameters_async() override
Backs up the state of the vertices asynchronously.
Definition vertex.hpp:129
bool get_eliminate() const override
Gets the eliminate flag for the vertex descriptor.
Definition vertex.hpp:299
VertexType ** vertices()
Returns a pointer to the array of vertex pointers for the descriptor.
Definition vertex.hpp:124
virtual void restore_parameters_async() override
Restores the state of the vertices from the backup asynchronously.
Definition vertex.hpp:146
uint8_t * get_active_state() const override
Retrieves a pointer to the active state array for the vertices.
Definition vertex.hpp:305
void add_vertex(const size_t id, VertexType *vertex, const bool fixed=false)
Adds a vertex to the descriptor.
Definition vertex.hpp:244
void augment_block_diagonal_async(InvP *block_diagonal, InvP *scalar_diagonal, const T mu, const bool use_identity, cudaStream_t stream) override
Adds damping to the Hessian block diagonal asynchronously.
Definition vertex.hpp:97
void set_hessian_column(const size_t global_id, const size_t hessian_column, const size_t block_index)
Sets the Hessian column and block index for a vertex.
Definition vertex.hpp:362
size_t dimension() const override
Returns the dimension of the vertex parameterization.
Definition vertex.hpp:350
void reserve(size_t size)
Reserves memory for the specified number of vertices. You should call this before constructing the gr...
Definition vertex.hpp:169
bool is_fixed(const size_t id) const override
Checks if a vertex is fixed.
Definition vertex.hpp:273
void apply_update_async(const T *delta_x, T *jacobian_scales, cudaStream_t stream) override
Applies an update to the vertex parameters asynchronously.
Definition vertex.hpp:89
const size_t * get_block_ids() const override
Retrieves a pointer to the block IDs array for the vertices.
Definition vertex.hpp:313
void set_fixed(const size_t id, const bool fixed)
Sets the fixed state of a vertex.
Definition vertex.hpp:262
void set_eliminate(bool eliminate) override
Sets the eliminate flag for the vertex descriptor. All vertices will be excluded from the Schur compl...
Definition vertex.hpp:293
bool is_active(const size_t id) const override
Checks if a vertex is active.
Definition vertex.hpp:283
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4