Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_gemm_multiply_add.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/wmma_matrix.h"
31 #ifdef CUTLASS_USE_WMMA_API
32 #include "cutlass/fragment.h"
33 
34 namespace cutlass {
35 namespace gemm {
36 
38 
39 template <MatrixLayout::Kind kLayoutA_,
40  typename ScalarA_,
41  MatrixLayout::Kind kLayoutB_,
42  typename ScalarB_,
43  MatrixLayout::Kind kLayoutC_,
44  typename ScalarC_,
45  typename WarpGemmShape_,
46  typename InstructionShape_>
47 struct WmmaGemmMultiplyAdd {
49  typedef InstructionShape_ InstructionShape;
51  typedef Shape<1, InstructionShape_::kH, InstructionShape_::kW> ThreadsPerWarp;
53  typedef WarpGemmShape_ WarpGemmShape;
55  typedef WarpGemmShape_ AccumulatorsPerWarp;
57  typedef ScalarA_ ScalarA;
59  typedef ScalarB_ ScalarB;
61  typedef ScalarC_ ScalarC;
64 
66  typedef WmmaMatrix<GemmOperand::kA, kLayoutA_, ScalarA, InstructionShape> ElementA;
68  typedef Fragment<ElementA, Iterations::kW> FragmentA;
69 
71  typedef WmmaMatrix<GemmOperand::kB, kLayoutB_, ScalarB, InstructionShape> ElementB;
73  typedef Fragment<ElementB, Iterations::kH> FragmentB;
74 
76  typedef WmmaMatrix<GemmOperand::kC, kLayoutC_, ScalarC, InstructionShape> ElementC;
78  typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
79 
81  CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
82 
84  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
85  FragmentB const& b,
86  Accumulators const& c,
87  Accumulators& d) {
88  for (int j = 0; j < Iterations::kH; ++j) {
89  for (int i = 0; i < Iterations::kW; ++i) {
90  // The input elements.
91  ElementA const& elt_a = a[i];
92  ElementB const& elt_b = b[j];
93  ElementC const& elt_c = c[j * Iterations::kW + i];
94 
95  // The output element.
96  ElementC& elt_d = d[j * Iterations::kW + i];
97 
98  // The wmma instruction.
99  nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
100  }
101  }
102  }
103 };
104 
106 
107 #ifdef CUTLASS_USE_SUBBYTE_WMMA
108 template<typename WarpGemmShape_>
110 struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
111  Vector<bin1_t, 32>,
113  Vector<bin1_t, 32>,
115  int,
116  WarpGemmShape_,
117  Shape<128, 8, 8> >{
119  typedef Shape<128, 8, 8> InstructionShape;
121  typedef Shape<1, 4, 8> ThreadsPerWarp;
123  typedef WarpGemmShape_ WarpGemmShape;
125  typedef WarpGemmShape_ AccumulatorsPerWarp;
127  typedef Vector<bin1_t, 32> ScalarA;
129  typedef Vector<bin1_t, 32> ScalarB;
131  typedef int ScalarC;
133  typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
134 
136  typedef WmmaMatrix<GemmOperand::kA,
138  Vector<bin1_t, 32>,
139  InstructionShape> ElementA;
141  typedef Fragment<ElementA, Iterations::kW> FragmentA;
142 
144  typedef WmmaMatrix<GemmOperand::kB,
146  Vector<bin1_t, 32>,
147  InstructionShape> ElementB;
149  typedef Fragment<ElementB, Iterations::kH> FragmentB;
150 
152  typedef WmmaMatrix<GemmOperand::kC,
154  int,
155  InstructionShape> ElementC;
157  typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
158 
160  CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
161 
163  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
164  FragmentB const& b,
165  Accumulators const& c,
166  Accumulators& d) {
167  for (int j = 0; j < Iterations::kH; ++j) {
168  for (int i = 0; i < Iterations::kW; ++i) {
169  // The input elements.
170  ElementA const& elt_a = a[i];
171  ElementB const& elt_b = b[j];
172  ElementC const& elt_c = c[j * Iterations::kW + i];
173 
174  // The output element.
175  ElementC& elt_d = d[j * Iterations::kW + i];
176 
177  // The wmma instruction.
178  nvcuda::wmma::bmma_sync(elt_d,
179  elt_a,
180  elt_b,
181  elt_c,
182  nvcuda::wmma::experimental::bmmaBitOpXOR,
183  nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
184  }
185  }
186  }
187 };
188 #endif
189 
191 
192 #ifdef CUTLASS_USE_SUBBYTE_WMMA
193 template<typename WarpGemmShape_>
195 struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
196  Vector<int4_t, 8>,
198  Vector<int4_t, 8>,
200  int,
201  WarpGemmShape_,
202  Shape<32, 8, 8> >{
204  typedef Shape<32, 8, 8> InstructionShape;
206  typedef Shape<1, 4, 8> ThreadsPerWarp;
208  typedef WarpGemmShape_ WarpGemmShape;
210  typedef WarpGemmShape_ AccumulatorsPerWarp;
212  typedef Vector<int4_t, 8> ScalarA;
214  typedef Vector<int4_t, 8> ScalarB;
216  typedef int ScalarC;
218  typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
219 
221  typedef WmmaMatrix<GemmOperand::kA,
223  Vector<int4_t, 8>,
224  InstructionShape> ElementA;
226  typedef Fragment<ElementA, Iterations::kW> FragmentA;
227 
229  typedef WmmaMatrix<GemmOperand::kB,
231  Vector<int4_t, 8>,
232  InstructionShape> ElementB;
234  typedef Fragment<ElementB, Iterations::kH> FragmentB;
235 
237  typedef WmmaMatrix<GemmOperand::kC,
239  int,
240  InstructionShape> ElementC;
242  typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
243 
245  CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
246 
248  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
249  FragmentB const& b,
250  Accumulators const& c,
251  Accumulators& d) {
252  for (int j = 0; j < Iterations::kH; ++j) {
253  for (int i = 0; i < Iterations::kW; ++i) {
254  // The input elements.
255  ElementA const& elt_a = a[i];
256  ElementB const& elt_b = b[j];
257  ElementC const& elt_c = c[j * Iterations::kW + i];
258 
259  // The output element.
260  ElementC& elt_d = d[j * Iterations::kW + i];
261 
262  // The wmma instruction.
263  nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
264  }
265  }
266  }
267 };
268 #endif
269 
271 
272 #ifdef CUTLASS_USE_SUBBYTE_WMMA
273 template<typename WarpGemmShape_>
275 struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
276  Vector<uint4_t, 8>,
278  Vector<uint4_t, 8>,
280  int,
281  WarpGemmShape_,
282  Shape<32, 8, 8> >{
284  typedef Shape<32, 8, 8> InstructionShape;
286  typedef Shape<1, 4, 8> ThreadsPerWarp;
288  typedef WarpGemmShape_ WarpGemmShape;
290  typedef WarpGemmShape_ AccumulatorsPerWarp;
292  typedef Vector<uint4_t, 8> ScalarA;
294  typedef Vector<uint4_t, 8> ScalarB;
296  typedef int ScalarC;
298  typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
299 
301  typedef WmmaMatrix<GemmOperand::kA,
303  Vector<uint4_t, 8>,
304  InstructionShape> ElementA;
306  typedef Fragment<ElementA, Iterations::kW> FragmentA;
307 
309  typedef WmmaMatrix<GemmOperand::kB,
311  Vector<uint4_t, 8>,
312  InstructionShape> ElementB;
314  typedef Fragment<ElementB, Iterations::kH> FragmentB;
315 
317  typedef WmmaMatrix<GemmOperand::kC,
319  int,
320  InstructionShape> ElementC;
322  typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
323 
325  CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
326 
328  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
329  FragmentB const& b,
330  Accumulators const& c,
331  Accumulators& d) {
332  for (int j = 0; j < Iterations::kH; ++j) {
333  for (int i = 0; i < Iterations::kW; ++i) {
334  // The input elements.
335  ElementA const& elt_a = a[i];
336  ElementB const& elt_b = b[j];
337  ElementC const& elt_c = c[j * Iterations::kW + i];
338 
339  // The output element.
340  ElementC& elt_d = d[j * Iterations::kW + i];
341 
342  // The wmma instruction.
343  nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
344  }
345  }
346  }
347 };
348 #endif
349 
351 
352 } // namespace gemm
353 } // namespace cutlass
354 
355 #endif // defined CUTLASS_USE_WMMA_API
Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: convert.h:33
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Definition: matrix_traits.h:357
Shape< A_::kD/B_::kD, A_::kH/B_::kH, A_::kW/B_::kW, A_::kC/B_::kC > Shape
Definition: shape.h:126
Definition: matrix_traits.h:357
Definition: matrix_traits.h:159
Definition: matrix_traits.h:159
Definition: matrix_traits.h:357
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...