37 template <
typename Scalar_,
typename Tile_,
typename Threads_,
int kScalarsPerSts_>
58 Tile::kH / Threads::kH,
59 Tile::kW / Threads::kW,
79 template <
typename Scalar_,
typename Tile_,
typename Threads_,
int kScalarsPerSts_,
int kSkew_>
93 static int const kSkew = kSkew_;
100 typedef Shape<1, TileWithoutSkew::kH / Threads::kW, TileWithoutSkew::kW / Threads::kH>
Iterations;
120 template <
typename Scalar_,
121 typename OutputTile_,
123 typename ThreadsPerWarp_,
124 typename InstructionShape_,
135 typedef Shape<kStages_,
136 OutputTile_::kD / InstructionShape_::kD,
174 int const warp = threadIdx.x / kWarpSize;
176 int const slice = warp / (Warps::kH * Warps::kW);
178 int const warp_row = warp % Warps::kW;
180 int const lane_row = (threadIdx.x & 0x0e) / 2;
183 slice * Tile::kW * Tile::kC + (warp_row * ThreadsPerWarp::kW + lane_row) *
kAccessSize;
192 template <
typename Scalar_,
193 typename OutputTile_,
195 typename ThreadsPerWarp_,
196 typename InstructionShape_,
207 typedef Shape<kStages_,
208 OutputTile_::kD / InstructionShape_::kD,
244 int const warp = threadIdx.x / kWarpSize;
246 int const slice = warp / (Warps::kH * Warps::kW);
248 int const warp_in_slice = warp % (Warps::kH * Warps::kW);
250 int const warp_col = warp_in_slice / Warps::kW;
252 int const lane_col = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01);
255 slice * Tile::kW * Tile::kC + (warp_col * ThreadsPerWarp::kH + lane_col) *
kAccessSize;
264 template <
typename Scalar_,
265 typename OutputTile_,
267 typename ThreadsPerWarp_,
308 int const warp = threadIdx.x / kWarpSize;
311 int const warp_row = warp % Warps::kW;
312 int const warp_col = warp / Warps::kW;
317 int hi_halfwarp_offset = ((threadIdx.x >> 4) & 0x1) * OutputTile::kW;
318 int lo_halfwarp_offset = ((threadIdx.x >> 1) & 0x7) + ThreadsPerWarp::kW * warp_row;
321 int const row = threadIdx.x & 0x01;
322 int col = warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW +
323 lo_halfwarp_offset *
kAccessSize + hi_halfwarp_offset;
332 template <
typename Scalar_,
333 typename OutputTile_,
335 typename ThreadsPerWarp_,
383 static int const kSplitK = OutputTile::kW * ThreadsPerWarp::kH / 2 * Warps::kH;
398 int const h = threadIdx.x / kWarpSize;
400 int const w = (threadIdx.x & (kWarpSize - 1)) *
kAccessSize;
403 int const row = h & 0x1;
404 int const col = h / 2;
static int const kAccessSize
The number of scalars per STS.
Definition: gemm_shared_tile.h:95
static CUTLASS_DEVICE int get()
Definition: shape.h:214
ReshapeTile< TileWithSkew, kScalarsPerLds_ >::Tile Tile
The tile.
Definition: gemm_shared_tile.h:216
ReshapeTile< TileWithSkew, kScalarsPerLds_ >::Tile Tile
The tile.
Definition: gemm_shared_tile.h:144
ReshapeTile< TileWithoutSkew_, kScalarsPerLds_ >::Tile TileWithoutSkew
The tile without skew after reshaping.
Definition: gemm_shared_tile.h:214
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:355
static int const kScalarsPerThread
The number of scalars per thread.
Definition: gemm_shared_tile.h:358
Definition: load_store.h:41
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:242
Shape< 1, 1, TileWithoutSkew::kW/kWarps/kThreadsPerWarp > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_shared_tile.h:234
static int const kWarps
The number of warps.
Definition: gemm_shared_tile.h:229
Definition: gemm_shared_tile.h:128
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:42
Definition: gemm_shared_tile.h:80
static int const kScalarsPerRow
The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
Definition: gemm_shared_tile.h:293
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:131
Definition: gemm_shared_tile.h:106
Shape< 1, 1, kScalarsPerThread/kAccessSize > Iterations
The number of iterations needed to store the tile.
Definition: gemm_shared_tile.h:298
Shape< TileWithSkew::kW *Warps::kD, 0, kWarps *kThreadsPerWarp *kAccessSize, 0 > Delta
Definition: gemm_shared_tile.h:238
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:351
ThreadsPerWarp_ ThreadsPerWarp
The threads in a warp.
Definition: gemm_shared_tile.h:148
Definition: reshape_tile.h:42
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Shape< 0, ShapeCount< Tile >::kWc, Tile::kC, kScalarsPerSts_ > ThreadsStrides
The strides to compute the base position of the thread.
Definition: gemm_shared_tile.h:48
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:282
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:353
Warps_ Warps
The number of warps.
Definition: gemm_shared_tile.h:218
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:224
Definition: gemm_shared_tile.h:38
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:203
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_shared_tile.h:395
Definition: gemm_shared_tile.h:200
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:155
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:40
static GemmOperand::Kind const kOperand
Definition: gemm_shared_tile.h:129
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:286
Kind
Definition: load_store.h:39
Shape< kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW+kSkew_ > TileWithSkew
The tile with skew.
Definition: gemm_shared_tile.h:212
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:151
static int const kH
The height of the cube.
Definition: shape.h:68
Shape< 1, Tile::kH/Threads::kH, Tile::kW/Threads::kW, Tile::kC/Threads::kC/kAccessSize > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_shared_tile.h:61
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:93
Shape< 1, 1, TileWithoutSkew::kW/kWarps/kThreadsPerWarp > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_shared_tile.h:164
OutputTile_ OutputTile
The dimension of the output tile.
Definition: gemm_shared_tile.h:276
static int const kScalarsPerRow
The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
Definition: gemm_shared_tile.h:362
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:205
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:133
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:274
static int const kScalarsPerThread
The number of scalars per thread.
Definition: gemm_shared_tile.h:289
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:306
Shape< kIterationsD, kIterationsH, OutputTile::kW/kWarpSize/kAccessSize, Warps::kD > Iterations
The number of iterations needed to store the tile.
Definition: gemm_shared_tile.h:387
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:54
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:50
static int const kThreadsPerWarp
The number of threads in one dimension of the warp.
Definition: gemm_shared_tile.h:231
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_shared_tile.h:241
Shape< 0, ShapeCount< Tile >::kWc, Threads::kH *kAccessSize > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:104
Shape< 1, 2, kScalarsPerRow/kAccessSize, kAccessSize > Tile
The tile.
Definition: gemm_shared_tile.h:296
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:52
ReshapeTile< Tile_, kScalarsPerSts_ >::Tile Tile
The tile.
Definition: gemm_shared_tile.h:44
Definition: gemm_shared_tile.h:68
static int const kIterationsInHPerWarp
Definition: gemm_shared_tile.h:369
Shape< OutputTile::kW, kScalarsPerRow, kWarpSize *kAccessSize, kSplitK > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:390
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:284
ReshapeTile< TileWithoutSkew_, kScalarsPerLds_ >::Tile TileWithoutSkew
The tile without skew after reshaping.
Definition: gemm_shared_tile.h:142
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Shape< 0, Threads::kH *ShapeCount< Tile >::kWc, Threads::kW *kAccessSize > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:66
ReshapeTile< Shape< Tile_::kD, Tile_::kH, Tile_::kW+kSkew_ >, kScalarsPerSts_ >::Tile Tile
The tile.
Definition: gemm_shared_tile.h:89
Shape< 0, kScalarsPerSts_, ShapeCount< Tile >::kHwc/Threads::kW > ThreadsStrides
The strides to compute the base position of the thread.
Definition: gemm_shared_tile.h:115
ReshapeTile< Tile_, kScalarsPerSts_ >::Tile TileWithoutSkew
The tile without skews.
Definition: gemm_shared_tile.h:86
static int const kIterationsD
Definition: gemm_shared_tile.h:379
Shape< TileWithSkew::kW *Warps::kD, 0, kWarps *kThreadsPerWarp *kAccessSize, 0 > Delta
Definition: gemm_shared_tile.h:168
static int const kWarps
The number of warps.
Definition: gemm_shared_tile.h:158
Definition: matrix_traits.h:357
ThreadsPerWarp_ ThreadsPerWarp
The threads in the warps.
Definition: gemm_shared_tile.h:280
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_shared_tile.h:171
Shape< kStages_, OutputTile_::kD/InstructionShape_::kD, GetExtent< kOperand, OutputTile_ >::kExtent *InstructionShape_::kD > TileWithoutSkew_
The tile without skew.
Definition: gemm_shared_tile.h:138
Definition: gemm_shared_tile.h:339
Threads_ Threads
The threads.
Definition: gemm_shared_tile.h:91
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
OutputTile_ OutputTile
The dimension of the output tile.
Definition: gemm_shared_tile.h:345
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:82
Shape< 0, 0, Warps::kW *ThreadsPerWarp::kW *kAccessSize > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:302
static GemmOperand::Kind const kOperand
Definition: gemm_shared_tile.h:201
Shape< 1, 2, kScalarsPerRow/kAccessSize, kAccessSize > Tile
Definition: gemm_shared_tile.h:366
Shape< OutputTile::kW, kScalarsPerRow, kWarpSize *kAccessSize, kSplitK > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:392
static int const kThreadsPerWarp
The number of threads in one dimension of the warp.
Definition: gemm_shared_tile.h:160
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:84
static int const kSplitK
Definition: gemm_shared_tile.h:383
Shape< 1, TileWithoutSkew::kH/Threads::kW, TileWithoutSkew::kW/Threads::kH > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_shared_tile.h:100
Shape< kStages_, OutputTile_::kD/InstructionShape_::kD, GetExtent< kOperand, OutputTile_ >::kExtent *InstructionShape_::kD > TileWithoutSkew_
The tile without skew.
Definition: gemm_shared_tile.h:210
Threads_ Threads
The threads.
Definition: gemm_shared_tile.h:46
Definition: gemm_operand.h:50
Shape< TileWithSkew::kW *Warps::kD, 0, kWarps *kThreadsPerWarp *kAccessSize, 0 > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:237
Shape< 0, Threads::kH *ShapeCount< Tile >::kWc, Threads::kW *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:63
static int const kThreads
The number of threads.
Definition: gemm_shared_tile.h:360
Warps_ Warps
The number of warps.
Definition: gemm_shared_tile.h:146
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:97
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:226
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:70
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:172
static int const kD
The depth of the cube.
Definition: shape.h:66
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_shared_tile.h:305
Warps_ Warps
The warps in the tile.
Definition: gemm_shared_tile.h:347
Tile_ Tile
Definition: reshape_tile.h:43
Shape< 0, ShapeCount< Tile >::kWc, Threads::kH *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:102
static int const kIterationsH
Definition: gemm_shared_tile.h:377
Shape< 0, 0, Warps::kW *ThreadsPerWarp::kW *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:300
Kind
Definition: matrix_traits.h:357
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:153
ThreadsPerWarp_ ThreadsPerWarp
The threads in the warps.
Definition: gemm_shared_tile.h:349
Definition: matrix_traits.h:357
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:343
static int const kThreads
The number of threads.
Definition: gemm_shared_tile.h:291
Shape< TileWithSkew::kW *Warps::kD, 0, kWarps *kThreadsPerWarp *kAccessSize, 0 > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:167
ThreadsPerWarp_ ThreadsPerWarp
The threads in a warp.
Definition: gemm_shared_tile.h:220
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:272
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
Shape< kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW+kSkew_ > TileWithSkew
The tile with skew.
Definition: gemm_shared_tile.h:140
Warps_ Warps
The warps in the tile.
Definition: gemm_shared_tile.h:278
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:107
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:341
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:396
Definition: gemm_shared_tile.h:270
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:222