Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
hgemm_multiply_add.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/fragment.h"
32 
34 
35 namespace cutlass {
36 namespace gemm {
37 
39 
41 template <typename ThreadGemmShape_, typename ThreadsPerWarp_>
42 struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, half> {
46  typedef ThreadGemmShape_ ThreadGemmShape;
50  typedef ThreadsPerWarp_ ThreadsPerWarp;
54  typedef half ScalarA;
58  typedef half ScalarB;
62  typedef half ScalarC;
65 
67  static_assert(AccumulatorsPerThread::kH % 2 == 0, "Invalid size");
68  static_assert(AccumulatorsPerThread::kW % 2 == 0, "Invalid size");
69 
71  CUTLASS_DEVICE ThreadMultiplyAdd() {}
72 
74  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
75  FragmentB const& b,
76  Accumulators const& c,
77  Accumulators& d) {
78 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
79  // The inputs.
80  __half2 const* a_half2 = reinterpret_cast<__half2 const*>(&a[0]);
81  __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
82  __half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
83 
84  // The output.
85  __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
86 
87  for (int j = 0; j < AccumulatorsPerThread::kH / 2; ++j) {
88  for (int i = 0; i < AccumulatorsPerThread::kW / 2; ++i) {
89  // The offsets in the output fragment.
90  int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i;
91  int const k1 = (2 * j + 1) * (AccumulatorsPerThread::kW / 2) + i;
92 
93  // Compute the product a[i] * b[j].low.
94  d_half2[k0] = __hfma2(a_half2[i], __low2half2(b_half2[j]), c_half2[k0]);
95  // Compute the product a[i] * b[j].high.
96  d_half2[k1] = __hfma2(a_half2[i], __high2half2(b_half2[j]), c_half2[k1]);
97  }
98  }
99 #endif
100  }
101 };
102 
104 
105 } // namespace gemm
106 } // namespace cutlass
CUTLASS_DEVICE ThreadMultiplyAdd()
Make sure there&#39;s an even number of elements in both dimensions.
Definition: hgemm_multiply_add.h:71
half ScalarC
The type for C and D.
Definition: hgemm_multiply_add.h:62
Definition: convert.h:33
Fragment< ScalarB, AccumulatorsPerThread::kH > FragmentB
The fragment for B.
Definition: hgemm_multiply_add.h:60
ThreadGemmShape_ ThreadGemmShape
The number of accumulators per thread.
Definition: hgemm_multiply_add.h:46
Shape< A_::kD *B_::kD, A_::kH *B_::kH, A_::kW *B_::kW, A_::kC *B_::kC > Shape
Definition: shape.h:119
A template defining Fragment Concept.
Definition: fragment.h:99
Template implementing matrix multiply-add operations on fragments.
Shape< 1, 1, 2, 1 > InstructionShape
The shape of the instruction.
Definition: hgemm_multiply_add.h:44
ShapeMul< ThreadGemmShape, ThreadsPerWarp >::Shape AccumulatorsPerWarp
The number of accumulators per warp.
Definition: hgemm_multiply_add.h:52
Fragment< ScalarA, AccumulatorsPerThread::kW > FragmentA
The fragment for A.
Definition: hgemm_multiply_add.h:56
Fragment< half, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW > Accumulators
The accumulators.
Definition: hgemm_multiply_add.h:64
CUTLASS_DEVICE void multiply_add(FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d)
Multiply : d = a*b + c.
Definition: hgemm_multiply_add.h:74
#define static_assert(__e, __m)
Definition: platform.h:153
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Template performing matrix multiply-add operation within a thread.
Definition: thread_multiply_add.h:44
ThreadGemmShape AccumulatorsPerThread
Aliased for compatibility. Will be removed for CUTLASS v2.0.
Definition: hgemm_multiply_add.h:48
ThreadsPerWarp_ ThreadsPerWarp
The number of threads per warp.
Definition: hgemm_multiply_add.h:50
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...