37 template <
typename ScalarAlphaBeta_,
38 typename ScalarAccum_,
53 template <
typename FragmentB_,
typename FragmentCd_>
55 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 56 int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
57 for (
int j = 0; j < FragmentCd_::kElements; ++j) {
58 d[j] = b[j * kReduction + 0];
59 for (
int k = 1; k < kReduction; ++k) {
60 d[j] += b[j * kReduction + k];
68 template <
typename FragmentB_,
typename FragmentCd_>
73 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 74 int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
75 for (
int j = 0; j < FragmentCd_::kElements; ++j) {
76 d[j] = b[j * kReduction + 0];
77 for (
int k = 1; k < kReduction; ++k) {
78 d[j] += b[j * kReduction + k];
88 #if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16) 102 template <
typename FragmentB_,
typename FragmentCd_>
103 CUTLASS_DEVICE
void multiply(half a, FragmentB_
const& b, FragmentCd_& d) {
104 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 106 __half2
const* b_half2 =
reinterpret_cast<__half2 const*
>(&b[0]);
108 __half2* d_half2 =
reinterpret_cast<__half2*
>(&d[0]);
111 __half2
const a_half2 = __half2half2(a);
113 int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
115 for (
int j = 0; j < FragmentCd_::kElements / 2; ++j) {
116 d_half2[j] = __hmul2(a_half2, b_half2[j * kReduction + 0]);
118 for (
int k = 1; k < kReduction; ++k) {
119 d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
127 template <
typename FragmentB_,
typename FragmentCd_>
130 FragmentCd_
const& c,
132 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 134 __half2
const* b_half2 =
reinterpret_cast<__half2 const*
>(&b[0]);
135 __half2
const* c_half2 =
reinterpret_cast<__half2 const*
>(&c[0]);
137 __half2* d_half2 =
reinterpret_cast<__half2*
>(&d[0]);
140 __half2
const a_half2 = __half2half2(a);
142 int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
143 for (
int j = 0; j < FragmentCd_::kElements / 2; ++j) {
144 d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + 0], c_half2[j]);
146 for (
int k = 1; k < kReduction; ++k) {
147 d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
CUTLASS_DEVICE void multiply(ScalarAlphaBeta a, FragmentB_ const &b, FragmentCd_ &d)
Multiply : d = a*b.
Definition: fragment_multiply_add.h:54
Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: fragment_multiply_add.h:92
half ScalarAlphaBeta
The type for alpha and beta.
Definition: fragment_multiply_add.h:94
CUTLASS_DEVICE FragmentMultiplyAdd()
Ctor.
Definition: fragment_multiply_add.h:50
CUTLASS_DEVICE FragmentMultiplyAdd()
Ctor.
Definition: fragment_multiply_add.h:99
CUTLASS_DEVICE void multiply(half a, FragmentB_ const &b, FragmentCd_ &d)
Multiply : d = a*b.
Definition: fragment_multiply_add.h:103
ScalarAccum_ ScalarAccum
The type for accumlator.
Definition: fragment_multiply_add.h:47
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
ScalarAlphaBeta_ ScalarAlphaBeta
The type for alpha and beta.
Definition: fragment_multiply_add.h:45
CUTLASS_DEVICE void multiply_add(half a, FragmentB_ const &b, FragmentCd_ const &c, FragmentCd_ &d)
Multiply : d = a*b + c.
Definition: fragment_multiply_add.h:128
Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: fragment_multiply_add.h:43
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...
half ScalarAccum
The type for accumlator.
Definition: fragment_multiply_add.h:96
CUTLASS_DEVICE void multiply_add(ScalarAlphaBeta a, FragmentB_ const &b, FragmentCd_ const &c, FragmentCd_ &d)
Multiply : d = a*b + c.
Definition: fragment_multiply_add.h:69
Definition: fragment_multiply_add.h:41