16 #ifndef EIGEN_BFLOAT16_H
17 #define EIGEN_BFLOAT16_H
19 #include "../../InternalHeaderCheck.h"
21 #if defined(EIGEN_HAS_HIP_BF16)
28 #pragma push_macro("EIGEN_CONSTEXPR")
29 #undef EIGEN_CONSTEXPR
30 #define EIGEN_CONSTEXPR
33 #define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
35 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
36 PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) { \
37 return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
41 #if defined(EIGEN_HAS_HIP_BF16) && defined(EIGEN_GPU_COMPILE_PHASE)
42 #define EIGEN_USE_HIP_BF16
56 namespace bfloat16_impl {
58 #if defined(EIGEN_USE_HIP_BF16)
70 #if defined(EIGEN_HAS_HIP_BF16) && !defined(EIGEN_GPU_COMPILE_PHASE)
82 template <
bool AssumeArgumentIsNormalOrInfinityOrZero>
120 template<
typename RealScalar>
131 namespace bfloat16_impl {
132 template <
typename =
void>
230 class numeric_limits<const
Eigen::bfloat16> :
public numeric_limits<Eigen::bfloat16> {};
232 class numeric_limits<volatile
Eigen::bfloat16> :
public numeric_limits<Eigen::bfloat16> {};
234 class numeric_limits<const volatile
Eigen::bfloat16> :
public numeric_limits<Eigen::bfloat16> {};
239 namespace bfloat16_impl {
244 #if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC)
246 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
248 #pragma push_macro("EIGEN_DEVICE_FUNC")
249 #undef EIGEN_DEVICE_FUNC
250 #if (defined(EIGEN_HAS_GPU_BF16) && defined(EIGEN_HAS_NATIVE_BF16))
251 #define EIGEN_DEVICE_FUNC __host__
253 #define EIGEN_DEVICE_FUNC __host__ __device__
264 return bfloat16(
float(
a) +
static_cast<float>(
b));
267 return bfloat16(
static_cast<float>(
a) +
float(
b));
280 return numext::bit_cast<bfloat16>(
x);
309 return original_value;
314 return original_value;
323 return float(
a) < float(
b);
326 return float(
a) <= float(
b);
329 return float(
a) > float(
b);
332 return float(
a) >= float(
b);
335 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
336 #pragma pop_macro("EIGEN_DEVICE_FUNC")
343 return bfloat16(
static_cast<float>(
a) /
static_cast<float>(
b));
347 #if defined(EIGEN_USE_HIP_BF16)
348 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(
v, __bfloat16_raw::truncate));
361 #if defined(EIGEN_USE_HIP_BF16)
371 #if defined(EIGEN_USE_HIP_BF16)
382 #if defined(EIGEN_USE_HIP_BF16)
556 #if defined(EIGEN_USE_HIP_BF16)
565 input += rounding_bias;
572 #if defined(EIGEN_USE_HIP_BF16)
573 return static_cast<float>(h);
583 #if defined(EIGEN_USE_HIP_BF16)
591 #if defined(EIGEN_USE_HIP_BF16)
603 return numext::bit_cast<bfloat16>(
x);
630 return bfloat16(::atan2f(
float(
a),
float(
b)));
685 const float f1 =
static_cast<float>(
a);
686 const float f2 =
static_cast<float>(
b);
687 return f2 < f1 ?
b :
a;
691 const float f1 =
static_cast<float>(
a);
692 const float f2 =
static_cast<float>(
b);
693 return f1 < f2 ?
b :
a;
697 const float f1 =
static_cast<float>(
a);
698 const float f2 =
static_cast<float>(
b);
703 const float f1 =
static_cast<float>(
a);
704 const float f2 =
static_cast<float>(
b);
710 os << static_cast<float>(
v);
720 struct random_default_impl<bfloat16, false, false>
722 static inline bfloat16 run(
const bfloat16&
x,
const bfloat16&
y)
724 return x + (
y-
x) * bfloat16(
float(std::rand()) / float(RAND_MAX));
726 static inline bfloat16 run()
728 return run(bfloat16(-1.f), bfloat16(1.f));
732 template<>
struct is_arithmetic<bfloat16> {
enum { value =
true }; };
769 #if defined(EIGEN_HAS_HIP_BF16)
770 #pragma pop_macro("EIGEN_CONSTEXPR")
807 #if EIGEN_HAS_STD_HASH
810 struct hash<
Eigen::bfloat16> {
812 return static_cast<std::size_t
>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(
a));
829 #if defined(EIGEN_HIPCC)
831 #if defined(EIGEN_HAS_HIP_BF16)
834 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
835 return Eigen::numext::bit_cast<Eigen::bfloat16>(
static_cast<Eigen::numext::uint16_t>(__shfl(ivar, srcLane, width)));
839 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
840 return Eigen::numext::bit_cast<Eigen::bfloat16>(
static_cast<Eigen::numext::uint16_t>(__shfl_up(ivar, delta, width)));
844 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
845 return Eigen::numext::bit_cast<Eigen::bfloat16>(
static_cast<Eigen::numext::uint16_t>(__shfl_down(ivar, delta, width)));
849 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
850 return Eigen::numext::bit_cast<Eigen::bfloat16>(
static_cast<Eigen::numext::uint16_t>(__shfl_xor(ivar, laneMask, width)));
857 #if defined(EIGEN_HIPCC)
const Log1pReturnType log1p() const
const Expm1ReturnType expm1() const
Array< int, Dynamic, 1 > v
IndexedView_or_Block operator()(const RowIndices &rowIndices, const ColIndices &colIndices)
#define EIGEN_ALWAYS_INLINE
#define EIGEN_USING_STD(FUNC)
#define EIGEN_DEVICE_FUNC
#define EIGEN_NOT_A_MACRO
bfloat16 asin(const bfloat16 &a)
bfloat16 cos(const bfloat16 &a)
bfloat16 rint(const bfloat16 &a)
bfloat16 acosh(const bfloat16 &a)
bfloat16 acos(const bfloat16 &a)
float bfloat16_to_float(__bfloat16_raw h)
bool() isinf(const bfloat16 &a)
bfloat16 sin(const bfloat16 &a)
bfloat16 tanh(const bfloat16 &a)
bfloat16 fmax(const bfloat16 &a, const bfloat16 &b)
bfloat16 asinh(const bfloat16 &a)
bfloat16 floor(const bfloat16 &a)
bfloat16 expm1(const bfloat16 &a)
bfloat16 operator+(const bfloat16 &a, const bfloat16 &b)
bfloat16 & operator/=(bfloat16 &a, const bfloat16 &b)
bfloat16() max(const bfloat16 &a, const bfloat16 &b)
bfloat16 ceil(const bfloat16 &a)
__bfloat16_raw truncate_to_bfloat16(const float v)
bfloat16 & operator*=(bfloat16 &a, const bfloat16 &b)
bool operator==(const bfloat16 &a, const bfloat16 &b)
bfloat16 log1p(const bfloat16 &a)
bfloat16 atanh(const bfloat16 &a)
bfloat16 atan(const bfloat16 &a)
bfloat16 abs(const bfloat16 &a)
bfloat16 cosh(const bfloat16 &a)
bfloat16 log2(const bfloat16 &a)
__bfloat16_raw float_to_bfloat16_rtne< false >(float ff)
EIGEN_ALWAYS_INLINE std::ostream & operator<<(std::ostream &os, const bfloat16 &v)
bfloat16 log10(const bfloat16 &a)
EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value)
bool operator>=(const bfloat16 &a, const bfloat16 &b)
bool operator>(const bfloat16 &a, const bfloat16 &b)
bfloat16 sinh(const bfloat16 &a)
bfloat16 operator*(const bfloat16 &a, const bfloat16 &b)
bfloat16 operator++(bfloat16 &a)
bool() isfinite(const bfloat16 &a)
EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(const __bfloat16_raw &bf)
EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value)
bfloat16 round(const bfloat16 &a)
bfloat16 & operator+=(bfloat16 &a, const bfloat16 &b)
bfloat16 operator--(bfloat16 &a)
bfloat16 exp(const bfloat16 &a)
__bfloat16_raw float_to_bfloat16_rtne(float ff)
bfloat16 pow(const bfloat16 &a, const bfloat16 &b)
bool operator<(const bfloat16 &a, const bfloat16 &b)
bfloat16 & operator-=(bfloat16 &a, const bfloat16 &b)
bfloat16 atan2(const bfloat16 &a, const bfloat16 &b)
bfloat16 tan(const bfloat16 &a)
bool operator<=(const bfloat16 &a, const bfloat16 &b)
__bfloat16_raw float_to_bfloat16_rtne< true >(float ff)
bool operator!=(const bfloat16 &a, const bfloat16 &b)
bfloat16 fmin(const bfloat16 &a, const bfloat16 &b)
bfloat16 operator/(const bfloat16 &a, const bfloat16 &b)
bfloat16 log(const bfloat16 &a)
bool() isnan(const bfloat16 &a)
bfloat16 fmod(const bfloat16 &a, const bfloat16 &b)
bfloat16 operator-(const bfloat16 &a, const bfloat16 &b)
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
bfloat16 sqrt(const bfloat16 &a)
bool equal_strict(const X &x, const Y &y)
bool not_equal_strict(const X &x, const Y &y)
EIGEN_ALWAYS_INLINE bool() isinf(const Eigen::bfloat16 &h)
static constexpr EIGEN_ALWAYS_INLINE Scalar signbit(const Scalar &x)
EIGEN_ALWAYS_INLINE bool() isnan(const Eigen::bfloat16 &h)
EIGEN_ALWAYS_INLINE bool() isfinite(const Eigen::bfloat16 &h)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)
static EIGEN_CONSTEXPR Eigen::bfloat16 dummy_precision()
static EIGEN_CONSTEXPR Eigen::bfloat16 lowest()
static EIGEN_CONSTEXPR Eigen::bfloat16 quiet_NaN()
static EIGEN_CONSTEXPR Eigen::bfloat16 infinity()
static EIGEN_CONSTEXPR Eigen::bfloat16 highest()
static EIGEN_CONSTEXPR Eigen::bfloat16 epsilon()
Holds information about the various numeric (i.e. scalar) types allowed by Eigen.
EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw)
EIGEN_CONSTEXPR __bfloat16_raw()
EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw &h)
EIGEN_CONSTEXPR bfloat16_base()
static EIGEN_CONSTEXPR const bool tinyness_before
static EIGEN_CONSTEXPR Eigen::bfloat16() min()
static EIGEN_CONSTEXPR const std::float_round_style round_style
static EIGEN_CONSTEXPR const bool has_denorm_loss
static EIGEN_CONSTEXPR const int min_exponent
static EIGEN_CONSTEXPR const bool has_infinity
static EIGEN_CONSTEXPR const int radix
static EIGEN_CONSTEXPR const bool is_iec559
static EIGEN_CONSTEXPR const bool is_bounded
static EIGEN_CONSTEXPR const bool is_exact
static EIGEN_CONSTEXPR const bool is_integer
static EIGEN_CONSTEXPR Eigen::bfloat16 denorm_min()
static EIGEN_CONSTEXPR Eigen::bfloat16 quiet_NaN()
static EIGEN_CONSTEXPR Eigen::bfloat16 epsilon()
static EIGEN_CONSTEXPR const int max_digits10
static EIGEN_CONSTEXPR const int max_exponent10
static EIGEN_CONSTEXPR const int max_exponent
static EIGEN_CONSTEXPR const int digits10
static EIGEN_CONSTEXPR Eigen::bfloat16 lowest()
static EIGEN_CONSTEXPR const bool has_signaling_NaN
static EIGEN_CONSTEXPR Eigen::bfloat16 infinity()
static EIGEN_CONSTEXPR Eigen::bfloat16() max()
static EIGEN_CONSTEXPR Eigen::bfloat16 signaling_NaN()
static EIGEN_CONSTEXPR Eigen::bfloat16 round_error()
static EIGEN_CONSTEXPR const bool traps
static EIGEN_CONSTEXPR const bool is_modulo
static EIGEN_CONSTEXPR const bool is_specialized
static EIGEN_CONSTEXPR const bool is_signed
static EIGEN_CONSTEXPR const bool has_quiet_NaN
static EIGEN_CONSTEXPR const std::float_denorm_style has_denorm
static EIGEN_CONSTEXPR const int min_exponent10
static EIGEN_CONSTEXPR const int digits
EIGEN_CONSTEXPR bfloat16(bool b)
bfloat16_impl::__bfloat16_raw __bfloat16_raw
EIGEN_CONSTEXPR bfloat16(T val)
EIGEN_CONSTEXPR bfloat16()
EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw &h)
EIGEN_CONSTEXPR bfloat16(const std::complex< RealScalar > &val)