Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_shared_tile.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
31 
32 namespace cutlass {
33 namespace gemm {
34 
36 
37 template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerSts_>
42  typedef Scalar_* Pointer;
46  typedef Threads_ Threads;
48  typedef Shape<0, ShapeCount<Tile>::kWc, Tile::kC, kScalarsPerSts_> ThreadsStrides;
50  static int const kSkew = 0;
52  static int const kAccessSize = kScalarsPerSts_;
55 
57  typedef Shape<1,
58  Tile::kH / Threads::kH,
59  Tile::kW / Threads::kW,
60  Tile::kC / Threads::kC / kAccessSize>
67 
68  struct ThreadOffset {
70  Coord<4> operator()() const {
72  return make_Coord(0, 0, offset, 0);
73  }
74  };
75 };
76 
78 
79 template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerSts_, int kSkew_>
84  typedef Scalar_* Pointer;
89  kScalarsPerSts_>::Tile Tile;
91  typedef Threads_ Threads;
93  static int const kSkew = kSkew_;
95  static int const kAccessSize = kScalarsPerSts_;
98 
100  typedef Shape<1, TileWithoutSkew::kH / Threads::kW, TileWithoutSkew::kW / Threads::kH> Iterations;
105 
106  struct ThreadOffset {
109  return make_Coord(0, 0, offset, 0);
110  }
111  };
112 
113  protected:
116 };
117 
119 
120 template <typename Scalar_,
121  typename OutputTile_,
122  typename Warps_,
123  typename ThreadsPerWarp_,
124  typename InstructionShape_,
125  int kStages_,
126  int kScalarsPerLds_,
127  int kSkew_ = 0>
133  typedef Scalar_* Pointer;
135  typedef Shape<kStages_,
136  OutputTile_::kD / InstructionShape_::kD,
137  GetExtent<kOperand, OutputTile_>::kExtent * InstructionShape_::kD>
146  typedef Warps_ Warps;
148  typedef ThreadsPerWarp_ ThreadsPerWarp;
150  // static int const kScalarsPerLds = kScalarsPerLds_;
151  static int const kAccessSize = kScalarsPerLds_;
153  static int const kSkew = kSkew_;
156 
161 
163  typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kScalarsPerLds*/>
169 
171  struct ThreadOffset {
173  // Extract the warp.
174  int const warp = threadIdx.x / kWarpSize;
175  // Extract the slice.
176  int const slice = warp / (Warps::kH * Warps::kW);
177  // Compute the row offset for each warp.
178  int const warp_row = warp % Warps::kW;
179  // Compute the row offset for each thread.
180  int const lane_row = (threadIdx.x & 0x0e) / 2;
181  // The offset.
182  int const offset =
183  slice * Tile::kW * Tile::kC + (warp_row * ThreadsPerWarp::kW + lane_row) * kAccessSize;
184  // Embed the offset in a 4D coordinate vector.
185  return make_Coord(0, 0, offset, 0);
186  }
187  };
188 };
189 
191 
192 template <typename Scalar_,
193  typename OutputTile_,
194  typename Warps_,
195  typename ThreadsPerWarp_,
196  typename InstructionShape_,
197  int kStages_,
198  int kScalarsPerLds_,
199  int kSkew_ = 0>
205  typedef Scalar_* Pointer;
207  typedef Shape<kStages_,
208  OutputTile_::kD / InstructionShape_::kD,
209  GetExtent<kOperand, OutputTile_>::kExtent * InstructionShape_::kD>
218  typedef Warps_ Warps;
220  typedef ThreadsPerWarp_ ThreadsPerWarp;
222  static int const kAccessSize = kScalarsPerLds_;
224  static int const kSkew = kSkew_;
227 
232 
234  typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kAccessSize*/> Iterations;
239 
241  struct ThreadOffset {
243  // Extract the warp.
244  int const warp = threadIdx.x / kWarpSize;
245  // Extract the slice.
246  int const slice = warp / (Warps::kH * Warps::kW);
247  // The warp in the slice.
248  int const warp_in_slice = warp % (Warps::kH * Warps::kW);
249  // Compute the row offset for each warp.
250  int const warp_col = warp_in_slice / Warps::kW;
251  // Compute the row offset for each thread.
252  int const lane_col = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01);
253  // The offset.
254  int const offset =
255  slice * Tile::kW * Tile::kC + (warp_col * ThreadsPerWarp::kH + lane_col) * kAccessSize;
256  // Embed the offset in a 4D coordinate.
257  return make_Coord(0, 0, offset, 0);
258  }
259  };
260 };
261 
263 
264 template <typename Scalar_,
265  typename OutputTile_,
266  typename Warps_,
267  typename ThreadsPerWarp_,
268  int kScalarsPerSts_,
269  int kSkew_ = 0>
274  typedef Scalar_* Pointer;
276  typedef OutputTile_ OutputTile;
278  typedef Warps_ Warps;
280  typedef ThreadsPerWarp_ ThreadsPerWarp;
282  static int const kAccessSize = kScalarsPerSts_;
284  static int const kSkew = kSkew_;
287 
289  static int const kScalarsPerThread = OutputTile_::kW / Warps::kW / ThreadsPerWarp::kW;
291  static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
293  static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew;
294 
303 
305  struct ThreadOffset {
307  // The warp.
308  int const warp = threadIdx.x / kWarpSize;
309 
310  // The position of the warp in the 2D tile.
311  int const warp_row = warp % Warps::kW;
312  int const warp_col = warp / Warps::kW;
313 
314  // We assume that the elements are distributed in a warps as 4 columns of 8 elements. The
315  // columns are stored in threads col0=[0, 2, 4, 6, 8, 10, 12, 14], col1=[1, 3, 5, 7, .., 15],
316  // col2=[16, 18, 20, ..., 30] and col3=[17, 19, ..., 31].
317  int hi_halfwarp_offset = ((threadIdx.x >> 4) & 0x1) * OutputTile::kW;
318  int lo_halfwarp_offset = ((threadIdx.x >> 1) & 0x7) + ThreadsPerWarp::kW * warp_row;
319 
320  // Odd threads go to the second half of shared memory.
321  int const row = threadIdx.x & 0x01;
322  int col = warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW +
323  lo_halfwarp_offset * kAccessSize + hi_halfwarp_offset;
324  // Embed the offset in a 4D coords.
325  return make_Coord(0, 0, row * kScalarsPerRow + col, 0);
326  }
327  };
328 };
329 
331 
332 template <typename Scalar_,
333  typename OutputTile_,
334  typename Warps_,
335  typename ThreadsPerWarp_,
336  int kTileH_,
337  int kScalarsPerLds_,
338  int kSkew_ = 0>
343  typedef Scalar_* Pointer;
345  typedef OutputTile_ OutputTile;
347  typedef Warps_ Warps;
349  typedef ThreadsPerWarp_ ThreadsPerWarp;
351  static int const kAccessSize = kScalarsPerLds_;
353  static int const kSkew = kSkew_;
356 
358  static int const kScalarsPerThread = OutputTile_::kW / Warps::kW / ThreadsPerWarp::kW;
360  static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
362  static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew;
363 
367 
368  // Compute the number of iterations per warp in the Tile::kH dimension.
369  static int const kIterationsInHPerWarp = kTileH_ / ShapeCount<Warps>::kCount;
370 
371  // As explained above, the shared memory tile is composed of 2 rows and each rows is made of
372  // kScalarsPerRow. A warp is expected to read from the 1st row, then move to the 2nd row and go
373  // back to the 1st row. To model that scheme we define the Iterations shape as Shape<X, 2, ...>.
374  // However, in some cases, we have only 1 iteration per warp. In that case, we must define the
375  // shape as Shape<1, 1, ...>. The following code does that except that we hijack the kH dimension
376  // to keep the number of elements to reduce for split-K.
377  static int const kIterationsH = kIterationsInHPerWarp == 1 ? 1 : 2;
378  // As soon as we know kIterationsH, it is trivial to compute kIterationsD:
380 
381  // If we have split-K enabled, we have to jump over the elements from the "odd/even" column of
382  // threads to grab the other elements.
383  static int const kSplitK = OutputTile::kW * ThreadsPerWarp::kH / 2 * Warps::kH;
384 
386  typedef Shape<kIterationsD, kIterationsH, OutputTile::kW / kWarpSize / kAccessSize, Warps::kD>
393 
395  struct ThreadOffset {
397  // Each warp works on a different column.
398  int const h = threadIdx.x / kWarpSize;
399  // Compute the row.
400  int const w = (threadIdx.x & (kWarpSize - 1)) * kAccessSize;
401  int offset = 0;
402  if (Iterations::kH == 1) {
403  int const row = h & 0x1;
404  int const col = h / 2;
405  offset = row * ShapeCount<Tile>::kWc + col * OutputTile::kW * Iterations::kD + w;
406  } else {
407  offset = h * OutputTile::kW * Iterations::kD + w;
408  }
409  return make_Coord(0, 0, offset, 0);
410  }
411  };
412 };
413 
415 
416 } // namespace gemm
417 } // namespace cutlass
static int const kAccessSize
The number of scalars per STS.
Definition: gemm_shared_tile.h:95
static CUTLASS_DEVICE int get()
Definition: shape.h:214
ReshapeTile< TileWithSkew, kScalarsPerLds_ >::Tile Tile
The tile.
Definition: gemm_shared_tile.h:216
ReshapeTile< TileWithSkew, kScalarsPerLds_ >::Tile Tile
The tile.
Definition: gemm_shared_tile.h:144
ReshapeTile< TileWithoutSkew_, kScalarsPerLds_ >::Tile TileWithoutSkew
The tile without skew after reshaping.
Definition: gemm_shared_tile.h:214
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:355
static int const kScalarsPerThread
The number of scalars per thread.
Definition: gemm_shared_tile.h:358
Definition: load_store.h:41
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:242
Shape< 1, 1, TileWithoutSkew::kW/kWarps/kThreadsPerWarp > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_shared_tile.h:234
Definition: convert.h:33
static int const kWarps
The number of warps.
Definition: gemm_shared_tile.h:229
Definition: gemm_shared_tile.h:128
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:42
static int const kScalarsPerRow
The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
Definition: gemm_shared_tile.h:293
T type
Definition: platform.h:377
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:131
Shape< 1, 1, kScalarsPerThread/kAccessSize > Iterations
The number of iterations needed to store the tile.
Definition: gemm_shared_tile.h:298
Shape< TileWithSkew::kW *Warps::kD, 0, kWarps *kThreadsPerWarp *kAccessSize, 0 > Delta
Definition: gemm_shared_tile.h:238
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:351
ThreadsPerWarp_ ThreadsPerWarp
The threads in a warp.
Definition: gemm_shared_tile.h:148
Definition: reshape_tile.h:42
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Shape< 0, ShapeCount< Tile >::kWc, Tile::kC, kScalarsPerSts_ > ThreadsStrides
The strides to compute the base position of the thread.
Definition: gemm_shared_tile.h:48
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:282
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:353
Warps_ Warps
The number of warps.
Definition: gemm_shared_tile.h:218
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:224
Definition: gemm_shared_tile.h:38
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:203
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_shared_tile.h:395
Definition: gemm_shared_tile.h:200
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:155
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:40
static GemmOperand::Kind const kOperand
Definition: gemm_shared_tile.h:129
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:286
Kind
Definition: load_store.h:39
Shape< kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW+kSkew_ > TileWithSkew
The tile with skew.
Definition: gemm_shared_tile.h:212
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:151
static int const kH
The height of the cube.
Definition: shape.h:68
Shape< 1, Tile::kH/Threads::kH, Tile::kW/Threads::kW, Tile::kC/Threads::kC/kAccessSize > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_shared_tile.h:61
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:93
Shape< 1, 1, TileWithoutSkew::kW/kWarps/kThreadsPerWarp > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_shared_tile.h:164
OutputTile_ OutputTile
The dimension of the output tile.
Definition: gemm_shared_tile.h:276
static int const kScalarsPerRow
The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
Definition: gemm_shared_tile.h:362
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:205
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:133
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:274
static int const kScalarsPerThread
The number of scalars per thread.
Definition: gemm_shared_tile.h:289
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:306
Shape< kIterationsD, kIterationsH, OutputTile::kW/kWarpSize/kAccessSize, Warps::kD > Iterations
The number of iterations needed to store the tile.
Definition: gemm_shared_tile.h:387
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:54
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:50
static int const kThreadsPerWarp
The number of threads in one dimension of the warp.
Definition: gemm_shared_tile.h:231
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_shared_tile.h:241
Shape< 0, ShapeCount< Tile >::kWc, Threads::kH *kAccessSize > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:104
Shape< 1, 2, kScalarsPerRow/kAccessSize, kAccessSize > Tile
The tile.
Definition: gemm_shared_tile.h:296
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:52
ReshapeTile< Tile_, kScalarsPerSts_ >::Tile Tile
The tile.
Definition: gemm_shared_tile.h:44
static int const kIterationsInHPerWarp
Definition: gemm_shared_tile.h:369
Shape< OutputTile::kW, kScalarsPerRow, kWarpSize *kAccessSize, kSplitK > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:390
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:284
ReshapeTile< TileWithoutSkew_, kScalarsPerLds_ >::Tile TileWithoutSkew
The tile without skew after reshaping.
Definition: gemm_shared_tile.h:142
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Shape< 0, Threads::kH *ShapeCount< Tile >::kWc, Threads::kW *kAccessSize > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:66
ReshapeTile< Shape< Tile_::kD, Tile_::kH, Tile_::kW+kSkew_ >, kScalarsPerSts_ >::Tile Tile
The tile.
Definition: gemm_shared_tile.h:89
Shape< 0, kScalarsPerSts_, ShapeCount< Tile >::kHwc/Threads::kW > ThreadsStrides
The strides to compute the base position of the thread.
Definition: gemm_shared_tile.h:115
ReshapeTile< Tile_, kScalarsPerSts_ >::Tile TileWithoutSkew
The tile without skews.
Definition: gemm_shared_tile.h:86
static int const kIterationsD
Definition: gemm_shared_tile.h:379
Shape< TileWithSkew::kW *Warps::kD, 0, kWarps *kThreadsPerWarp *kAccessSize, 0 > Delta
Definition: gemm_shared_tile.h:168
static int const kWarps
The number of warps.
Definition: gemm_shared_tile.h:158
Definition: matrix_traits.h:357
ThreadsPerWarp_ ThreadsPerWarp
The threads in the warps.
Definition: gemm_shared_tile.h:280
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_shared_tile.h:171
Shape< kStages_, OutputTile_::kD/InstructionShape_::kD, GetExtent< kOperand, OutputTile_ >::kExtent *InstructionShape_::kD > TileWithoutSkew_
The tile without skew.
Definition: gemm_shared_tile.h:138
Definition: gemm_shared_tile.h:339
Threads_ Threads
The threads.
Definition: gemm_shared_tile.h:91
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
OutputTile_ OutputTile
The dimension of the output tile.
Definition: gemm_shared_tile.h:345
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:82
Shape< 0, 0, Warps::kW *ThreadsPerWarp::kW *kAccessSize > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:302
static GemmOperand::Kind const kOperand
Definition: gemm_shared_tile.h:201
Shape< 1, 2, kScalarsPerRow/kAccessSize, kAccessSize > Tile
Definition: gemm_shared_tile.h:366
Shape< OutputTile::kW, kScalarsPerRow, kWarpSize *kAccessSize, kSplitK > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:392
static int const kThreadsPerWarp
The number of threads in one dimension of the warp.
Definition: gemm_shared_tile.h:160
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:84
static int const kSplitK
Definition: gemm_shared_tile.h:383
Shape< 1, TileWithoutSkew::kH/Threads::kW, TileWithoutSkew::kW/Threads::kH > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_shared_tile.h:100
Shape< kStages_, OutputTile_::kD/InstructionShape_::kD, GetExtent< kOperand, OutputTile_ >::kExtent *InstructionShape_::kD > TileWithoutSkew_
The tile without skew.
Definition: gemm_shared_tile.h:210
Threads_ Threads
The threads.
Definition: gemm_shared_tile.h:46
Definition: gemm_operand.h:50
Shape< TileWithSkew::kW *Warps::kD, 0, kWarps *kThreadsPerWarp *kAccessSize, 0 > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:237
Shape< 0, Threads::kH *ShapeCount< Tile >::kWc, Threads::kW *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:63
static int const kThreads
The number of threads.
Definition: gemm_shared_tile.h:360
Warps_ Warps
The number of warps.
Definition: gemm_shared_tile.h:146
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:97
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:226
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:70
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:172
static int const kD
The depth of the cube.
Definition: shape.h:66
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_shared_tile.h:305
Warps_ Warps
The warps in the tile.
Definition: gemm_shared_tile.h:347
Tile_ Tile
Definition: reshape_tile.h:43
Shape< 0, ShapeCount< Tile >::kWc, Threads::kH *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:102
static int const kIterationsH
Definition: gemm_shared_tile.h:377
Shape< 0, 0, Warps::kW *ThreadsPerWarp::kW *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:300
Kind
Definition: matrix_traits.h:357
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:153
ThreadsPerWarp_ ThreadsPerWarp
The threads in the warps.
Definition: gemm_shared_tile.h:349
Definition: matrix_traits.h:357
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:343
static int const kThreads
The number of threads.
Definition: gemm_shared_tile.h:291
Shape< TileWithSkew::kW *Warps::kD, 0, kWarps *kThreadsPerWarp *kAccessSize, 0 > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:167
ThreadsPerWarp_ ThreadsPerWarp
The threads in a warp.
Definition: gemm_shared_tile.h:220
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:272
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
Shape< kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW+kSkew_ > TileWithSkew
The tile with skew.
Definition: gemm_shared_tile.h:140
Warps_ Warps
The warps in the tile.
Definition: gemm_shared_tile.h:278
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:107
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:341
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:396
Definition: gemm_shared_tile.h:270
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:222