Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_gemm_epilogue_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/wmma_matrix.h"
31 #ifdef CUTLASS_USE_WMMA_API
32 
33 #include "cutlass/convert.h"
34 #include "cutlass/coord.h"
40 #include "cutlass/reshape_tile.h"
41 #include "cutlass/tile_iterator.h"
42 
43 namespace cutlass {
44 namespace gemm {
45 
47 
48 template <typename GemmConfig_, typename Accumulator_, typename EpilogueFunctor_, typename Index_ = int>
49 struct WmmaGemmEpilogueTraitsHelper {
51  typedef typename EpilogueFunctor_::Scalar Scalar;
53  typedef typename GemmConfig_::OutputTile OutputTile;
54 
56  static int const kWmmasPerH =
57  GemmConfig_::AccumulatorsPerWarp::kH / GemmConfig_::InstructionShape::kH;
59  typedef Shape<1, 1, kWmmasPerH> Iterations;
60  // The iteration strides in the H/W dimension.
61  typedef Shape<0, 0, 0> Delta;
63  typedef EpilogueFunctor_ Functor;
64 
66  typedef WmmaGemmSharedStoreTileDTraits<
67  // The output layout.
69  // The pointer is float.
70  typename Functor::Scalar,
71  // The output tile size.
72  typename GemmConfig_::OutputTile,
73  // The number of warps.
74  typename GemmConfig_::Warps,
75  // The shape of the instruction.
76  typename GemmConfig_::InstructionShape>
77  SharedStoreTileTraits;
78 
79  typedef WmmaMatrix<GemmOperand::kC,
81  Scalar,
82  typename GemmConfig_::InstructionShape>
83  WmmaMatrix;
84 
86  typedef TileStoreIterator<SharedStoreTileTraits,
87  typename SharedStoreTileTraits::Scalar,
90  Index_,
91  WmmaMatrix,
93  SharedStoreIteratorD;
94 
96  typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
97 
99  typedef WmmaGemmSharedLoadTileDTraits<
100  // The pointer.
101  typename Functor::Scalar,
102  // The tile size.
103  typename SharedStoreIteratorD::Tile,
104  // The number of threads.
105  Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
106  // The number of scalars per LDS.
107  GemmConfig_::kScalarsPerLdsD,
108  // this parameter helps with swizzling when accum is fp32 and output is fp16
109  sizeof(Accumulator_) / sizeof(typename GemmConfig_::ScalarD)
110  >
111  SharedLoadTileTraits;
112 
114  typedef TileLoadIterator<SharedLoadTileTraits,
115  typename SharedLoadTileTraits::Scalar,
118  SharedLoadIteratorD;
119 
121  typedef SharedLoadStream<SharedLoadIteratorD> SharedLoadStreamD;
122 
124  typedef WmmaGemmGlobalIteratorCdTraits<
125  // The pointer is float const.
126  typename GemmConfig_::ScalarC const,
127  // The tile has size (N / Iterations)xM in GEMM's terminology.
128  Shape<1,
129  GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
130  GemmConfig_::OutputTile::kW>,
131  // The threads are distributed as warps x 32 (the traits may reorganize).
132  Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
133  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
134  GemmConfig_::kScalarsPerLdgC>
135  GlobalLoadTileTraits;
136 
138  typedef WmmaGemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
140  typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
141 
143  typedef WmmaGemmGlobalIteratorCdTraits<
144  // The pointer is float.
145  typename GemmConfig_::ScalarD,
146  // The tile has size (N / Iterations)xM in GEMM's terminology.
147  Shape<1,
148  GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
149  GemmConfig_::OutputTile::kW>,
150  // The threads are distributed as warps x 32 (the traits may reorganize).
151  Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
152  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
153  GemmConfig_::kScalarsPerStgD>
154  GlobalStoreTileTraits;
155 
157  typedef WmmaGemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
159  typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
160 };
161 
163 
164 } // namespace gemm
165 } // namespace cutlass
166 
167 #endif // defined CUTLASS_USE_WMMA_API
Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:41
Definition: convert.h:33
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
Implements the BLAS linear scaling function alpha*AB + beta*C.
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: tile_iterator.h:65
Definition: matrix_traits.h:357
Defines a type for restructuring a tile.
Definition: matrix_traits.h:159
Defines tile iterator traits for loading thread block-level tile from global memory.
static int const kCount
The number of elements in the 4D space.
Definition: shape.h:91
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Defines conversion operations among Fragments of different base type.
Defines iterator traits for efficiently loading and storing fragment to and from shared memory...