Graphite
Loading...
Searching...
No Matches
vertex.hpp
Go to the documentation of this file.
1
2#pragma once
3#include <graphite/common.hpp>
6#include <graphite/stream.hpp>
7#include <graphite/vector.hpp>
8#include <type_traits>
9
10namespace graphite {
11
12template <typename, typename = void>
13struct has_type_alias_State : std::false_type {};
14
15template <typename T>
16struct has_type_alias_State<T, std::void_t<typename T::State>>
17 : std::true_type {};
18
19template <typename T, typename Fallback, typename = void> struct get_State_or {
20 using type = Fallback;
21};
22
23// Specialization: use T::State if it exists
24template <typename T, typename Fallback>
25struct get_State_or<T, Fallback, std::void_t<typename T::State>> {
26 using type = typename T::State;
27};
28
29// Helper alias
30template <typename T, typename Fallback>
31using get_State_or_t = typename get_State_or<T, Fallback>::type;
32
33template <typename VertexType, typename State, typename Traits, typename T>
34__global__ void backup_state_kernel(VertexType **vertices, State *dst,
35 const uint8_t *active_state,
36 const size_t num_vertices) {
37
38 const size_t vertex_id = get_thread_id();
39
40 if (vertex_id >= num_vertices || !is_vertex_active(active_state, vertex_id))
41 return;
43 dst[vertex_id] = Traits::get_state(*vertices[vertex_id]);
44 } else {
45 dst[vertex_id] = *vertices[vertex_id];
46 }
47}
48
49template <typename VertexType, typename State, typename Traits, typename T>
50__global__ void set_state_kernel(VertexType **vertices, const State *src,
51 const uint8_t *active_state,
52 const size_t num_vertices) {
53
54 const size_t vertex_id = get_thread_id();
55
56 if (vertex_id >= num_vertices || !is_vertex_active(active_state, vertex_id))
57 return;
58
59 if constexpr (has_type_alias_State<Traits>::value) {
60 Traits::set_state(*vertices[vertex_id], src[vertex_id]);
61 } else {
62 *vertices[vertex_id] = src[vertex_id];
63 }
64}
65
66template <typename T, typename S> class BaseVertexDescriptor {
67public:
68 using InvP = std::conditional_t<is_low_precision<S>::value, T, S>;
69
70 virtual ~BaseVertexDescriptor(){};
71
72 // virtual void update(const T* x, const T* delta) = 0;
73 virtual void apply_update_async(const T *delta_x, T *jacobian_scales,
74 cudaStream_t stream) = 0;
75 virtual void augment_block_diagonal_async(InvP *block_diagonal,
76 InvP *scalar_diagonal, T mu,
77 cudaStream_t stream) = 0;
78 virtual void apply_block_jacobi(T *z, const T *r, InvP *block_diagonal,
79 cudaStream_t stream) = 0;
80
81 virtual size_t dimension() const = 0;
82 virtual size_t count() const = 0;
83
84 virtual void backup_parameters_async() = 0;
85 virtual void restore_parameters_async() = 0;
86 virtual void to_device() = 0;
87
88 virtual const std::unordered_map<size_t, size_t> &get_global_map() const = 0;
89 virtual const size_t *get_hessian_ids() const = 0;
90 virtual void set_hessian_column(const size_t global_id,
91 const size_t hessian_column,
92 const size_t block_index) = 0;
93 virtual bool is_fixed(const size_t id) const = 0;
94 virtual bool is_active(const size_t id) const = 0;
95 virtual uint8_t *get_active_state() const = 0;
96 virtual const size_t *get_block_ids() const = 0;
97};
98
103template <typename T, typename S, typename VTraits>
105public:
106 using InvP = std::conditional_t<is_low_precision<S>::value, T, S>;
107
108 using Traits = VTraits;
109
110 using VertexType = typename Traits::Vertex;
111 using State = get_State_or_t<Traits, VertexType>;
112
113private:
114 thrust::device_vector<VertexType *> x_device;
115 thrust::host_vector<VertexType *> x_host;
116 thrust::device_vector<State> backup_state;
117
118public:
119 // Mappings
120 std::unordered_map<size_t, size_t> global_to_local_map;
121 std::vector<size_t> local_to_global_map;
122 thrust::host_vector<size_t> local_to_hessian_offsets;
123 thrust::device_vector<size_t> hessian_ids;
124 managed_vector<size_t> block_ids;
125 managed_vector<uint8_t> active_state;
126
127 static constexpr size_t dim = Traits::dimension;
128
129public:
130 virtual ~VertexDescriptor(){};
131
135 void apply_update_async(const T *delta_x, T *jacobian_scales,
136 cudaStream_t stream) override {
137 ops::apply_update<T, S>(this, delta_x, jacobian_scales, stream);
138 }
139
143 void augment_block_diagonal_async(InvP *block_diagonal, InvP *scalar_diagonal,
144 T mu, cudaStream_t stream) override {
145 ops::augment_block_diagonal<T, S>(this, block_diagonal, scalar_diagonal, mu,
146 stream);
147 }
148
152 void apply_block_jacobi(T *z, const T *r, InvP *block_diagonal,
153 cudaStream_t stream) override {
154 ops::apply_block_jacobi<T, S>(this, z, r, block_diagonal, stream);
155 }
156
160 virtual void to_device() override {
161 x_device = x_host;
162 hessian_ids = local_to_hessian_offsets;
163 }
164
169 VertexType **vertices() { return x_device.data().get(); }
170
174 virtual void backup_parameters_async() override {
175 VertexType **vertices = x_device.data().get();
176
177 const int num_vertices = static_cast<int>(count());
178 const int num_threads = num_vertices;
179 const int block_size = 256;
180 const auto num_blocks = (num_threads + block_size - 1) / block_size;
181 backup_state.resize(num_vertices);
182
183 backup_state_kernel<VertexType, State, Traits, T>
184 <<<num_blocks, block_size>>>(vertices, backup_state.data().get(),
185 active_state.data().get(), num_vertices);
186 }
187
191 virtual void restore_parameters_async() override {
192 VertexType **vertices = x_device.data().get();
193
194 const int num_vertices = static_cast<int>(count());
195 const int num_threads = num_vertices;
196 const int block_size = 256;
197 const auto num_blocks = (num_threads + block_size - 1) / block_size;
198 set_state_kernel<VertexType, State, Traits, T>
199 <<<num_blocks, block_size>>>(vertices, backup_state.data().get(),
200 active_state.data().get(), num_vertices);
201 }
202
207 virtual size_t count() const override { return x_host.size(); }
208
214 void reserve(size_t size) {
215 x_host.reserve(size);
216 global_to_local_map.reserve(size);
217 local_to_global_map.reserve(size);
218 local_to_hessian_offsets.reserve(size);
219 block_ids.reserve(size);
220 active_state.reserve(size);
221 }
222
227 void remove_vertex(const size_t id) {
228 if (count() == 0) {
229 return;
230 }
231
232 if (global_to_local_map.find(id) == global_to_local_map.end()) {
233 std::cerr << "Vertex with id " << id << " not found." << std::endl;
234 return;
235 }
236
237 const auto local_id = global_to_local_map[id];
238 const auto last_index = x_host.size() - 1;
239
240 // Swap the vertex to be removed with the last vertex
241 std::swap(x_host[local_id], x_host[last_index]);
242
243 if (local_to_hessian_offsets.size() > 0) { // may not be initialized yet
244 std::swap(local_to_hessian_offsets[local_id],
245 local_to_hessian_offsets[last_index]);
246 }
247
248 // Update the global_to_local_map for the swapped vertex
249 const auto last_global_id = local_to_global_map[last_index];
250 global_to_local_map[last_global_id] = local_id;
251 local_to_global_map[local_id] = last_global_id;
252
253 // Only need to update the fixed mask for the swapped vertex
254 active_state[local_id] = active_state[last_index];
255
256 // Remove unused entry
257 active_state.pop_back();
258
259 // Remove the last vertex
260 x_host.pop_back();
261 local_to_hessian_offsets.pop_back();
262 global_to_local_map.erase(id);
263 local_to_global_map.pop_back();
264 block_ids.pop_back();
265 }
266
272 void replace_vertex(const size_t id, VertexType *vertex) {
273 if (global_to_local_map.find(id) == global_to_local_map.end()) {
274 std::cerr << "Vertex with id " << id << " not found." << std::endl;
275 return;
276 }
277
278 const auto local_id = global_to_local_map[id];
279 x_host[local_id] = vertex;
280 }
281
289 void add_vertex(const size_t id, VertexType *vertex,
290 const bool fixed = false) {
291 x_host.push_back(vertex);
292 const auto local_id = x_host.size() - 1;
293 global_to_local_map.insert({id, local_id});
294 local_to_global_map.push_back(id);
295 local_to_hessian_offsets.push_back(0); // Initialize to 0
296 block_ids.push_back(0); // Initialize to 0
297
298 // Update fixed mask
299 active_state.push_back(static_cast<uint8_t>(fixed));
300 }
301
307 void set_fixed(const size_t id, const bool fixed) {
308 const auto local_id = global_to_local_map.at(id);
309 // Don't preserve MSB flag
310 active_state[local_id] = static_cast<uint8_t>(fixed);
311 }
312
318 bool is_fixed(const size_t id) const override {
319 const auto local_id = global_to_local_map.at(id);
320 return (active_state[local_id] & 0x1) > 0;
321 }
322
328 bool is_active(const size_t id) const override {
329 const auto local_id = global_to_local_map.at(id);
330 return is_vertex_active(active_state.data().get(), local_id);
331 }
332
337 uint8_t *get_active_state() const override {
338 return active_state.data().get();
339 }
340
345 const size_t *get_block_ids() const override {
346 return block_ids.data().get();
347 }
348
355 VertexType *get_vertex(const size_t id) {
356 auto it = global_to_local_map.find(id);
357 if (it != global_to_local_map.end()) {
358 return x_host[it->second];
359 } else {
360 std::cerr << "Vertex with id " << id << " not found." << std::endl;
361 return nullptr;
362 }
363 }
364
370 bool exists(const size_t id) const {
371 return global_to_local_map.find(id) != global_to_local_map.end();
372 }
373
374 const std::unordered_map<size_t, size_t> &get_global_map() const override {
375 return global_to_local_map;
376 }
377
382 size_t dimension() const override { return dim; }
383
384 const size_t *get_hessian_ids() const override {
385 return hessian_ids.data().get();
386 }
387
394 void set_hessian_column(const size_t global_id, const size_t hessian_column,
395 const size_t block_index) {
396 const auto local_id = global_to_local_map.at(global_id);
397 local_to_hessian_offsets[local_id] = hessian_column;
398 block_ids[local_id] = block_index;
399 }
400
404 void clear() {
405 x_device.clear();
406 x_host.clear();
407 backup_state.clear();
408
409 global_to_local_map.clear();
410 local_to_global_map.clear();
411 local_to_hessian_offsets.clear();
412 hessian_ids.clear();
413 block_ids.clear();
414 active_state.clear();
415 }
416};
417
418} // namespace graphite
Definition vertex.hpp:66
Represents a collection of optimizable variables to be processed together on the GPU.
Definition vertex.hpp:104
void clear()
Clears all vertices and resets the descriptor to an empty state.
Definition vertex.hpp:404
virtual void to_device() override
Prepares descriptor for GPU processing.
Definition vertex.hpp:160
bool exists(const size_t id) const
Checks if a vertex exists in the descriptor.
Definition vertex.hpp:370
virtual size_t count() const override
Returns the number of vertices in the descriptor (including fixed vertices).
Definition vertex.hpp:207
VertexType * get_vertex(const size_t id)
Retrieves a pointer to the vertex associated with the given ID.
Definition vertex.hpp:355
void apply_block_jacobi(T *z, const T *r, InvP *block_diagonal, cudaStream_t stream) override
Applies the block Jacobi preconditioner asynchronously.
Definition vertex.hpp:152
void replace_vertex(const size_t id, VertexType *vertex)
Replaces a vertex pointer in the descriptor for the given ID.
Definition vertex.hpp:272
void remove_vertex(const size_t id)
Removes a vertex from the descriptor by its ID.
Definition vertex.hpp:227
virtual void backup_parameters_async() override
Backs up the state of the vertices asynchronously.
Definition vertex.hpp:174
VertexType ** vertices()
Returns a pointer to the array of vertex pointers for the descriptor.
Definition vertex.hpp:169
virtual void restore_parameters_async() override
Restores the state of the vertices from the backup asynchronously.
Definition vertex.hpp:191
uint8_t * get_active_state() const override
Retrieves a pointer to the active state array for the vertices.
Definition vertex.hpp:337
void augment_block_diagonal_async(InvP *block_diagonal, InvP *scalar_diagonal, T mu, cudaStream_t stream) override
Adds damping to the Hessian block diagonal asynchronously.
Definition vertex.hpp:143
void add_vertex(const size_t id, VertexType *vertex, const bool fixed=false)
Adds a vertex to the descriptor.
Definition vertex.hpp:289
void set_hessian_column(const size_t global_id, const size_t hessian_column, const size_t block_index)
Sets the Hessian column and block index for a vertex.
Definition vertex.hpp:394
size_t dimension() const override
Returns the dimension of the vertex parameterization.
Definition vertex.hpp:382
void reserve(size_t size)
Reserves memory for the specified number of vertices. You should call this before constructing the gr...
Definition vertex.hpp:214
bool is_fixed(const size_t id) const override
Checks if a vertex is fixed.
Definition vertex.hpp:318
void apply_update_async(const T *delta_x, T *jacobian_scales, cudaStream_t stream) override
Applies an update to the vertex parameters asynchronously.
Definition vertex.hpp:135
const size_t * get_block_ids() const override
Retrieves a pointer to the block IDs array for the vertices.
Definition vertex.hpp:345
void set_fixed(const size_t id, const bool fixed)
Sets the fixed state of a vertex.
Definition vertex.hpp:307
bool is_active(const size_t id) const override
Checks if a vertex is active.
Definition vertex.hpp:328
Definition vertex.hpp:19
Definition vertex.hpp:13