Graphite
Loading...
Searching...
No Matches
types.hpp
Go to the documentation of this file.
1
2#pragma once
3#include <cuda_bf16.h>
4#include <cuda_fp16.h>
5
6namespace graphite {
7
8// Type traits to determine if a type is a low precision type (half or bfloat16)
9
10template <typename T> struct is_half_or_bfloat16 : std::false_type {};
11
12template <> struct is_half_or_bfloat16<__half> : std::true_type {};
13
14template <> struct is_half_or_bfloat16<__nv_bfloat16> : std::true_type {};
15
16template <typename T> using is_low_precision = is_half_or_bfloat16<T>;
17
18template <typename T, typename S>
19using InvP = std::conditional_t<is_low_precision<S>::value, T, S>;
20
21// Vec2 types for different precisions
22
23template <typename T> struct vec2_type;
24
25template <> struct vec2_type<float> { using type = float2; };
26
27template <> struct vec2_type<double> { using type = double2; };
28
29template <> struct vec2_type<__half> { using type = __half2; };
30
31template <> struct vec2_type<__nv_bfloat16> { using type = __nv_bfloat162; };
32
33// Conversion functions from higher precision vec2 to lower precision vec2
34template <typename hp, typename lp>
35__device__ lp convert_to_low_precision(const hp &value) {
36 if constexpr (std::is_same_v<lp, __half2>) {
37 return __float22half2_rn(value);
38 } else if constexpr (std::is_same_v<lp, __nv_bfloat162>) {
39 return __float22bfloat162_rn(value);
40 } else {
41 return static_cast<lp>(value);
42 }
43}
44
45// For when we don't need storage
46struct Empty {};
47
48} // namespace graphite
Definition types.hpp:46
Definition types.hpp:10
Definition types.hpp:23