16template <
typename T>
using is_low_precision = is_half_or_bfloat16<T>;
18template <
typename T,
typename S>
19using InvP = std::conditional_t<is_low_precision<S>::value, T, S>;
25template <>
struct vec2_type<float> {
using type = float2; };
27template <>
struct vec2_type<double> {
using type = double2; };
29template <>
struct vec2_type<__half> {
using type = __half2; };
31template <>
struct vec2_type<__nv_bfloat16> {
using type = __nv_bfloat162; };
34template <
typename hp,
typename lp>
35__device__ lp convert_to_low_precision(
const hp &value) {
36 if constexpr (std::is_same_v<lp, __half2>) {
37 return __float22half2_rn(value);
38 }
else if constexpr (std::is_same_v<lp, __nv_bfloat162>) {
39 return __float22bfloat162_rn(value);
41 return static_cast<lp
>(value);