Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_gemm_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/wmma_matrix.h"
31 #ifdef CUTLASS_USE_WMMA_API
32 
33 #include "cutlass/convert.h"
34 #include "cutlass/gemm/gemm.h"
43 
44 namespace cutlass {
45 namespace gemm {
46 
48 
49  template <
51  MatrixLayout::Kind kLayoutA_,
53  MatrixLayout::Kind kLayoutB_,
55  typename OutputTile_,
57  typename ScalarA_,
59  typename ScalarB_,
61  typename ScalarC_,
63  typename Accumulator_,
65  typename WarpGemmShape_,
67  typename InstructionShape_,
69  int kScalarsPerLdgA_,
71  int kScalarsPerLdgB_,
73  int KScalarsPerLdsA_,
75  int KscalarsPerLdsB_,
77  int kScalarsPerLdgCAndStgD_,
79  int kScalarsPerStsD_,
81  int kScalarsPerLdsD_
82 >
83 struct WmmaGemmConfig : public GemmConfig<
85  ScalarA_,
87  ScalarB_,
89  ScalarC_,
91  ScalarC_,
93  OutputTile_,
95  WmmaGemmMultiplyAdd<kLayoutA_,
96  ScalarA_,
97  kLayoutB_,
98  ScalarB_,
99  MatrixLayout::kColumnMajor,
100  Accumulator_,
101  WarpGemmShape_,
102  InstructionShape_>,
104  kScalarsPerLdgA_,
106  kScalarsPerLdgA_,
108  KScalarsPerLdsA_,
110  kScalarsPerLdgB_,
112  kScalarsPerLdgB_,
114  KscalarsPerLdsB_,
116  kScalarsPerLdgCAndStgD_,
118  kScalarsPerStsD_,
120  kScalarsPerLdsD_,
122  1,
124  false,
126  true,
128  false> {};
129 
131 
132 template <enum MatrixLayout::Kind kLayout_,
133  typename GemmConfig_,
134  typename ScalarA_>
135 struct WmmaGemmTileTraitsHelperA {};
136 
138 
139 template <typename GemmConfig_, typename ScalarA_>
140 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_, ScalarA_>
141  : public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
143  typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
144 
146  static int const kSkew = 16 / sizeof(typename Base::MultiplyAddScalar);
148  typedef Shape<GemmConfig_::kStages,
149  GemmConfig_::OutputTile::kD,
150  GemmConfig_::OutputTile::kW + kSkew>
151  Tile;
152 
154  typedef WmmaMatrix<GemmOperand::kA,
156  typename Base::MultiplyAddScalar,
157  typename GemmConfig_::InstructionShape>
158  WmmaMatrix;
159 
161  typedef GemmSharedStoreTileAbTraits<
162  // The pointer.
163  typename Base::MultiplyAddScalar,
164  // The tile has size KxM in GEMM's terminology.
165  Tile,
166  // The threads are distributed as warps x 32 (the traits may reorganize).
167  typename Base::GlobalTileTraits::Threads,
168  // The number of scalars per STS (STS.32 or STS.128, etc).
169  GemmConfig_::kScalarsPerStsA>
170  SharedStoreTileTraits;
171 
173  static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
175  static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
177  typedef WmmaGemmSharedLoadTileATraits<
178  // The layout of the matrix.
180  // The pointer.
181  typename Base::MultiplyAddScalar,
182  // The output tile size.
183  Tile,
184  // The number of warps.
185  typename GemmConfig_::Warps,
186  // The strides between warps.
187  GemmConfig_::InstructionShape::kW,
188  // The number of iterations to load the data.
189  Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
190  // The stride between iterations.
191  Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
192  // The shape of the instruction.
193  typename GemmConfig_::InstructionShape>
194  SharedLoadTileTraits;
195 };
196 
198 
199 template <typename GemmConfig_, typename ScalarA_>
200 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, ScalarA_> {
202  static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
203 
205  typedef typename GemmConfig_::ScalarA Scalar;
207  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
208 
210  typedef WmmaMatrix<GemmOperand::kA,
212  MultiplyAddScalar,
213  typename GemmConfig_::InstructionShape>
214  WmmaMatrix;
215 
217  typedef GemmGlobalTileTraits<
218  // That's A.
220  // A is row-major.
222  // The pointer is float const.
223  Scalar const,
224  // The tile has size KxM in GEMM's terminology.
225  Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
226  // The threads are distributed as warps x 32 (the traits may reorganize).
227  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
228  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
229  GemmConfig_::kScalarsPerLdgA>
230  GlobalTileTraits;
231 
233  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
235  typedef Shape<GemmConfig_::kStages,
236  GemmConfig_::OutputTile::kW,
237  GemmConfig_::OutputTile::kD + kSkew>
238  Tile;
239 
241  typedef GemmSharedStoreTileAbTraits<
242  // The pointer.
243  MultiplyAddScalar,
244  // The tile has size KxM in GEMM's terminology.
245  Tile,
246  // The threads are distributed as warps x 32 (the traits may reorganize).
247  typename GlobalTileTraits::Threads,
248  // The number of scalars per STS (STS.32 or STS.128, etc).
249  GemmConfig_::kScalarsPerStsA>
250  SharedStoreTileTraits;
251 
253  static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
255  typedef WmmaGemmSharedLoadTileATraits<
256  // The layout of the matrix.
258  // The pointer.
259  MultiplyAddScalar,
260  // The tile in shared memory.
261  Tile,
262  // The number of warps.
263  typename GemmConfig_::Warps,
264  // The strides between warps.
265  GemmConfig_::InstructionShape::kW * Tile::kW,
266  // The number of iterations to load the data.
267  Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
268  // The stride between iterations.
269  Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
270  // The shape of the instruction.
271  typename GemmConfig_::InstructionShape>
272  SharedLoadTileTraits;
273 };
274 
276 
277 #ifdef CUTLASS_USE_SUBBYTE_WMMA
278 template <typename GemmConfig_>
280 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Vector<bin1_t, 32> > {
282  static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
283 
285  typedef typename GemmConfig_::ScalarA Scalar;
287  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
288 
291  static int const kBitsPerScalar = sizeof(Scalar) * 8;
292 
294  typedef WmmaMatrix<GemmOperand::kA,
296  Vector<bin1_t, 32>,
297  typename GemmConfig_::InstructionShape>
298  WmmaMatrix;
299 
301  typedef GemmGlobalTileTraits<
302  // That's A.
304  // A is row-major.
306  // The pointer is float const.
307  Scalar const,
308  // The tile has size KxM in GEMM's terminology.
309  Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kBitsPerScalar>,
310  // The threads are distributed as warps x 32 (the traits may reorganize).
311  Shape<1,
312  GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kBitsPerScalar),
313  GemmConfig_::OutputTile::kD / kBitsPerScalar>,
314  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
315  GemmConfig_::kScalarsPerLdgA / kBitsPerScalar>
316  GlobalTileTraits;
317 
319  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
321  typedef Shape<GemmConfig_::kStages,
322  GemmConfig_::OutputTile::kW,
323  GemmConfig_::OutputTile::kD / kBitsPerScalar + kSkew>
324  Tile;
325 
327  typedef GemmSharedStoreTileAbTraits<
328  // The pointer.
329  MultiplyAddScalar,
330  // The tile has size KxM in GEMM's terminology.
331  Tile,
332  // The threads are distributed as warps x 32 (the traits may reorganize).
333  typename GlobalTileTraits::Threads,
334  // The number of scalars per STS (STS.32 or STS.128, etc).
335  GemmConfig_::kScalarsPerStsA / kBitsPerScalar>
336  SharedStoreTileTraits;
337 
339  static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
341  typedef WmmaGemmSharedLoadTileATraits<
342  // The layout of the matrix.
344  // The pointer.
345  MultiplyAddScalar,
346  // The tile in shared memory.
347  Tile,
348  // The number of warps.
349  typename GemmConfig_::Warps,
350  // The strides between warps.
351  GemmConfig_::InstructionShape::kW * Tile::kW,
352  // The number of iterations to load the data.
353  Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
354  // The stride between iterations.
355  Shape<GemmConfig_::InstructionShape::kD / kBitsPerScalar, 0, kScalarsPerW * Tile::kW>,
356  // The shape of the instruction.
357  typename GemmConfig_::InstructionShape>
358  SharedLoadTileTraits;
359 };
360 #endif
361 
363 
364 #ifdef CUTLASS_USE_SUBBYTE_WMMA
365 template <typename GemmConfig_>
367 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Vector<uint4_t, 8> > {
369  static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
370 
372  typedef typename GemmConfig_::ScalarA Scalar;
374  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
375 
378  static int const kInt4PerScalar = sizeof(Scalar) * 2;
379 
381  typedef WmmaMatrix<GemmOperand::kA,
383  Vector<uint4_t, 8>,
384  typename GemmConfig_::InstructionShape>
385  WmmaMatrix;
386 
388  typedef GemmGlobalTileTraits<
389  // That's A.
391  // A is row-major.
393  // The pointer is float const.
394  Scalar const,
395  // The tile has size KxM in GEMM's terminology.
396  Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
397  // The threads are distributed as warps x 32 (the traits may reorganize).
398  Shape<1,
399  GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
400  GemmConfig_::OutputTile::kD / kInt4PerScalar>,
401  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
402  GemmConfig_::kScalarsPerLdgA / kInt4PerScalar>
403  GlobalTileTraits;
404 
406  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
408  typedef Shape<GemmConfig_::kStages,
409  GemmConfig_::OutputTile::kW,
410  GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
411  Tile;
412 
414  typedef GemmSharedStoreTileAbTraits<
415  // The pointer.
416  MultiplyAddScalar,
417  // The tile has size KxM in GEMM's terminology.
418  Tile,
419  // The threads are distributed as warps x 32 (the traits may reorganize).
420  typename GlobalTileTraits::Threads,
421  // The number of scalars per STS (STS.32 or STS.128, etc).
422  GemmConfig_::kScalarsPerStsA / kInt4PerScalar>
423  SharedStoreTileTraits;
424 
426  static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
428  typedef WmmaGemmSharedLoadTileATraits<
429  // The layout of the matrix.
431  // The pointer.
432  MultiplyAddScalar,
433  // The tile in shared memory.
434  Tile,
435  // The number of warps.
436  typename GemmConfig_::Warps,
437  // The strides between warps.
438  GemmConfig_::InstructionShape::kW * Tile::kW,
439  // The number of iterations to load the data.
440  Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
441  // The stride between iterations.
442  Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
443  // The shape of the instruction.
444  typename GemmConfig_::InstructionShape>
445  SharedLoadTileTraits;
446 };
447 #endif
448 
450 
451 #ifdef CUTLASS_USE_SUBBYTE_WMMA
452 template <typename GemmConfig_>
454 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Vector<int4_t, 8> > {
456  static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
457 
459  typedef typename GemmConfig_::ScalarA Scalar;
461  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
462 
465  static int const kInt4PerScalar = sizeof(Scalar) * 2;
466 
468  typedef WmmaMatrix<GemmOperand::kA,
470  Vector<int4_t, 8>,
471  typename GemmConfig_::InstructionShape>
472  WmmaMatrix;
473 
475  typedef GemmGlobalTileTraits<
476  // That's A.
478  // A is row-major.
480  // The pointer is float const.
481  Scalar const,
482  // The tile has size KxM in GEMM's terminology.
483  Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
484  // The threads are distributed as warps x 32 (the traits may reorganize).
485  Shape<1,
486  GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
487  GemmConfig_::OutputTile::kD / kInt4PerScalar>,
488  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
489  GemmConfig_::kScalarsPerLdgA / kInt4PerScalar>
490  GlobalTileTraits;
491 
493  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
495  typedef Shape<GemmConfig_::kStages,
496  GemmConfig_::OutputTile::kW,
497  GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
498  Tile;
499 
501  typedef GemmSharedStoreTileAbTraits<
502  // The pointer.
503  MultiplyAddScalar,
504  // The tile has size KxM in GEMM's terminology.
505  Tile,
506  // The threads are distributed as warps x 32 (the traits may reorganize).
507  typename GlobalTileTraits::Threads,
508  // The number of scalars per STS (STS.32 or STS.128, etc).
509  GemmConfig_::kScalarsPerStsA / kInt4PerScalar>
510  SharedStoreTileTraits;
511 
513  static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
515  typedef WmmaGemmSharedLoadTileATraits<
516  // The layout of the matrix.
518  // The pointer.
519  MultiplyAddScalar,
520  // The tile in shared memory.
521  Tile,
522  // The number of warps.
523  typename GemmConfig_::Warps,
524  // The strides between warps.
525  GemmConfig_::InstructionShape::kW * Tile::kW,
526  // The number of iterations to load the data.
527  Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
528  // The stride between iterations.
529  Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
530  // The shape of the instruction.
531  typename GemmConfig_::InstructionShape>
532  SharedLoadTileTraits;
533 };
534 #endif
535 
537 
538 template <enum MatrixLayout::Kind kLayout_,
539  typename GemmConfig_,
540  typename ScalarB_>
541 struct WmmaGemmTileTraitsHelperB {};
542 
544 
545 template <typename GemmConfig_, typename ScalarB_>
546 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_, ScalarB_>
547  : public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
549  typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
550 
552  static int const kSkew = 16 / sizeof(typename Base::MultiplyAddScalar);
554  typedef Shape<GemmConfig_::kStages,
555  GemmConfig_::OutputTile::kD,
556  GemmConfig_::OutputTile::kH + kSkew>
557  Tile;
558 
560  typedef WmmaMatrix<GemmOperand::kB,
562  typename Base::MultiplyAddScalar,
563  typename GemmConfig_::InstructionShape>
564  WmmaMatrix;
565 
567  typedef GemmSharedStoreTileAbTraits<
568  // The pointer.
569  typename Base::MultiplyAddScalar,
570  // The tile has size KxM in GEMM's terminology.
571  Tile,
572  // The threads are distributed as warps x 32 (the traits may reorganize).
573  typename Base::GlobalTileTraits::Threads,
574  // The number of scalars per STS (STS.32 or STS.128, etc).
575  GemmConfig_::kScalarsPerStsB>
576  SharedStoreTileTraits;
577 
579  static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
581  static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
583  typedef WmmaGemmSharedLoadTileBTraits<
584  // The layout of the matrix.
586  // The pointer.
587  typename Base::MultiplyAddScalar,
588  // The output tile size.
589  Tile,
590  // The number of warps.
591  typename GemmConfig_::Warps,
592  // The strides between warps.
593  GemmConfig_::InstructionShape::kH,
594  // The number of iterations to load the data.
595  Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
596  // The stride between iterations.
597  Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
598  // The shape of the instruction.
599  typename GemmConfig_::InstructionShape>
600  SharedLoadTileTraits;
601 };
602 
604 
605 template <typename GemmConfig_, typename ScalarB_>
606 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, ScalarB_> {
608  static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
609 
611  typedef typename GemmConfig_::ScalarB Scalar;
613  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
614 
616  typedef WmmaMatrix<GemmOperand::kB,
618  MultiplyAddScalar,
619  typename GemmConfig_::InstructionShape>
620  WmmaMatrix;
621 
623  typedef GemmGlobalTileTraits<
624  // That's B.
626  // A is row-major.
628  // The pointer is float const.
629  Scalar const,
630  // The tile has size KxM in GEMM's terminology.
631  Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
632  // The threads are distributed as warps x 32 (the traits may reorganize).
633  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
634  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
635  GemmConfig_::kScalarsPerLdgB>
636  GlobalTileTraits;
637 
639  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
641  typedef Shape<GemmConfig_::kStages,
642  GemmConfig_::OutputTile::kH,
643  GemmConfig_::OutputTile::kD + kSkew>
644  Tile;
645 
647  typedef GemmSharedStoreTileAbTraits<
648  // The pointer.
649  MultiplyAddScalar,
650  // The tile has size KxM in GEMM's terminology.
651  Tile,
652  // The threads are distributed as warps x 32 (the traits may reorganize).
653  typename GlobalTileTraits::Threads,
654  // The number of scalars per STS (STS.32 or STS.128, etc).
655  GemmConfig_::kScalarsPerStsB>
656  SharedStoreTileTraits;
657 
659  static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
661  typedef WmmaGemmSharedLoadTileBTraits<
662  // The layout of the matrix.
664  // The pointer.
665  MultiplyAddScalar,
666  // The tile in shared memory.
667  Tile,
668  // The number of warps.
669  typename GemmConfig_::Warps,
670  // The strides between warps.
671  GemmConfig_::InstructionShape::kH * Tile::kW,
672  // The number of iterations to load the data.
673  Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
674  // The stride between iterations.
675  Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
676  // The shape of the instruction.
677  typename GemmConfig_::InstructionShape>
678  SharedLoadTileTraits;
679 };
680 
682 
683 #ifdef CUTLASS_USE_SUBBYTE_WMMA
684 template <typename GemmConfig_>
686 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Vector<bin1_t, 32> > {
688  static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
689 
691  typedef typename GemmConfig_::ScalarB Scalar;
693  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
694 
697  static int const kBitsPerScalar = sizeof(Scalar) * 8;
698 
700  typedef WmmaMatrix<GemmOperand::kB,
702  Vector<bin1_t, 32>,
703  typename GemmConfig_::InstructionShape>
704  WmmaMatrix;
705 
707  typedef GemmGlobalTileTraits<
708  // That's B.
710  // A is row-major.
712  // The pointer is float const.
713  Scalar const,
714  // The tile has size KxM in GEMM's terminology.
715  Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kBitsPerScalar>,
716  // The threads are distributed as warps x 32 (the traits may reorganize).
717  Shape<1,
718  GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kBitsPerScalar),
719  GemmConfig_::OutputTile::kD / kBitsPerScalar>,
720  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
721  GemmConfig_::kScalarsPerLdgB / kBitsPerScalar>
722  GlobalTileTraits;
723 
725  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
727  typedef Shape<GemmConfig_::kStages,
728  GemmConfig_::OutputTile::kH,
729  GemmConfig_::OutputTile::kD / kBitsPerScalar + kSkew>
730  Tile;
731 
733  typedef GemmSharedStoreTileAbTraits<
734  // The pointer.
735  MultiplyAddScalar,
736  // The tile has size KxM in GEMM's terminology.
737  Tile,
738  // The threads are distributed as warps x 32 (the traits may reorganize).
739  typename GlobalTileTraits::Threads,
740  // The number of scalars per STS (STS.32 or STS.128, etc).
741  GemmConfig_::kScalarsPerStsB / kBitsPerScalar>
742  SharedStoreTileTraits;
743 
745  static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
747  typedef WmmaGemmSharedLoadTileBTraits<
748  // The layout of the matrix.
750  // The pointer.
751  MultiplyAddScalar,
752  // The tile in shared memory.
753  Tile,
754  // The number of warps.
755  typename GemmConfig_::Warps,
756  // The strides between warps.
757  GemmConfig_::InstructionShape::kH * Tile::kW,
758  // The number of iterations to load the data.
759  Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
760  // The stride between iterations.
761  Shape<GemmConfig_::InstructionShape::kD / kBitsPerScalar, 0, kScalarsPerW * Tile::kW>,
762  // The shape of the instruction.
763  typename GemmConfig_::InstructionShape>
764  SharedLoadTileTraits;
765 };
766 #endif
767 
769 
770 #ifdef CUTLASS_USE_SUBBYTE_WMMA
771 template <typename GemmConfig_>
773 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Vector<uint4_t, 8> > {
775  static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
776 
778  typedef typename GemmConfig_::ScalarB Scalar;
780  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
781 
784  static int const kInt4PerScalar = sizeof(Scalar) * 2;
785 
787  typedef WmmaMatrix<GemmOperand::kB,
789  Vector<uint4_t, 8>,
790  typename GemmConfig_::InstructionShape>
791  WmmaMatrix;
792 
794  typedef GemmGlobalTileTraits<
795  // That's B.
797  // A is row-major.
799  // The pointer is float const.
800  Scalar const,
801  // The tile has size KxM in GEMM's terminology.
802  Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
803  // The threads are distributed as warps x 32 (the traits may reorganize).
804  Shape<1,
805  GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
806  GemmConfig_::OutputTile::kD / kInt4PerScalar>,
807  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
808  GemmConfig_::kScalarsPerLdgB / kInt4PerScalar>
809  GlobalTileTraits;
810 
812  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
814  typedef Shape<GemmConfig_::kStages,
815  GemmConfig_::OutputTile::kH,
816  GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
817  Tile;
818 
820  typedef GemmSharedStoreTileAbTraits<
821  // The pointer.
822  MultiplyAddScalar,
823  // The tile has size KxM in GEMM's terminology.
824  Tile,
825  // The threads are distributed as warps x 32 (the traits may reorganize).
826  typename GlobalTileTraits::Threads,
827  // The number of scalars per STS (STS.32 or STS.128, etc).
828  GemmConfig_::kScalarsPerStsB / kInt4PerScalar>
829  SharedStoreTileTraits;
830 
832  static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
834  typedef WmmaGemmSharedLoadTileBTraits<
835  // The layout of the matrix.
837  // The pointer.
838  MultiplyAddScalar,
839  // The tile in shared memory.
840  Tile,
841  // The number of warps.
842  typename GemmConfig_::Warps,
843  // The strides between warps.
844  GemmConfig_::InstructionShape::kH * Tile::kW,
845  // The number of iterations to load the data.
846  Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
847  // The stride between iterations.
848  Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
849  // The shape of the instruction.
850  typename GemmConfig_::InstructionShape>
851  SharedLoadTileTraits;
852 };
853 #endif
854 
856 
857 #ifdef CUTLASS_USE_SUBBYTE_WMMA
858 template <typename GemmConfig_>
860 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Vector<int4_t, 8> > {
862  static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
863 
865  typedef typename GemmConfig_::ScalarB Scalar;
867  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
868 
871  static int const kInt4PerScalar = sizeof(Scalar) * 2;
872 
874  typedef WmmaMatrix<GemmOperand::kB,
876  Vector<int4_t, 8>,
877  typename GemmConfig_::InstructionShape>
878  WmmaMatrix;
879 
881  typedef GemmGlobalTileTraits<
882  // That's B.
884  // A is row-major.
886  // The pointer is float const.
887  Scalar const,
888  // The tile has size KxM in GEMM's terminology.
889  Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
890  // The threads are distributed as warps x 32 (the traits may reorganize).
891  Shape<1,
892  GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
893  GemmConfig_::OutputTile::kD / kInt4PerScalar>,
894  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
895  GemmConfig_::kScalarsPerLdgB / kInt4PerScalar>
896  GlobalTileTraits;
897 
899  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
901  typedef Shape<GemmConfig_::kStages,
902  GemmConfig_::OutputTile::kH,
903  GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
904  Tile;
905 
907  typedef GemmSharedStoreTileAbTraits<
908  // The pointer.
909  MultiplyAddScalar,
910  // The tile has size KxM in GEMM's terminology.
911  Tile,
912  // The threads are distributed as warps x 32 (the traits may reorganize).
913  typename GlobalTileTraits::Threads,
914  // The number of scalars per STS (STS.32 or STS.128, etc).
915  GemmConfig_::kScalarsPerStsB / kInt4PerScalar>
916  SharedStoreTileTraits;
917 
919  static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
921  typedef WmmaGemmSharedLoadTileBTraits<
922  // The layout of the matrix.
924  // The pointer.
925  MultiplyAddScalar,
926  // The tile in shared memory.
927  Tile,
928  // The number of warps.
929  typename GemmConfig_::Warps,
930  // The strides between warps.
931  GemmConfig_::InstructionShape::kH * Tile::kW,
932  // The number of iterations to load the data.
933  Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
934  // The stride between iterations.
935  Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
936  // The shape of the instruction.
937  typename GemmConfig_::InstructionShape>
938  SharedLoadTileTraits;
939 };
940 #endif
941 
943 
944 template <
946  MatrixLayout::Kind kLayoutA_,
948  MatrixLayout::Kind kLayoutB_,
950  typename OutputTile_,
952  typename ScalarA_,
954  typename ScalarB_,
956  typename ScalarC_,
958  typename Accumulator_,
960  typename EpilogueFunctor_,
962  typename WarpGemmShape_,
964  typename InstructionShape_,
966  int kScalarsPerLdgA_,
968  int kScalarsPerLdgB_,
970  int KScalarsPerLdsA_,
972  int KscalarsPerLdsB_,
974  int kScalarsPerLdgCAndStgD_,
976  int kScalarsPerStsD_,
978  int kScalarsPerLdsD_,
980  typename Index_>
981 struct WmmaGemmTraitsHelper {
983  typedef WmmaGemmConfig<kLayoutA_,
984  kLayoutB_,
985  OutputTile_,
986  ScalarA_,
987  ScalarB_,
988  ScalarC_,
989  Accumulator_,
990  WarpGemmShape_,
991  InstructionShape_,
992  kScalarsPerLdgA_,
993  kScalarsPerLdgB_,
994  KScalarsPerLdsA_,
995  KscalarsPerLdsB_,
996  kScalarsPerLdgCAndStgD_,
997  kScalarsPerStsD_,
998  kScalarsPerLdsD_
999  >
1000  GemmConfig;
1001 
1003  typedef WmmaGemmTileTraitsHelperA<kLayoutA_, GemmConfig, ScalarA_> GemmTileTraitsHelperA;
1005  typedef WmmaGemmTileTraitsHelperB<kLayoutB_, GemmConfig, ScalarB_> GemmTileTraitsHelperB;
1006 
1008  typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
1009  GlobalLoadIteratorA;
1011  typedef Copy<typename GlobalLoadIteratorA::Fragment> GlobalTransformerA;
1013  typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
1014  typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
1017  SharedStoreIteratorA;
1019  typedef GlobalLoadStream<GemmOperand::kA,
1020  GlobalLoadIteratorA,
1021  SharedStoreIteratorA,
1022  GlobalTransformerA>
1023  GlobalLoadStreamA;
1024 
1026  typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
1027  GlobalLoadIteratorB;
1028  // The default transformer for B.
1029  typedef Copy<typename GlobalLoadIteratorB::Fragment> GlobalTransformerB;
1031  typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
1032  typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
1035  SharedStoreIteratorB;
1037  typedef GlobalLoadStream<GemmOperand::kB,
1038  GlobalLoadIteratorB,
1039  SharedStoreIteratorB,
1040  GlobalTransformerB>
1041  GlobalLoadStreamB;
1042 
1044  typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
1045  typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
1048  Index_,
1049  typename GemmTileTraitsHelperA::WmmaMatrix,
1051  SharedLoadIteratorA;
1053  typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
1055  typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
1056  typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
1059  Index_,
1060  typename GemmTileTraitsHelperB::WmmaMatrix,
1062  SharedLoadIteratorB;
1064  typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
1065 
1067  typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
1069  typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
1070 
1072  typedef WmmaGemmEpilogueTraitsHelper<GemmConfig, Accumulator_, EpilogueFunctor_, Index_> EpilogueTraitsHelper;
1074  typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_, EpilogueTraitsHelper>
1075  GemmEpilogueTraits;
1077  typedef GemmEpilogue<GemmEpilogueTraits> Epilogue;
1078 };
1079 
1081 
1082 template <typename OutputTile_, typename DefaultShape_ = Shape<64, 32, 64> >
1083 struct WmmaGemmAccumulatorsPerWarp {
1084  typedef typename ShapeMin<OutputTile_, DefaultShape_>::Shape Shape;
1085 };
1086 
1088 
1089 template <
1091  MatrixLayout::Kind kLayoutA_,
1093  MatrixLayout::Kind kLayoutB_,
1095  typename OutputTile_ = Shape<64, 128, 128>,
1097  typename ScalarA_ = half,
1099  typename ScalarB_ = half,
1101  typename ScalarC_ = float,
1103  typename EpilogueFunctor_ = LinearScaling<ScalarC_>,
1105  typename Accumulator_ = ScalarC_,
1107  typename WarpGemmShape_ = typename WmmaGemmAccumulatorsPerWarp<OutputTile_>::Shape,
1109  typename InstructionShape_ = Shape<16, 16, 16>,
1111  int kScalarsPerLdgA_ = 8,
1113  int kScalarsPerLdgB_ = 8,
1115  int KScalarsPerLdsA_ = 8,
1117  int KscalarsPerLdsB_ = 8,
1119  int kScalarsPerLdgCAndStgD_ = 16 / sizeof(ScalarC_),
1121  int kScalarsPerStsD_ = 16 / sizeof(Accumulator_),
1123  int kScalarsPerLdsD_ = 16 / sizeof(Accumulator_),
1125  typename Index_ = int,
1127  typename Helper_ = WmmaGemmTraitsHelper<kLayoutA_,
1128  kLayoutB_,
1129  OutputTile_,
1130  ScalarA_,
1131  ScalarB_,
1132  ScalarC_,
1133  Accumulator_,
1134  EpilogueFunctor_,
1135  WarpGemmShape_,
1136  InstructionShape_,
1137  kScalarsPerLdgA_,
1138  kScalarsPerLdgB_,
1139  KScalarsPerLdsA_,
1140  KscalarsPerLdsB_,
1141  kScalarsPerLdgCAndStgD_,
1142  kScalarsPerStsD_,
1143  kScalarsPerLdsD_,
1144  Index_> >
1145 struct WmmaGemmTraits : public GemmTraits<
1146  // The config.
1147  typename Helper_::GemmConfig,
1148  // The stream to load A from global memory to shared memory.
1149  typename Helper_::GlobalLoadStreamA,
1150  // The stream to load B from global memory to shared memory.
1151  typename Helper_::GlobalLoadStreamB,
1152  // The stream to load A from shared memory.
1153  typename Helper_::SharedLoadStreamA,
1154  // The stream to load B from shared memory.
1155  typename Helper_::SharedLoadStreamB,
1156  // The epilogue.
1157  typename Helper_::Epilogue,
1158  // The block swizzle to reorganize the grid.
1159  IdentityBlockSwizzle,
1160  // The index.
1161  Index_,
1162  // The tool used to clear accumulators.
1163  typename Helper_::ClearAccumulators> {};
1164 
1166 
1167 } // namespace gemm
1168 } // namespace cutlass
1169 
1170 #endif // defined CUTLASS_USE_WMMA_API
Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:41
Definition: convert.h:33
Defines iterators for efficiently loading and storing to global memory.
Defines structural properties of complete GEMM computation.
Defines structural properties of WMMA GEMM&#39;s epilogue phase.
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Definition: tile_iterator.h:65
Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
Defines iterators for efficiently loading and storing tiles to and from shared memory.
MultiplyAdd_ MultiplyAdd
The functor to do D = A*B + C.
Definition: gemm_config.h:90
Definition: matrix_traits.h:357
Definition: matrix_traits.h:159
Defines tile iterator traits for loading thread block-level tile from global memory.
Definition: matrix_traits.h:159
Implements warp-level matrix multiply-accumulate operation using CUDA WMMA API.
Definition: matrix_traits.h:357
Implements a software-pipelined efficient GEMM.
Defines structural properties of the GEMM epilogue.
Shape<(A_::kD< B_::kD ? A_::kD :B_::kD),(A_::kH< B_::kH ? A_::kH :B_::kH),(A_::kW< B_::kW ? A_::kW :B_::kW),(A_::kC< B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:159
Defines conversion operations among Fragments of different base type.