Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
hgemm_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/convert.h"
31 #include "cutlass/reshape_tile.h"
32 
33 #include "cutlass/gemm/gemm.h"
42 
43 namespace cutlass {
44 namespace gemm {
45 
47 
48 template <
50  typename OutputTile_,
52  typename ThreadGemmShape_,
54  int kScalarsPerLdgA_ = 2,
56  int kScalarsPerLdgB_ = 2>
57 struct HgemmConfig : public GemmConfig<
59  half,
61  half,
63  half,
65  half,
67  OutputTile_,
69  ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, half, half, half>,
71  kScalarsPerLdgA_,
73  kScalarsPerLdgA_,
75  8,
77  kScalarsPerLdgB_,
79  kScalarsPerLdgB_,
81  8,
83  2,
85  8,
87  2,
89  2,
91  false,
93  true,
95  false
96  > {};
97 
99 
100 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
102 
103 template <typename Iterator_>
104 struct HgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
106 };
107 
108 template <typename Iterator_>
109 struct HgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
111 };
112 
114 
115 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
117 
118 template <typename Iterator_>
119 struct HgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
121 };
122 
123 template <typename Iterator_>
124 struct HgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
126 };
127 
129 
130 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
131 struct HgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
132 
134 
135 template <typename GemmConfig_>
136 struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
137  : public GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
140 
144  // The layout.
146  // The pointer.
147  half const,
148  // The tile has size MxK in GEMM's terminology.
150  // The threads are distributed as (threads / K ) x K (the traits may reorganize).
151  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
152  // The number of scalars per LDG (LDG.32 or LDG.128, etc)
153  GemmConfig_::kScalarsPerLdgA>
155 
156  static int const kSkewA = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
157 
160  // The pointer.
161  half,
162  // The tile has size KxM in GEMM's terminology.
163  Shape<GemmConfig_::kStages,
164  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
165  GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
166  // The threads are distributed as warps x 32(the traits may reorganize).
167  typename GlobalTileTraits::Threads,
168  // The number of scalars per STS (STS.32 or STS.128, etc).
169  2,
170  // The skew to avoid bank conflicts added in the tile W dimension.
171  kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
172  SharedStoreTileTraits;
173 
176  // The pointer.
177  half const,
178  // The output tile size.
179  typename GemmConfig_::OutputTile,
180  // The number of warps.
181  typename GemmConfig_::Warps,
182  // The number of threads per warp.
183  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
184  // The shape of the FMA instruction.
185  typename GemmConfig_::InstructionShape,
186  // The number of stages.
187  GemmConfig_::kStages,
188  // The number of scalars per LDS.
189  8,
190  // The skew.
191  SharedStoreTileTraits::kSkew>
192  SharedLoadTileTraits;
193 };
194 
196 
197 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
198 struct HgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
199 
201 
202 template <typename GemmConfig_>
203 struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
204  : public GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
207 
211  // The layout.
213  // The pointer.
214  half const,
215  // The tile has size KxN in GEMM's terminology.
217  // The threads are distributed as (threads / K) x K (the traits may reorganize).
218  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
219  // The number of scalars per LDG (LDG.32 or LDG.128, etc)
220  GemmConfig_::kScalarsPerLdgB>
222 
223  static int const kSkewB = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
224 
227  // The pointer.
228  half,
229  // The tile has size KxN in GEMM's terminology.
230  Shape<GemmConfig_::kStages,
231  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
232  GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
233  // The threads are distributed as (threads / K) x K (the traits may reorganize).
234  typename GlobalTileTraits::Threads,
235  // The number of scalars per STS (STS.32 or STS.128, etc).
236  2,
237  // The skew to avoid bank conflicts added in the tile W dimension.
238  kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
239  SharedStoreTileTraits;
240 
243  // The pointer.
244  half const,
245  // The output tile size.
246  typename GemmConfig_::OutputTile,
247  // The number of warps.
248  typename GemmConfig_::Warps,
249  // The number of threads per warp.
250  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
251  // The shape of the FMA instruction.
252  typename GemmConfig_::InstructionShape,
253  // The number of stages.
254  GemmConfig_::kStages,
255  // The number of scalars per LDS.
256  8,
257  // The skew.
258  SharedStoreTileTraits::kSkew>
259  SharedLoadTileTraits;
260 };
261 
263 
264 template <
266  MatrixLayout::Kind kLayoutA_,
268  MatrixLayout::Kind kLayoutB_,
270  typename OutputTile_,
272  typename EpilogueFunctor_,
274  typename ThreadGemmShape_,
276  int kScalarsPerLdgA_ = 2,
278  int kScalarsPerLdgB_ = 2,
280  typename Index_ = int>
288 
293  typedef typename HgemmTransformerA<GemmTileTraitsHelperA::kLayout,
296  typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
297  typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
307 
311  // The default transformer for B.
312  typedef typename HgemmTransformerB<GemmTileTraitsHelperB::kLayout,
315  typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
316  typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
326 
328  typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
329  typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
336  typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
337  typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
343 
348 
353 };
354 
356 
357 template <
359  MatrixLayout::Kind kLayoutA_,
361  MatrixLayout::Kind kLayoutB_,
363  typename OutputTile_ = Shape<8, 128, 128>,
365  typename EpilogueFunctor_ = LinearScaling<half>,
367  typename ThreadGemmShape_ = Shape<8, 8, 16>,
369  int kScalarsPerLdgA_ = 2,
371  int kScalarsPerLdgB_ = 2,
373  typename Index_ = int,
375  typename Helper_ = HgemmTraitsHelper<kLayoutA_,
376  kLayoutB_,
377  OutputTile_,
378  EpilogueFunctor_,
379  ThreadGemmShape_,
380  kScalarsPerLdgA_,
381  kScalarsPerLdgB_,
382  Index_> >
383 struct HgemmTraits : public GemmTraits<
384  // The config.
385  typename Helper_::GemmConfig,
386  // The stream to load A from global memory to shared memory.
387  typename Helper_::GlobalLoadStreamA,
388  // The stream to load B from global memory to shared memory.
389  typename Helper_::GlobalLoadStreamB,
390  // The stream to load A from shared memory.
391  typename Helper_::SharedLoadStreamA,
392  // The stream to load B from shared memory.
393  typename Helper_::SharedLoadStreamB,
394  // The epilogue.
395  typename Helper_::Epilogue,
396  // The block swizzle to reorganize the grid.
397  IdentityBlockSwizzle,
398  // The index.
399  Index_,
400  // The tool used to clear accumulators.
401  typename Helper_::ClearAccumulators> {};
402 
404 
405 } // namespace gemm
406 } // namespace cutlass
SharedLoadStream< SharedLoadIteratorB > SharedLoadStreamB
The stream to load B from shared memory.
Definition: hgemm_traits.h:342
GemmGlobalIteratorAb< typename GemmTileTraitsHelperB::GlobalTileTraits, Index_ > GlobalLoadIteratorB
The iterator to load B from global memory.
Definition: hgemm_traits.h:310
Definition: load_store.h:41
HgemmSwizzle< Iterator_ > Transformer
Definition: hgemm_traits.h:125
Definition: convert.h:33
HgemmConfig< OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_ > GemmConfig
The HGEMM config.
Definition: hgemm_traits.h:283
Definition: gemm_shared_tile.h:128
Definition: gemm_epilogue.h:42
Defines iterators for efficiently loading and storing to global memory.
SimplifiedGemmEpilogueTraits< GemmConfig, EpilogueFunctor_, Index_ > GemmEpilogueTraits
The traits class for the epilogue.
Definition: hgemm_traits.h:350
Defines structural properties of complete GEMM computation.
HgemmCrosswiseGlobalTileTraits< GemmOperand::kB, MatrixLayout::kColumnMajor, half const, Shape< 1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
The traits class to build the iterator to load data from global memory for B^N.
Definition: hgemm_traits.h:221
Definition: hgemm_traits.h:101
GemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ > Base
The base config.
Definition: hgemm_traits.h:206
GemmEpilogue< GemmEpilogueTraits > Epilogue
The epilogue.
Definition: hgemm_traits.h:352
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Convert< typename Iterator_::Fragment, typename Iterator_::Fragment > Transformer
Definition: hgemm_traits.h:105
GlobalLoadStream< GemmOperand::kA, GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA > GlobalLoadStreamA
The stream to load A from global memory to shared memory.
Definition: hgemm_traits.h:306
Definition: hgemm_traits.h:383
HgemmSwizzle< Iterator_ > Transformer
Definition: hgemm_traits.h:110
TileLoadIterator< typename GemmTileTraitsHelperB::SharedLoadTileTraits, typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorB
The iterator to load B from shared memory.
Definition: hgemm_traits.h:340
HgemmTransformerA< GemmTileTraitsHelperA::kLayout, GlobalLoadIteratorA >::Transformer GlobalTransformerA
The default transformer for A.
Definition: hgemm_traits.h:294
Definition: tile_iterator.h:65
Definition: gemm_shared_tile.h:200
Definition: gemm_global_tile.h:163
Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
Definition: gemm_global_stream.h:52
Definition: gemm_traits.h:191
Definition: hgemm_traits.h:131
HgemmTileTraitsHelperA< kLayoutA_, GemmConfig > GemmTileTraitsHelperA
The GEMM config for A.
Definition: hgemm_traits.h:285
Definition: hgemm_traits.h:116
GemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_ > Base
The base config.
Definition: hgemm_traits.h:139
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:402
Defines iterators for efficiently loading and storing tiles to and from shared memory.
Definition: gemm_shared_stream.h:45
HgemmTransformerB< GemmTileTraitsHelperB::kLayout, GlobalLoadIteratorB >::Transformer GlobalTransformerB
Definition: hgemm_traits.h:313
Defines a type for restructuring a tile.
ClearAccumulators< typename MultiplyAdd::ScalarC > ClearAccumulators
The object to clear accumulators.
Definition: hgemm_traits.h:347
Specialization implementing multiply-add operation on half-precision floating point fragments...
Definition: gemm_config.h:76
TileLoadIterator< typename GemmTileTraitsHelperA::SharedLoadTileTraits, typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorA
The iterator to load A from shared memory.
Definition: hgemm_traits.h:332
Transposes a tile of 16b elements. Used by HGEMM to construct a K-strided layout in shared memory for...
Definition: gemm_traits.h:52
Definition: matrix_traits.h:357
Definition: hgemm_traits.h:198
GemmConfig::MultiplyAdd MultiplyAdd
The functor to do the multiply-add in the main loop.
Definition: hgemm_traits.h:345
Definition: gemm_traits.h:349
HgemmTileTraitsHelperB< kLayoutB_, GemmConfig > GemmTileTraitsHelperB
The GEMM config for B.
Definition: hgemm_traits.h:287
Definition: hgemm_global_tile.h:48
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Definition: matrix_traits.h:159
Definition: gemm_epilogue_traits.h:340
Definition: matrix_traits.h:159
ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:88
Template performing matrix multiply-add operation within a thread.
Definition: thread_multiply_add.h:44
GemmGlobalIteratorAb< typename GemmTileTraitsHelperA::GlobalTileTraits, Index_ > GlobalLoadIteratorA
The iterator to load A from global memory.
Definition: hgemm_traits.h:291
Definition: hgemm_traits.h:281
HgemmCrosswiseGlobalTileTraits< GemmOperand::kA, MatrixLayout::kRowMajor, half const, Shape< 1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
The traits class to build the iterator to load data from global memory for A^T.
Definition: hgemm_traits.h:154
Tile traits used to construct global tile iterator for HGEMM. This is intended to partition the threa...
Functor to compute linear combination of fragments.
Definition: linear_scaling.h:51
Definition: convert.h:38
Definition: matrix_traits.h:357
Implements a software-pipelined efficient GEMM.
GlobalLoadStream< GemmOperand::kB, GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB > GlobalLoadStreamB
The stream to load B from global memory to shared memory.
Definition: hgemm_traits.h:325
SharedLoadStream< SharedLoadIteratorA > SharedLoadStreamA
The stream to load A from shared memory.
Definition: hgemm_traits.h:334
Defines structural properties of the GEMM epilogue.
TileStoreIterator< typename GemmTileTraitsHelperB::SharedStoreTileTraits, typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorB
The iterator to store B to shared memory.
Definition: hgemm_traits.h:319
TileStoreIterator< typename GemmTileTraitsHelperA::SharedStoreTileTraits, typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorA
The iterator to store A to shared memory.
Definition: hgemm_traits.h:300
Definition: hgemm_swizzle.h:40
Defines conversion operations among Fragments of different base type.
Convert< typename Iterator_::Fragment, typename Iterator_::Fragment > Transformer
Definition: hgemm_traits.h:120
Definition: hgemm_traits.h:57
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:841