Ginkgo Generated from branch based on main. Ginkgo version 1.9.0
A numerical linear algebra library targeting many-core architectures
 
Loading...
Searching...
No Matches
half.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_HALF_HPP_
6#define GKO_PUBLIC_CORE_BASE_HALF_HPP_
7
8
9#include <climits>
10#include <complex>
11#include <cstdint>
12#include <cstring>
13#include <type_traits>
14
15
16class __half;
17
18
19namespace gko {
20
21
22template <typename, std::size_t, std::size_t>
24
25
26class half;
27
28
29namespace detail {
30
31
32constexpr std::size_t byte_size = CHAR_BIT;
33
34template <std::size_t, typename = void>
35struct uint_of_impl {};
36
37template <std::size_t Bits>
38struct uint_of_impl<Bits, std::enable_if_t<(Bits <= 16)>> {
39 using type = std::uint16_t;
40};
41
42template <std::size_t Bits>
43struct uint_of_impl<Bits, std::enable_if_t<(16 < Bits && Bits <= 32)>> {
44 using type = std::uint32_t;
45};
46
47template <std::size_t Bits>
48struct uint_of_impl<Bits, std::enable_if_t<(32 < Bits) && (Bits <= 64)>> {
49 using type = std::uint64_t;
50};
51
52template <std::size_t Bits>
53using uint_of = typename uint_of_impl<Bits>::type;
54
55
56template <typename T>
57struct basic_float_traits {};
58
59template <>
60struct basic_float_traits<half> {
61 using type = half;
62 static constexpr int sign_bits = 1;
63 static constexpr int significand_bits = 10;
64 static constexpr int exponent_bits = 5;
65 static constexpr bool rounds_to_nearest = true;
66};
67
68template <>
69struct basic_float_traits<__half> {
70 using type = __half;
71 static constexpr int sign_bits = 1;
72 static constexpr int significand_bits = 10;
73 static constexpr int exponent_bits = 5;
74 static constexpr bool rounds_to_nearest = true;
75};
76
77template <>
78struct basic_float_traits<float> {
79 using type = float;
80 static constexpr int sign_bits = 1;
81 static constexpr int significand_bits = 23;
82 static constexpr int exponent_bits = 8;
83 static constexpr bool rounds_to_nearest = true;
84};
85
86template <>
87struct basic_float_traits<double> {
88 using type = double;
89 static constexpr int sign_bits = 1;
90 static constexpr int significand_bits = 52;
91 static constexpr int exponent_bits = 11;
92 static constexpr bool rounds_to_nearest = true;
93};
94
95template <typename FloatType, std::size_t NumComponents,
96 std::size_t ComponentId>
97struct basic_float_traits<truncated<FloatType, NumComponents, ComponentId>> {
98 using type = truncated<FloatType, NumComponents, ComponentId>;
99 static constexpr int sign_bits = ComponentId == 0 ? 1 : 0;
100 static constexpr int exponent_bits =
101 ComponentId == 0 ? basic_float_traits<FloatType>::exponent_bits : 0;
102 static constexpr int significand_bits =
103 ComponentId == 0 ? sizeof(type) * byte_size - exponent_bits - 1
104 : sizeof(type) * byte_size;
105 static constexpr bool rounds_to_nearest = false;
106};
107
108
109template <typename UintType>
110constexpr UintType create_ones(int n)
111{
112 return (n == sizeof(UintType) * byte_size ? static_cast<UintType>(0)
113 : static_cast<UintType>(1) << n) -
114 static_cast<UintType>(1);
115}
116
117
118template <typename T>
119struct float_traits {
120 using type = typename basic_float_traits<T>::type;
121 using bits_type = uint_of<sizeof(type) * byte_size>;
122 static constexpr int sign_bits = basic_float_traits<T>::sign_bits;
123 static constexpr int significand_bits =
124 basic_float_traits<T>::significand_bits;
125 static constexpr int exponent_bits = basic_float_traits<T>::exponent_bits;
126 static constexpr bits_type significand_mask =
127 create_ones<bits_type>(significand_bits);
128 static constexpr bits_type exponent_mask =
129 create_ones<bits_type>(significand_bits + exponent_bits) -
130 significand_mask;
131 static constexpr bits_type bias_mask =
132 create_ones<bits_type>(significand_bits + exponent_bits - 1) -
133 significand_mask;
134 static constexpr bits_type sign_mask =
135 create_ones<bits_type>(sign_bits + significand_bits + exponent_bits) -
136 exponent_mask - significand_mask;
137 static constexpr bool rounds_to_nearest =
138 basic_float_traits<T>::rounds_to_nearest;
139
140 static constexpr auto eps =
141 1.0 / (1ll << (significand_bits + rounds_to_nearest));
142
143 static constexpr bool is_inf(bits_type data)
144 {
145 return (data & exponent_mask) == exponent_mask &&
146 (data & significand_mask) == bits_type{};
147 }
148
149 static constexpr bool is_nan(bits_type data)
150 {
151 return (data & exponent_mask) == exponent_mask &&
152 (data & significand_mask) != bits_type{};
153 }
154
155 static constexpr bool is_denom(bits_type data)
156 {
157 return (data & exponent_mask) == bits_type{};
158 }
159};
160
161
162template <typename SourceType, typename ResultType,
163 bool = (sizeof(SourceType) <= sizeof(ResultType))>
164struct precision_converter;
165
166// upcasting implementation details
167template <typename SourceType, typename ResultType>
168struct precision_converter<SourceType, ResultType, true> {
169 using source_traits = float_traits<SourceType>;
170 using result_traits = float_traits<ResultType>;
171 using source_bits = typename source_traits::bits_type;
172 using result_bits = typename result_traits::bits_type;
173
174 static_assert(source_traits::exponent_bits <=
175 result_traits::exponent_bits &&
176 source_traits::significand_bits <=
177 result_traits::significand_bits,
178 "SourceType has to have both lower range and precision or "
179 "higher range and precision than ResultType");
180
181 static constexpr int significand_offset =
182 result_traits::significand_bits - source_traits::significand_bits;
183 static constexpr int exponent_offset = significand_offset;
184 static constexpr int sign_offset = result_traits::exponent_bits -
185 source_traits::exponent_bits +
186 exponent_offset;
187 static constexpr result_bits bias_change =
188 result_traits::bias_mask -
189 (static_cast<result_bits>(source_traits::bias_mask) << exponent_offset);
190
191 static constexpr result_bits shift_significand(source_bits data) noexcept
192 {
193 return static_cast<result_bits>(data & source_traits::significand_mask)
194 << significand_offset;
195 }
196
197 static constexpr result_bits shift_exponent(source_bits data) noexcept
198 {
199 return update_bias(
200 static_cast<result_bits>(data & source_traits::exponent_mask)
201 << exponent_offset);
202 }
203
204 static constexpr result_bits shift_sign(source_bits data) noexcept
205 {
206 return static_cast<result_bits>(data & source_traits::sign_mask)
207 << sign_offset;
208 }
209
210private:
211 static constexpr result_bits update_bias(result_bits data) noexcept
212 {
213 return data == typename result_traits::bits_type{} ? data
214 : data + bias_change;
215 }
216};
217
218// downcasting implementation details
219template <typename SourceType, typename ResultType>
220struct precision_converter<SourceType, ResultType, false> {
221 using source_traits = float_traits<SourceType>;
222 using result_traits = float_traits<ResultType>;
223 using source_bits = typename source_traits::bits_type;
224 using result_bits = typename result_traits::bits_type;
225
226 static_assert(source_traits::exponent_bits >=
227 result_traits::exponent_bits &&
228 source_traits::significand_bits >=
229 result_traits::significand_bits,
230 "SourceType has to have both lower range and precision or "
231 "higher range and precision than ResultType");
232
233 static constexpr int significand_offset =
234 source_traits::significand_bits - result_traits::significand_bits;
235 static constexpr int exponent_offset = significand_offset;
236 static constexpr int sign_offset = source_traits::exponent_bits -
237 result_traits::exponent_bits +
238 exponent_offset;
239 static constexpr source_bits bias_change =
240 (source_traits::bias_mask >> exponent_offset) -
241 static_cast<source_bits>(result_traits::bias_mask);
242
243 static constexpr result_bits shift_significand(source_bits data) noexcept
244 {
245 return static_cast<result_bits>(
246 (data & source_traits::significand_mask) >> significand_offset);
247 }
248
249 static constexpr result_bits shift_exponent(source_bits data) noexcept
250 {
251 return static_cast<result_bits>(update_bias(
252 (data & source_traits::exponent_mask) >> exponent_offset));
253 }
254
255 static constexpr result_bits shift_sign(source_bits data) noexcept
256 {
257 return static_cast<result_bits>((data & source_traits::sign_mask) >>
258 sign_offset);
259 }
260
261private:
262 static constexpr source_bits update_bias(source_bits data) noexcept
263 {
264 return data <= bias_change ? typename source_traits::bits_type{}
265 : limit_exponent(data - bias_change);
266 }
267
268 static constexpr source_bits limit_exponent(source_bits data) noexcept
269 {
270 return data >= static_cast<source_bits>(result_traits::exponent_mask)
271 ? static_cast<source_bits>(result_traits::exponent_mask)
272 : data;
273 }
274};
275
276
277} // namespace detail
278
279
286class alignas(std::uint16_t) half {
287public:
288 // create half value from the bits directly.
289 static constexpr half create_from_bits(const std::uint16_t& bits) noexcept
290 {
291 half result;
292 result.data_ = bits;
293 return result;
294 }
295
296 // TODO: NVHPC (host side) may not use zero initialization for the data
297 // member by default constructor in some cases. Not sure whether it is
298 // caused by something else in jacobi or isai.
299 constexpr half() noexcept : data_(0){};
300
301 template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
302 half(const T& val) : data_(0)
303 {
304 this->float2half(static_cast<float>(val));
305 }
306
307 template <typename V>
308 half& operator=(const V& val)
309 {
310 this->float2half(static_cast<float>(val));
311 return *this;
312 }
313
314 operator float() const noexcept
315 {
316 const auto bits = half2float(data_);
317 float ans(0);
318 std::memcpy(&ans, &bits, sizeof(float));
319 return ans;
320 }
321
322 // can not use half operator _op(const half) for half + half
323 // operation will cast it to float and then do float operation such that it
324 // becomes float in the end.
325#define HALF_OPERATOR(_op, _opeq) \
326 friend half operator _op(const half& lhf, const half& rhf) \
327 { \
328 return static_cast<half>(static_cast<float>(lhf) \
329 _op static_cast<float>(rhf)); \
330 } \
331 half& operator _opeq(const half& hf) \
332 { \
333 auto result = *this _op hf; \
334 data_ = result.data_; \
335 return *this; \
336 }
337
338 HALF_OPERATOR(+, +=)
339 HALF_OPERATOR(-, -=)
340 HALF_OPERATOR(*, *=)
341 HALF_OPERATOR(/, /=)
342
343#undef HALF_OPERATOR
344
345 // Do operation with different type
346 // If it is floating point, using floating point as type.
347 // If it is integer, using half as type
348#define HALF_FRIEND_OPERATOR(_op, _opeq) \
349 template <typename T> \
350 friend std::enable_if_t< \
351 !std::is_same<T, half>::value && std::is_scalar<T>::value, \
352 std::conditional_t<std::is_floating_point<T>::value, T, half>> \
353 operator _op(const half& hf, const T& val) \
354 { \
355 using type = \
356 std::conditional_t<std::is_floating_point<T>::value, T, half>; \
357 auto result = static_cast<type>(hf); \
358 result _opeq static_cast<type>(val); \
359 return result; \
360 } \
361 template <typename T> \
362 friend std::enable_if_t< \
363 !std::is_same<T, half>::value && std::is_scalar<T>::value, \
364 std::conditional_t<std::is_floating_point<T>::value, T, half>> \
365 operator _op(const T& val, const half& hf) \
366 { \
367 using type = \
368 std::conditional_t<std::is_floating_point<T>::value, T, half>; \
369 auto result = static_cast<type>(val); \
370 result _opeq static_cast<type>(hf); \
371 return result; \
372 }
373
374 HALF_FRIEND_OPERATOR(+, +=)
375 HALF_FRIEND_OPERATOR(-, -=)
376 HALF_FRIEND_OPERATOR(*, *=)
377 HALF_FRIEND_OPERATOR(/, /=)
378
379#undef HALF_FRIEND_OPERATOR
380
381 // the negative
382 half operator-() const
383 {
384 auto val = 0.0f - *this;
385 return static_cast<half>(val);
386 }
387
388private:
389 using f16_traits = detail::float_traits<half>;
390 using f32_traits = detail::float_traits<float>;
391
392 void float2half(const float& val) noexcept
393 {
394 std::uint32_t bit_val(0);
395 std::memcpy(&bit_val, &val, sizeof(float));
396 data_ = float2half(bit_val);
397 }
398
399 static constexpr std::uint16_t float2half(std::uint32_t data_) noexcept
400 {
401 using conv = detail::precision_converter<float, half>;
402 if (f32_traits::is_inf(data_)) {
403 return conv::shift_sign(data_) | f16_traits::exponent_mask;
404 } else if (f32_traits::is_nan(data_)) {
405 return conv::shift_sign(data_) | f16_traits::exponent_mask |
406 f16_traits::significand_mask;
407 } else {
408 const auto exp = conv::shift_exponent(data_);
409 if (f16_traits::is_inf(exp)) {
410 return conv::shift_sign(data_) | exp;
411 } else if (f16_traits::is_denom(exp)) {
412 // TODO: handle denormals
413 return conv::shift_sign(data_);
414 } else {
415 // Rounding to even
416 const auto result = conv::shift_sign(data_) | exp |
417 conv::shift_significand(data_);
418 const auto tail =
419 data_ & static_cast<f32_traits::bits_type>(
420 (1 << conv::significand_offset) - 1);
421
422 constexpr auto half = static_cast<f32_traits::bits_type>(
423 1 << (conv::significand_offset - 1));
424 return result +
425 (tail > half || ((tail == half) && (result & 1)));
426 }
427 }
428 }
429
430 static constexpr std::uint32_t half2float(std::uint16_t data_) noexcept
431 {
432 using conv = detail::precision_converter<half, float>;
433 if (f16_traits::is_inf(data_)) {
434 return conv::shift_sign(data_) | f32_traits::exponent_mask;
435 } else if (f16_traits::is_nan(data_)) {
436 return conv::shift_sign(data_) | f32_traits::exponent_mask |
437 f32_traits::significand_mask;
438 } else if (f16_traits::is_denom(data_)) {
439 // TODO: handle denormals
440 return conv::shift_sign(data_);
441 } else {
442 return conv::shift_sign(data_) | conv::shift_exponent(data_) |
443 conv::shift_significand(data_);
444 }
445 }
446
447 std::uint16_t data_;
448};
449
450
451} // namespace gko
452
453
454namespace std {
455
456
457template <>
458class complex<gko::half> {
459public:
460 using value_type = gko::half;
461
462 complex(const value_type& real = value_type(0.f),
463 const value_type& imag = value_type(0.f))
464 : real_(real), imag_(imag)
465 {}
466
467 template <typename T, typename U,
468 typename = std::enable_if_t<std::is_scalar<T>::value &&
469 std::is_scalar<U>::value>>
470 explicit complex(const T& real, const U& imag)
471 : real_(static_cast<value_type>(real)),
472 imag_(static_cast<value_type>(imag))
473 {}
474
475 template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
476 complex(const T& real)
477 : real_(static_cast<value_type>(real)),
478 imag_(static_cast<value_type>(0.f))
479 {}
480
481 // When using complex(real, imag), MSVC with CUDA try to recognize the
482 // complex is a member not constructor.
483 template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
484 explicit complex(const complex<T>& other)
485 : real_(static_cast<value_type>(other.real())),
486 imag_(static_cast<value_type>(other.imag()))
487 {}
488
489 value_type real() const noexcept { return real_; }
490
491 value_type imag() const noexcept { return imag_; }
492
493 operator std::complex<float>() const noexcept
494 {
495 return std::complex<float>(static_cast<float>(real_),
496 static_cast<float>(imag_));
497 }
498
499 template <typename V>
500 complex& operator=(const V& val)
501 {
502 real_ = val;
503 imag_ = value_type();
504 return *this;
505 }
506
507 template <typename V>
508 complex& operator=(const std::complex<V>& val)
509 {
510 real_ = val.real();
511 imag_ = val.imag();
512 return *this;
513 }
514
515 complex& operator+=(const value_type& real)
516 {
517 real_ += real;
518 return *this;
519 }
520
521 complex& operator-=(const value_type& real)
522 {
523 real_ -= real;
524 return *this;
525 }
526
527 complex& operator*=(const value_type& real)
528 {
529 real_ *= real;
530 imag_ *= real;
531 return *this;
532 }
533
534 complex& operator/=(const value_type& real)
535 {
536 real_ /= real;
537 imag_ /= real;
538 return *this;
539 }
540
541 template <typename T>
542 complex& operator+=(const complex<T>& val)
543 {
544 real_ += val.real();
545 imag_ += val.imag();
546 return *this;
547 }
548
549 template <typename T>
550 complex& operator-=(const complex<T>& val)
551 {
552 real_ -= val.real();
553 imag_ -= val.imag();
554 return *this;
555 }
556
557 template <typename T>
558 complex& operator*=(const complex<T>& val)
559 {
560 auto val_f = static_cast<std::complex<float>>(val);
561 auto result_f = static_cast<std::complex<float>>(*this);
562 result_f *= val_f;
563 real_ = result_f.real();
564 imag_ = result_f.imag();
565 return *this;
566 }
567
568 template <typename T>
569 complex& operator/=(const complex<T>& val)
570 {
571 auto val_f = static_cast<std::complex<float>>(val);
572 auto result_f = static_cast<std::complex<float>>(*this);
573 result_f /= val_f;
574 real_ = result_f.real();
575 imag_ = result_f.imag();
576 return *this;
577 }
578
579#define COMPLEX_HALF_OPERATOR(_op, _opeq) \
580 friend complex operator _op(const complex& lhf, const complex& rhf) \
581 { \
582 auto a = lhf; \
583 a _opeq rhf; \
584 return a; \
585 }
586
587 COMPLEX_HALF_OPERATOR(+, +=)
588 COMPLEX_HALF_OPERATOR(-, -=)
589 COMPLEX_HALF_OPERATOR(*, *=)
590 COMPLEX_HALF_OPERATOR(/, /=)
591
592#undef COMPLEX_HALF_OPERATOR
593
594private:
595 value_type real_;
596 value_type imag_;
597};
598
599
600template <>
601struct numeric_limits<gko::half> {
602 static constexpr bool is_specialized{true};
603 static constexpr bool is_signed{true};
604 static constexpr bool is_integer{false};
605 static constexpr bool is_exact{false};
606 static constexpr bool is_bounded{true};
607 static constexpr bool is_modulo{false};
608 static constexpr int digits{
609 gko::detail::float_traits<gko::half>::significand_bits + 1};
610 // 3/10 is approx. log_10(2)
611 static constexpr int digits10{digits * 3 / 10};
612
613 static constexpr gko::half epsilon()
614 {
615 constexpr auto bits = static_cast<std::uint16_t>(0b0'00101'0000000000u);
616 return gko::half::create_from_bits(bits);
617 }
618
619 static constexpr gko::half infinity()
620 {
621 constexpr auto bits = static_cast<std::uint16_t>(0b0'11111'0000000000u);
622 return gko::half::create_from_bits(bits);
623 }
624
625 static constexpr gko::half min()
626 {
627 constexpr auto bits = static_cast<std::uint16_t>(0b0'00001'0000000000u);
628 return gko::half::create_from_bits(bits);
629 }
630
631 static constexpr gko::half max()
632 {
633 constexpr auto bits = static_cast<std::uint16_t>(0b0'11110'1111111111u);
634 return gko::half::create_from_bits(bits);
635 }
636
637 static constexpr gko::half lowest()
638 {
639 constexpr auto bits = static_cast<std::uint16_t>(0b1'11110'1111111111u);
640 return gko::half::create_from_bits(bits);
641 };
642
643 static constexpr gko::half quiet_NaN()
644 {
645 constexpr auto bits = static_cast<std::uint16_t>(0b0'11111'1111111111u);
646 return gko::half::create_from_bits(bits);
647 }
648};
649
650
651// complex using a template on operator= for any kind of complex<T>, so we can
652// do full specialization for half
653template <>
654inline complex<double>& complex<double>::operator=(
655 const std::complex<gko::half>& a)
656{
657 complex<double> t(a.real(), a.imag());
658 operator=(t);
659 return *this;
660}
661
662
663// For MSVC
664template <>
665inline complex<float>& complex<float>::operator=(
666 const std::complex<gko::half>& a)
667{
668 complex<float> t(a.real(), a.imag());
669 operator=(t);
670 return *this;
671}
672
673
674} // namespace std
675
676
677#endif // GKO_PUBLIC_CORE_BASE_HALF_HPP_
A class providing basic support for half precision floating point types.
Definition half.hpp:286
Definition half.hpp:23
The Ginkgo namespace.
Definition abstract_factory.hpp:20
constexpr size_type byte_size
Number of bits in a byte.
Definition types.hpp:177
STL namespace.