Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_epilogue_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/convert.h"
31 #include "cutlass/coord.h"
35 #include "cutlass/reshape_tile.h"
36 #include "cutlass/tile_iterator.h"
37 
38 namespace cutlass {
39 namespace gemm {
40 
42 
43 template <
45  typename OutputTile_,
47  typename Accumulators_,
49  typename GlobalLoadIteratorC_,
51  typename GlobalTransformerC_,
53  typename GlobalTransformerD_,
55  typename GlobalStoreIteratorD_,
57  typename SharedStoreIteratorD_,
59  typename SharedStoreTransformerD_,
61  typename SharedLoadStreamD_,
63  typename Iterations_,
65  typename Delta_,
67  typename Functor_,
69  typename Index_ = int>
71  //
73  typedef OutputTile_ OutputTile;
76  typedef Accumulators_ Accumulators;
78  typedef GlobalLoadIteratorC_ GlobalLoadIteratorC;
80  typedef GlobalTransformerC_ GlobalTransformerC;
82  typedef GlobalTransformerD_ GlobalTransformerD;
84  typedef GlobalStoreIteratorD_ GlobalStoreIteratorD;
86  typedef SharedStoreIteratorD_ SharedStoreIteratorD;
88  typedef SharedStoreTransformerD_ SharedStoreTransformerD;
90  typedef SharedLoadStreamD_ SharedLoadStreamD;
92  typedef Iterations_ Iterations;
94  typedef Delta_ Delta;
95 
97  typedef Functor_ Functor;
99  typedef Index_ Index;
101  typedef long long LongIndex;
102 
104  static_assert(Iterations::kD == 1 && Iterations::kC == 1, "Unsupported 3D/4D shapes");
105 
107  typedef typename Functor::Scalar Scalar;
109  typedef typename GlobalLoadIteratorC::Scalar ScalarC;
111  typedef typename GlobalStoreIteratorD::Scalar ScalarD;
112 
114  struct Params {
118  typename GlobalLoadIteratorC::Params iterator_c;
119 
122 
124  typename GlobalStoreIteratorD::Params iterator_d;
125 
128 
130  typename SharedStoreIteratorD::Params shared_store_iterator_d;
132  typename SharedLoadStreamD::Params shared_load_stream_d;
134  typename Functor::Params functor;
135 
137  template <typename GemmDesc_>
138  CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
139 
140  // The parameters for the functor.
141  int error_code = functor.initialize(desc);
142  if (error_code) {
143  return error_code;
144  }
145 
146  // At the end of the H iteration, we jump over a number of columns.
147  this->stride_h = desc.D.leading_dim() * Delta::kH;
148  // Nothing to do here.
149  this->stride_w = 0;
150  // Setup the params for the global memory iterator for C.
151  error_code = iterator_c.initialize(desc.C.data(),
152  desc.C.leading_dim(),
153  desc.C.leading_dim(),
154  desc.problem_size[1],
155  stride_w,
156  Delta::kW);
157 
158  batch_stride_C = desc.batch_stride_C;
159 
160  if (error_code) {
161  return error_code;
162  }
163 
164  // Setup the params for the global memory iterator for D.
165  error_code = iterator_d.initialize(desc.D.data(),
166  desc.D.leading_dim(),
167  desc.D.leading_dim(),
168  desc.problem_size[1],
169  stride_w,
170  Delta::kW);
171 
172  batch_stride_D = desc.batch_stride_D;
173 
174  return error_code;
175  }
176  };
177 
180  // The storage for the store iterator.
181  typename SharedStoreIteratorD::SharedStorage store;
182  // The storage for the store iterator.
183  typename SharedLoadStreamD::SharedStorage load;
184  };
185 
187  struct SharedStorage {
188  // The storage for the shared stream D.
190 
191  //
192  //
193  //
194 
195  CUTLASS_DEVICE
196  ScalarD* data() { return reinterpret_cast<ScalarD*>(&shared_stream.load); }
197  };
198 };
199 
201 
202 template <typename GemmConfig_, typename EpilogueFunctor_, typename Index_ = int>
205  typedef typename EpilogueFunctor_::Scalar Scalar;
207  typedef typename GemmConfig_::OutputTile OutputTile;
208 
210  typedef Shape<1,
211  GemmConfig_::MultiplyAdd::AccumulatorsPerThread::kH /
212  GemmConfig_::kAccumulatorsPerLdsB,
213  GemmConfig_::kAccumulatorsPerLdsB>
215  // The iteration strides in the H/W dimension.
216  typedef Shape<0,
217  GemmConfig_::kAccumulatorsPerLdsB*(
218  GemmConfig_::Warps::kH* GemmConfig_::MultiplyAdd::ThreadsPerWarp::kH - 1),
219  0>
222  typedef EpilogueFunctor_ Functor;
223 
226  // The pointer is float.
227  // typename Functor::Scalar,
228  // Functor::Scalar is alpha, beta type, in mixed precision, alpha and beta may not be the same with accumulation.
229  // In this case Functor::ScalarAccum is needed
230  typename Functor::ScalarAccum,
231  // The output tile size.
232  typename GemmConfig_::OutputTile,
233  // The number of warps.
234  typename GemmConfig_::Warps,
235  // The number of threads per warp.
236  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
237  // The number of scalars per STS.
238  GemmConfig_::kScalarsPerStsD,
239  // The skew -- 128 / sizeof(ScalarD) / kScalarsPerStsD is the number of threads involved in
240  // a single STS. We divide by 2 as our objective is to add a skew to the odd threads to
241  // avoid bank conflicts between odd and even threads.
242  128 / sizeof(typename GemmConfig_::ScalarD) / GemmConfig_::kScalarsPerStsD / 2 *
243  GemmConfig_::kScalarsPerStsD>
245 
252 
255 
258  // The pointer is float.
259  // typename Functor::Scalar,
260  // Functor::Scalar is alpha, beta type, in mixed precision, alpha and beta may not be the same with accumulation.
261  // In this case Functor::ScalarAccum is needed
262  typename Functor::ScalarAccum,
263  // The output tile size.
264  typename GemmConfig_::OutputTile,
265  // The number of warps.
266  typename GemmConfig_::Warps,
267  // The number of threads per warp.
268  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
269  // The number of columns of the output tile written by iteration.
270  GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
271  // The number of scalars per LDS.
272  GemmConfig_::kScalarsPerLdsD,
273  // The skew.
276 
285 
287  typedef GemmGlobalTileCdTraits<
288  // The pointer is float const.
289  typename GemmConfig_::ScalarC const,
290  // The tile has size (N / Iterations)xM in GEMM's terminology.
291  Shape<1,
292  GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
293  GemmConfig_::OutputTile::kW>,
294  // The threads are distributed as warps x 32 (the traits may reorganize).
296  // How many elements do we jump over at each iteration?
298  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
299  GemmConfig_::kScalarsPerLdgC>
301 
306 
308  typedef GemmGlobalTileCdTraits<
309  // The pointer is float.
310  typename GemmConfig_::ScalarD,
311  // The tile has size (N / Iterations)xM in GEMM's terminology.
312  Shape<1,
313  GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
314  GemmConfig_::OutputTile::kW>,
315  // The threads are distributed as warps x 32 (the traits may reorganize).
317  // How many elements do we jump over at each iteration?
319  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
320  GemmConfig_::kScalarsPerStgD>
322 
327 };
328 
330 
331 template <
333  typename GemmConfig_,
335  typename EpilogueFunctor_,
337  typename Index_ = int,
341  // The output tile.
342  typename GemmConfig_::OutputTile,
343  // The accumulators.
344  typename GemmConfig_::Accumulators,
345  // The global iterator for C.
346  typename Helper_::GlobalLoadIteratorC,
347  // The transformer for C.
348  typename Helper_::GlobalTransformerC,
349  // The transformer for D.
350  typename Helper_::GlobalTransformerD,
351  // The global iterator for D.
352  typename Helper_::GlobalStoreIteratorD,
353  // The iterator to store D to shared memory.
354  typename Helper_::SharedStoreIteratorD,
355  // The shared store transformer for D.
356  typename Helper_::SharedStoreTransformerD,
357  // The stream to load D from shared memory.
358  typename Helper_::SharedLoadStreamD,
359  // The number of iterations.
360  typename Helper_::Iterations,
361  // The strides between iterations.
362  typename Helper_::Delta,
363  // The functor to be used in the epilogue.
364  EpilogueFunctor_,
365  // The index.
366  Index_> {};
367 
369 
370 } // namespace gemm
371 } // namespace cutlass
Definition: gemm_global_tile.h:120
SharedStoreTransformerD_ SharedStoreTransformerD
The shared store transformer for D.
Definition: gemm_epilogue_traits.h:88
Iterations_ Iterations
typedef typename GemmConfig::EpilogueIterations Iterations;
Definition: gemm_epilogue_traits.h:92
CUTLASS_DEVICE ScalarD * data()
Definition: gemm_epilogue_traits.h:196
Definition: load_store.h:41
Definition: convert.h:33
GemmGlobalTileCdTraits< typename GemmConfig_::ScalarC const, Shape< 1, GemmConfig_::OutputTile::kH/ShapeCount< Iterations >::kCount, GemmConfig_::OutputTile::kW >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, Iterations::kW, GemmConfig_::kScalarsPerLdgC > GlobalLoadTileTraits
The traits class to build the iterator to load data from global memory for C^N.
Definition: gemm_epilogue_traits.h:300
GlobalLoadIteratorC_ GlobalLoadIteratorC
The iterator for C in global memory.
Definition: gemm_epilogue_traits.h:78
Definition: gemm_epilogue_traits.h:203
Functor::Params functor
The functor params.
Definition: gemm_epilogue_traits.h:134
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
SharedLoadStreamD::SharedStorage load
Definition: gemm_epilogue_traits.h:183
long long LongIndex
The long index.
Definition: gemm_epilogue_traits.h:101
Implements the BLAS linear scaling function alpha*AB + beta*C.
The shared memory storage to exchange data.
Definition: gemm_epilogue_traits.h:179
EpilogueFunctor_::Scalar Scalar
The scalar.
Definition: gemm_epilogue_traits.h:205
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: convert.h:69
GlobalTransformerC_ GlobalTransformerC
The transformer for C.
Definition: gemm_epilogue_traits.h:80
GemmGlobalIteratorCd< GlobalLoadTileTraits, Index_ > GlobalLoadIteratorC
The iterator to load C.
Definition: gemm_epilogue_traits.h:303
GlobalTransformerD_ GlobalTransformerD
The transformer for D.
Definition: gemm_epilogue_traits.h:82
Definition: tile_iterator.h:65
TileStoreIterator< SharedStoreTileTraits, typename SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorD
The iterator to store D to shared memory.
Definition: gemm_epilogue_traits.h:251
SharedLoadStream< SharedLoadIteratorD > SharedLoadStreamD
The stream to load D.
Definition: gemm_epilogue_traits.h:284
GlobalStoreIteratorD::Scalar ScalarD
The scalar for D.
Definition: gemm_epilogue_traits.h:111
LongIndex batch_stride_C
Batch stride for C matrix.
Definition: gemm_epilogue_traits.h:121
GemmGlobalIteratorCd< GlobalStoreTileTraits, Index_ > GlobalStoreIteratorD
The iterator to store D.
Definition: gemm_epilogue_traits.h:324
GlobalStoreIteratorD::Params iterator_d
The params for the D global iterator.
Definition: gemm_epilogue_traits.h:124
SharedLoadStreamD::Params shared_load_stream_d
The params for the D shared load stream.
Definition: gemm_epilogue_traits.h:132
Copy< typename SharedStoreIteratorD::Fragment > SharedStoreTransformerD
The shared store transformer for D.
Definition: gemm_epilogue_traits.h:254
Shape< 1, GemmConfig_::MultiplyAdd::AccumulatorsPerThread::kH/GemmConfig_::kAccumulatorsPerLdsB, GemmConfig_::kAccumulatorsPerLdsB > Iterations
The number of iterations in the epilogue.
Definition: gemm_epilogue_traits.h:214
GemmGlobalTileCdTraits< typename GemmConfig_::ScalarD, Shape< 1, GemmConfig_::OutputTile::kH/ShapeCount< Iterations >::kCount, GemmConfig_::OutputTile::kW >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, Iterations::kW, GemmConfig_::kScalarsPerStgD > GlobalStoreTileTraits
The traits class to build the iterator to store data to global memory for D^N.
Definition: gemm_epilogue_traits.h:321
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:402
Index stride_h
The strides for H and W in the different iterations of the epilogue.
Definition: gemm_epilogue_traits.h:116
Functor_ Functor
The functor in charge of the math.
Definition: gemm_epilogue_traits.h:97
Definition: gemm_shared_stream.h:45
Accumulators_ Accumulators
Definition: gemm_epilogue_traits.h:76
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:284
Defines a type for restructuring a tile.
GlobalLoadIteratorC::Params iterator_c
The params for the C iterator.
Definition: gemm_epilogue_traits.h:118
LongIndex batch_stride_D
Batch stride for C matrix.
Definition: gemm_epilogue_traits.h:127
SharedStoreIteratorD::SharedStorage store
Definition: gemm_epilogue_traits.h:181
Index stride_w
Definition: gemm_epilogue_traits.h:116
SharedLoadStreamD_ SharedLoadStreamD
The stream to store D in shared memory.
Definition: gemm_epilogue_traits.h:90
OutputTile_ OutputTile
The output tile.
Definition: gemm_epilogue_traits.h:73
Definition: gemm_shared_tile.h:339
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
#define static_assert(__e, __m)
Definition: platform.h:153
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Definition: gemm_epilogue_traits.h:340
EpilogueFunctor_ Functor
The functor to do the math in the epilogue.
Definition: gemm_epilogue_traits.h:222
TileLoadIterator< SharedLoadTileTraits, typename SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorD
The iterator to load D from shared memory.
Definition: gemm_epilogue_traits.h:282
GemmConfig_::OutputTile OutputTile
The output tile.
Definition: gemm_epilogue_traits.h:207
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Setup the params.
Definition: gemm_epilogue_traits.h:138
StreamSharedStorage shared_stream
Definition: gemm_epilogue_traits.h:189
Index_ Index
The index.
Definition: gemm_epilogue_traits.h:99
Definition: gemm_epilogue_traits.h:70
GemmSharedLoadTileDTraits< typename Functor::ScalarAccum, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, GemmConfig_::OutputTile::kH/ShapeCount< Iterations >::kCount, GemmConfig_::kScalarsPerLdsD, SharedStoreTileTraits::kSkew > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for D.
Definition: gemm_epilogue_traits.h:275
Definition: gemm_global_tile.h:366
static int const kW
The width of the cube.
Definition: shape.h:70
Delta_ Delta
The iterations strides.
Definition: gemm_epilogue_traits.h:94
GlobalStoreIteratorD_ GlobalStoreIteratorD
The iterator for D in global memory.
Definition: gemm_epilogue_traits.h:84
Copy< typename GlobalStoreIteratorD::Fragment > GlobalTransformerD
The transformer for D.
Definition: gemm_epilogue_traits.h:326
GlobalLoadIteratorC::Scalar ScalarC
The scalar for C.
Definition: gemm_epilogue_traits.h:109
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
The params.
Definition: gemm_epilogue_traits.h:114
The shared memory to swizzle the data in the epilogue.
Definition: gemm_epilogue_traits.h:187
Copy< typename GlobalLoadIteratorC::Fragment > GlobalTransformerC
The transformer for C.
Definition: gemm_epilogue_traits.h:305
SharedStoreIteratorD::Params shared_store_iterator_d
The params for the D shared store iterator.
Definition: gemm_epilogue_traits.h:130
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:272
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
Defines conversion operations among Fragments of different base type.
Functor::Scalar Scalar
We do not support 3D or 4D shapes.
Definition: gemm_epilogue_traits.h:104
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:341
SharedStoreIteratorD_ SharedStoreIteratorD
The iterator to store D in shared memory.
Definition: gemm_epilogue_traits.h:86
GemmSharedStoreTileDTraits< typename Functor::ScalarAccum, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, GemmConfig_::kScalarsPerStsD, 128/sizeof(typename GemmConfig_::ScalarD)/GemmConfig_::kScalarsPerStsD/2 *GemmConfig_::kScalarsPerStsD > SharedStoreTileTraits
The traits class to build the iterator to store to shared memory for D.
Definition: gemm_epilogue_traits.h:244
Definition: gemm_shared_tile.h:270
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:841