40 template <enum swizzleDirection::Kind>
43 return blockIdx.y * gridDim.x + blockIdx.x;
46 CUTLASS_DEVICE
int getLinearIdx<swizzleDirection::Boustrophedon>(
int groups) {
48 if ((blockIdx.y / groups) % 2 == 1)
49 return blockIdx.y * gridDim.x + (gridDim.x - blockIdx.x - 1);
51 return blockIdx.y * gridDim.x + blockIdx.x;
70 CUTLASS_DEVICE dim3
swizzle() {
return blockIdx; }
77 grid.x = (problem_size.
m() + OutputTile[2] - 1) / OutputTile[2];
78 grid.y = (problem_size.
n() + OutputTile[1] - 1) / OutputTile[1];
79 grid.z = problem_size.
batch();
87 make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
88 return threadblock_offset;
180 template <
int groupCols, enum swizzleDirection::Kind swDirection>
187 assert(gridDim.z == 1);
188 int linearIdx = getLinearIdx<swDirection>(groupCols);
189 dim3 swizzledBlockIdx;
190 int currGroupCols = groupCols;
191 int prevGroupCols = groupCols;
193 if ((gridDim.y % groupCols != 0) && ((blockIdx.y + (gridDim.y % groupCols)) >= gridDim.y)) {
195 currGroupCols = gridDim.y % groupCols;
198 swizzledBlockIdx.x = (linearIdx / currGroupCols) % gridDim.x;
200 linearIdx % currGroupCols + prevGroupCols * (linearIdx / (prevGroupCols * gridDim.x));
201 swizzledBlockIdx.z = blockIdx.z;
203 return swizzledBlockIdx;
210 grid.x = (problem_size.
m() + OutputTile[2] - 1) / OutputTile[2];
211 grid.y = (problem_size.
n() + OutputTile[1] - 1) / OutputTile[1];
212 grid.z = problem_size.
batch();
220 make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
221 return threadblock_offset;
333 template <
int groupRows, enum swizzleDirection::Kind swDirection>
340 assert(gridDim.z == 1);
341 int linearIdx = getLinearIdx<swDirection>(groupRows);
342 dim3 swizzledBlockIdx;
343 int currGroupRows = groupRows;
344 int prevGroupRows = groupRows;
346 if ((gridDim.y % groupRows != 0) && ((blockIdx.y + (gridDim.y % groupRows)) >= gridDim.y)) {
348 currGroupRows = gridDim.y % groupRows;
352 linearIdx % currGroupRows + prevGroupRows * (linearIdx / (prevGroupRows * gridDim.x));
353 swizzledBlockIdx.y = (linearIdx / currGroupRows) % gridDim.x;
354 swizzledBlockIdx.z = blockIdx.z;
356 return swizzledBlockIdx;
363 grid.x = (problem_size.
n() + OutputTile[1] - 1) / OutputTile[1];
364 grid.y = (problem_size.
m() + OutputTile[2] - 1) / OutputTile[2];
365 grid.z = problem_size.
batch();
373 make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
374 return threadblock_offset;
Definition: threadblock_swizzle.h:37
CUTLASS_HOST_DEVICE IdentityBlockSwizzle()
Ctor. aka ColumnMajorBlockSwizzle<1>
Definition: threadblock_swizzle.h:67
CUTLASS_DEVICE int get_batch_id()
Definition: threadblock_swizzle.h:92
CUTLASS_DEVICE Coord< 3 > get_threadblock_offset(Coord< 3 > const &OutputTile)
Definition: threadblock_swizzle.h:217
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: gemm_coord.h:97
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:318
Definition: gemm_coord.h:43
Definition: threadblock_swizzle.h:181
CUTLASS_DEVICE Coord< 3 > get_threadblock_offset(Coord< 3 > const &OutputTile)
Definition: threadblock_swizzle.h:370
CUTLASS_HOST_DEVICE RowMajorBlockSwizzle()
Ctor.
Definition: threadblock_swizzle.h:336
CUTLASS_DEVICE int getLinearIdx(int groups)
Definition: threadblock_swizzle.h:41
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, Coord< 3 > const &OutputTile)
Definition: threadblock_swizzle.h:207
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, Coord< 3 > const &OutputTile)
Definition: threadblock_swizzle.h:360
CUTLASS_DEVICE dim3 swizzle()
Swizzle the block index.
Definition: threadblock_swizzle.h:70
CUTLASS_DEVICE int get_batch_id()
Definition: threadblock_swizzle.h:225
Definition: threadblock_swizzle.h:65
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_DEVICE Coord< 3 > get_threadblock_offset(Coord< 3 > const &OutputTile)
Definition: threadblock_swizzle.h:84
CUTLASS_DEVICE int get_batch_id()
Definition: threadblock_swizzle.h:378
CUTLASS_HOST_DEVICE ColumnMajorBlockSwizzle()
Ctor.
Definition: threadblock_swizzle.h:183
CUTLASS_DEVICE dim3 swizzle()
Swizzle the block index.
Definition: threadblock_swizzle.h:186
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: gemm_coord.h:89
CUTLASS_DEVICE dim3 swizzle()
Swizzle the block index.
Definition: threadblock_swizzle.h:339
Definition: threadblock_swizzle.h:37
CUTLASS_HOST_DEVICE Index const & batch() const
Returns the GEMM batch coordinate.
Definition: gemm_coord.h:113
Kind
Definition: threadblock_swizzle.h:37
Definition: threadblock_swizzle.h:36
Definition: threadblock_swizzle.h:334
GemmCoord is a structure derived from Coord<4> that specifies a location within the coordinate system...
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, Coord< 3 > const &OutputTile)
Definition: threadblock_swizzle.h:73