Graphite  0.5.0
GPU-accelerated graph optimization framework
Loading...
Searching...
No Matches
vertex.hpp
Go to the documentation of this file.
1
2#pragma once
3#include <graphite/active.hpp>
4#include <graphite/common.hpp>
6#include <graphite/ops/state.hpp>
8#include <graphite/stream.hpp>
9#include <graphite/vector.hpp>
10#include <type_traits>
11
12namespace graphite {
13
14template <typename T, typename S> class BaseVertexDescriptor {
15public:
16 using InvP = std::conditional_t<is_low_precision<S>::value, T, S>;
17
18 virtual ~BaseVertexDescriptor(){};
19
20 // virtual void update(const T* x, const T* delta) = 0;
21 virtual void apply_update_async(const T *delta_x, T *jacobian_scales,
22 cudaStream_t stream) = 0;
23 virtual void augment_block_diagonal_async(InvP *block_diagonal,
24 InvP *scalar_diagonal, const T mu,
25 const bool use_identity,
26 cudaStream_t stream) = 0;
27 virtual void apply_block_jacobi(T *z, const T *r, InvP *block_diagonal,
28 cudaStream_t stream) = 0;
29
30 virtual size_t dimension() const = 0;
31 virtual size_t count() const = 0;
32
33 virtual void backup_parameters_async() = 0;
34 virtual void restore_parameters_async() = 0;
35 virtual void to_device() = 0;
36
37 virtual const std::unordered_map<size_t, size_t> &get_global_map() const = 0;
38 virtual const size_t *get_hessian_ids() const = 0;
39 virtual void set_hessian_column(const size_t global_id,
40 const size_t hessian_column,
41 const size_t block_index) = 0;
42 virtual bool is_fixed(const size_t id) const = 0;
43 virtual bool is_active(const size_t id) const = 0;
44 virtual void set_eliminate(bool eliminate) = 0;
45 virtual bool get_eliminate() const = 0;
46 virtual uint8_t *get_active_state() const = 0;
47 virtual const size_t *get_block_ids() const = 0;
48};
49
54template <typename T, typename S, typename VTraits>
56public:
57 using InvP = std::conditional_t<is_low_precision<S>::value, T, S>;
58
59 using Traits = VTraits;
60
61 using VertexType = typename Traits::Vertex;
62 using State = get_State_or_t<Traits, VertexType>;
63
64private:
65 thrust::device_vector<VertexType *> x_device;
66 thrust::host_vector<VertexType *> x_host;
67 thrust::device_vector<State> backup_state;
68 bool eliminate;
69
70public:
71 // Mappings
72 std::unordered_map<size_t, size_t> global_to_local_map;
73 std::vector<size_t> local_to_global_map;
74 thrust::host_vector<size_t> local_to_hessian_offsets;
75 thrust::device_vector<size_t> hessian_ids;
76 managed_vector<size_t> block_ids;
77 managed_vector<uint8_t> active_state;
78
79 static constexpr size_t dim = Traits::dimension;
80
81public:
82 VertexDescriptor() : eliminate(false) {}
83
84 virtual ~VertexDescriptor(){};
85
89 void apply_update_async(const T *delta_x, T *jacobian_scales,
90 cudaStream_t stream) override {
91 ops::apply_update<T, S>(this, delta_x, jacobian_scales, stream);
92 }
93
97 void augment_block_diagonal_async(InvP *block_diagonal, InvP *scalar_diagonal,
98 const T mu, const bool use_identity,
99 cudaStream_t stream) override {
100 ops::augment_block_diagonal<T, S>(this, block_diagonal, scalar_diagonal, mu,
101 use_identity, stream);
102 }
103
107 void apply_block_jacobi(T *z, const T *r, InvP *block_diagonal,
108 cudaStream_t stream) override {
109 ops::apply_block_jacobi<T, S>(this, z, r, block_diagonal, stream);
110 }
111
115 virtual void to_device() override {
116 x_device = x_host;
117 hessian_ids = local_to_hessian_offsets;
118 }
119
124 VertexType **vertices() { return x_device.data().get(); }
125
129 virtual void backup_parameters_async() override {
130 VertexType **vertices = x_device.data().get();
131
132 const int num_vertices = static_cast<int>(count());
133 const int num_threads = num_vertices;
134 const int block_size = 256;
135 const auto num_blocks = (num_threads + block_size - 1) / block_size;
136 backup_state.resize(num_vertices);
137
138 ops::backup_state_kernel<VertexType, State, Traits, T>
139 <<<num_blocks, block_size>>>(vertices, backup_state.data().get(),
140 active_state.data().get(), num_vertices);
141 }
142
146 virtual void restore_parameters_async() override {
147 VertexType **vertices = x_device.data().get();
148
149 const int num_vertices = static_cast<int>(count());
150 const int num_threads = num_vertices;
151 const int block_size = 256;
152 const auto num_blocks = (num_threads + block_size - 1) / block_size;
153 ops::set_state_kernel<VertexType, State, Traits, T>
154 <<<num_blocks, block_size>>>(vertices, backup_state.data().get(),
155 active_state.data().get(), num_vertices);
156 }
157
162 virtual size_t count() const override { return x_host.size(); }
163
169 void reserve(size_t size) {
170 x_host.reserve(size);
171 global_to_local_map.reserve(size);
172 local_to_global_map.reserve(size);
173 local_to_hessian_offsets.reserve(size);
174 block_ids.reserve(size);
175 active_state.reserve(size);
176 }
177
182 void remove_vertex(const size_t id) {
183 if (count() == 0) {
184 return;
185 }
186
187 if (global_to_local_map.find(id) == global_to_local_map.end()) {
188 std::cerr << "Vertex with id " << id << " not found." << std::endl;
189 return;
190 }
191
192 const auto local_id = global_to_local_map[id];
193 const auto last_index = x_host.size() - 1;
194
195 // Swap the vertex to be removed with the last vertex
196 std::swap(x_host[local_id], x_host[last_index]);
197
198 if (local_to_hessian_offsets.size() > 0) { // may not be initialized yet
199 std::swap(local_to_hessian_offsets[local_id],
200 local_to_hessian_offsets[last_index]);
201 }
202
203 // Update the global_to_local_map for the swapped vertex
204 const auto last_global_id = local_to_global_map[last_index];
205 global_to_local_map[last_global_id] = local_id;
206 local_to_global_map[local_id] = last_global_id;
207
208 // Only need to update the fixed mask for the swapped vertex
209 active_state[local_id] = active_state[last_index];
210
211 // Remove unused entry
212 active_state.pop_back();
213
214 // Remove the last vertex
215 x_host.pop_back();
216 local_to_hessian_offsets.pop_back();
217 global_to_local_map.erase(id);
218 local_to_global_map.pop_back();
219 block_ids.pop_back();
220 }
221
227 void replace_vertex(const size_t id, VertexType *vertex) {
228 if (global_to_local_map.find(id) == global_to_local_map.end()) {
229 std::cerr << "Vertex with id " << id << " not found." << std::endl;
230 return;
231 }
232
233 const auto local_id = global_to_local_map[id];
234 x_host[local_id] = vertex;
235 }
236
244 void add_vertex(const size_t id, VertexType *vertex,
245 const bool fixed = false) {
246 x_host.push_back(vertex);
247 const auto local_id = x_host.size() - 1;
248 global_to_local_map.insert({id, local_id});
249 local_to_global_map.push_back(id);
250 local_to_hessian_offsets.push_back(0); // Initialize to 0
251 block_ids.push_back(0); // Initialize to 0
252
253 // Update fixed mask
254 active_state.push_back(static_cast<uint8_t>(fixed));
255 }
256
262 void set_fixed(const size_t id, const bool fixed) {
263 const auto local_id = global_to_local_map.at(id);
264 // Don't preserve MSB flag
265 active_state[local_id] = static_cast<uint8_t>(fixed);
266 }
267
273 bool is_fixed(const size_t id) const override {
274 const auto local_id = global_to_local_map.at(id);
275 return (active_state[local_id] & 0x1) > 0;
276 }
277
283 bool is_active(const size_t id) const override {
284 const auto local_id = global_to_local_map.at(id);
285 return is_vertex_active(active_state.data().get(), local_id);
286 }
287
293 void set_eliminate(bool eliminate) override { this->eliminate = eliminate; }
294
299 bool get_eliminate() const override { return eliminate; }
300
305 uint8_t *get_active_state() const override {
306 return active_state.data().get();
307 }
308
313 const size_t *get_block_ids() const override {
314 return block_ids.data().get();
315 }
316
323 VertexType *get_vertex(const size_t id) {
324 auto it = global_to_local_map.find(id);
325 if (it != global_to_local_map.end()) {
326 return x_host[it->second];
327 } else {
328 std::cerr << "Vertex with id " << id << " not found." << std::endl;
329 return nullptr;
330 }
331 }
332
338 bool exists(const size_t id) const {
339 return global_to_local_map.find(id) != global_to_local_map.end();
340 }
341
342 const std::unordered_map<size_t, size_t> &get_global_map() const override {
343 return global_to_local_map;
344 }
345
350 size_t dimension() const override { return dim; }
351
352 const size_t *get_hessian_ids() const override {
353 return hessian_ids.data().get();
354 }
355
362 void set_hessian_column(const size_t global_id, const size_t hessian_column,
363 const size_t block_index) {
364 const auto local_id = global_to_local_map.at(global_id);
365 local_to_hessian_offsets[local_id] = hessian_column;
366 block_ids[local_id] = block_index;
367 }
368
372 void clear() {
373 x_device.clear();
374 x_host.clear();
375 backup_state.clear();
376
377 global_to_local_map.clear();
378 local_to_global_map.clear();
379 local_to_hessian_offsets.clear();
380 hessian_ids.clear();
381 block_ids.clear();
382 active_state.clear();
383 }
384};
385
386} // namespace graphite
Definition vertex.hpp:14
Represents a collection of optimizable variables to be processed together on the GPU.
Definition vertex.hpp:55
void clear()
Clears all vertices and resets the descriptor to an empty state.
Definition vertex.hpp:372
virtual void to_device() override
Prepares descriptor for GPU processing.
Definition vertex.hpp:115
bool exists(const size_t id) const
Checks if a vertex exists in the descriptor.
Definition vertex.hpp:338
virtual size_t count() const override
Returns the number of vertices in the descriptor (including fixed vertices).
Definition vertex.hpp:162
VertexType * get_vertex(const size_t id)
Retrieves a pointer to the vertex associated with the given ID.
Definition vertex.hpp:323
void apply_block_jacobi(T *z, const T *r, InvP *block_diagonal, cudaStream_t stream) override
Applies the block Jacobi preconditioner asynchronously.
Definition vertex.hpp:107
void replace_vertex(const size_t id, VertexType *vertex)
Replaces a vertex pointer in the descriptor for the given ID.
Definition vertex.hpp:227
void remove_vertex(const size_t id)
Removes a vertex from the descriptor by its ID.
Definition vertex.hpp:182
virtual void backup_parameters_async() override
Backs up the state of the vertices asynchronously.
Definition vertex.hpp:129
bool get_eliminate() const override
Gets the eliminate flag for the vertex descriptor.
Definition vertex.hpp:299
VertexType ** vertices()
Returns a pointer to the array of vertex pointers for the descriptor.
Definition vertex.hpp:124
virtual void restore_parameters_async() override
Restores the state of the vertices from the backup asynchronously.
Definition vertex.hpp:146
uint8_t * get_active_state() const override
Retrieves a pointer to the active state array for the vertices.
Definition vertex.hpp:305
void add_vertex(const size_t id, VertexType *vertex, const bool fixed=false)
Adds a vertex to the descriptor.
Definition vertex.hpp:244
void augment_block_diagonal_async(InvP *block_diagonal, InvP *scalar_diagonal, const T mu, const bool use_identity, cudaStream_t stream) override
Adds damping to the Hessian block diagonal asynchronously.
Definition vertex.hpp:97
void set_hessian_column(const size_t global_id, const size_t hessian_column, const size_t block_index)
Sets the Hessian column and block index for a vertex.
Definition vertex.hpp:362
size_t dimension() const override
Returns the dimension of the vertex parameterization.
Definition vertex.hpp:350
void reserve(size_t size)
Reserves memory for the specified number of vertices. You should call this before constructing the gr...
Definition vertex.hpp:169
bool is_fixed(const size_t id) const override
Checks if a vertex is fixed.
Definition vertex.hpp:273
void apply_update_async(const T *delta_x, T *jacobian_scales, cudaStream_t stream) override
Applies an update to the vertex parameters asynchronously.
Definition vertex.hpp:89
const size_t * get_block_ids() const override
Retrieves a pointer to the block IDs array for the vertices.
Definition vertex.hpp:313
void set_fixed(const size_t id, const bool fixed)
Sets the fixed state of a vertex.
Definition vertex.hpp:262
void set_eliminate(bool eliminate) override
Sets the eliminate flag for the vertex descriptor. All vertices will be excluded from the Schur compl...
Definition vertex.hpp:293
bool is_active(const size_t id) const override
Checks if a vertex is active.
Definition vertex.hpp:283
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4