40 template <enum swizzleDirection::Kind>
43 return blockIdx.y * gridDim.x + blockIdx.x;
46 CUTLASS_DEVICE
int getLinearIdx<swizzleDirection::Boustrophedon>(
int groups) {
48 if ((blockIdx.y / groups) % 2 == 1)
49 return blockIdx.y * gridDim.x + (gridDim.x - blockIdx.x - 1);
51 return blockIdx.y * gridDim.x + blockIdx.x;
70 CUTLASS_DEVICE dim3
swizzle() {
return blockIdx; }
77 grid.x = (problem_size.
m() + OutputTile[2] - 1) / OutputTile[2];
78 grid.y = (problem_size.
n() + OutputTile[1] - 1) / OutputTile[1];
79 grid.z = problem_size.
batch();
87 make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
88 return threadblock_offset;
107 int partitionK_range) {
112 return problem_size.
knm();
114 return make_Coord(partitionK_range, problem_size.
n(), problem_size.
m());
200 template <
int groupCols, enum swizzleDirection::Kind swDirection>
207 assert(gridDim.z == 1);
208 int linearIdx = getLinearIdx<swDirection>(groupCols);
209 dim3 swizzledBlockIdx;
210 int currGroupCols = groupCols;
211 int prevGroupCols = groupCols;
213 if ((gridDim.y % groupCols != 0) && ((blockIdx.y + (gridDim.y % groupCols)) >= gridDim.y)) {
215 currGroupCols = gridDim.y % groupCols;
218 swizzledBlockIdx.x = (linearIdx / currGroupCols) % gridDim.x;
220 linearIdx % currGroupCols + prevGroupCols * (linearIdx / (prevGroupCols * gridDim.x));
221 swizzledBlockIdx.z = blockIdx.z;
223 return swizzledBlockIdx;
230 grid.x = (problem_size.
m() + OutputTile[2] - 1) / OutputTile[2];
231 grid.y = (problem_size.
n() + OutputTile[1] - 1) / OutputTile[1];
232 grid.z = problem_size.
batch();
240 make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
241 return threadblock_offset;
260 int partitionK_range) {
265 return problem_size.
knm();
267 return make_Coord(partitionK_range, problem_size.
n(), problem_size.
m());
373 template <
int groupRows, enum swizzleDirection::Kind swDirection>
380 assert(gridDim.z == 1);
381 int linearIdx = getLinearIdx<swDirection>(groupRows);
382 dim3 swizzledBlockIdx;
383 int currGroupRows = groupRows;
384 int prevGroupRows = groupRows;
386 if ((gridDim.y % groupRows != 0) && ((blockIdx.y + (gridDim.y % groupRows)) >= gridDim.y)) {
388 currGroupRows = gridDim.y % groupRows;
392 linearIdx % currGroupRows + prevGroupRows * (linearIdx / (prevGroupRows * gridDim.x));
393 swizzledBlockIdx.y = (linearIdx / currGroupRows) % gridDim.x;
394 swizzledBlockIdx.z = blockIdx.z;
396 return swizzledBlockIdx;
403 grid.x = (problem_size.
n() + OutputTile[1] - 1) / OutputTile[1];
404 grid.y = (problem_size.
m() + OutputTile[2] - 1) / OutputTile[2];
405 grid.z = problem_size.
batch();
413 make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
414 return threadblock_offset;
433 int partitionK_range) {
438 return problem_size.
knm();
440 return make_Coord(partitionK_range, problem_size.
n(), problem_size.
m());
CUTLASS_DEVICE Coord< 3 > get_threadblock_bounds(GemmCoord const &problem_size, int partitionK_range)
Definition: gemm/threadblock_swizzle.h:432
CUTLASS_HOST_DEVICE Coord< 3 > knm() const
Obtains a Coord<3> from GemmCoord.
Definition: gemm_coord.h:121
Definition: gemm/threadblock_swizzle.h:37
CUTLASS_HOST_DEVICE IdentityBlockSwizzle()
Ctor. aka ColumnMajorBlockSwizzle<1>
Definition: gemm/threadblock_swizzle.h:67
CUTLASS_DEVICE int get_batch_id()
Definition: gemm/threadblock_swizzle.h:92
CUTLASS_DEVICE Coord< 3 > get_threadblock_offset(Coord< 3 > const &OutputTile)
Definition: gemm/threadblock_swizzle.h:237
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: gemm_coord.h:97
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
CUTLASS_DEVICE bool is_last_partition()
check if at the last partition
Definition: gemm/threadblock_swizzle.h:424
Definition: gemm_coord.h:43
CUTLASS_DEVICE Coord< 3 > get_threadblock_bounds(GemmCoord const &problem_size, int partitionK_range)
Definition: gemm/threadblock_swizzle.h:106
Definition: gemm/threadblock_swizzle.h:201
CUTLASS_DEVICE Coord< 3 > get_threadblock_offset(Coord< 3 > const &OutputTile)
Definition: gemm/threadblock_swizzle.h:410
CUTLASS_HOST_DEVICE RowMajorBlockSwizzle()
Ctor.
Definition: gemm/threadblock_swizzle.h:376
CUTLASS_DEVICE int getLinearIdx(int groups)
Definition: gemm/threadblock_swizzle.h:41
CUTLASS_DEVICE bool is_last_partition()
check if at the last partition
Definition: gemm/threadblock_swizzle.h:98
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, Coord< 3 > const &OutputTile)
Definition: gemm/threadblock_swizzle.h:227
CUTLASS_DEVICE bool is_last_partition()
check if at the last partition
Definition: gemm/threadblock_swizzle.h:251
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, Coord< 3 > const &OutputTile)
Definition: gemm/threadblock_swizzle.h:400
CUTLASS_DEVICE dim3 swizzle()
Swizzle the block index.
Definition: gemm/threadblock_swizzle.h:70
CUTLASS_DEVICE int get_batch_id()
Definition: gemm/threadblock_swizzle.h:245
Definition: gemm/threadblock_swizzle.h:65
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_DEVICE Coord< 3 > get_threadblock_offset(Coord< 3 > const &OutputTile)
get threadblock offset, without considering tha batch dim
Definition: gemm/threadblock_swizzle.h:84
CUTLASS_DEVICE int get_batch_id()
Definition: gemm/threadblock_swizzle.h:418
CUTLASS_HOST_DEVICE ColumnMajorBlockSwizzle()
Ctor.
Definition: gemm/threadblock_swizzle.h:203
CUTLASS_DEVICE dim3 swizzle()
Swizzle the block index.
Definition: gemm/threadblock_swizzle.h:206
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: gemm_coord.h:89
CUTLASS_DEVICE dim3 swizzle()
Swizzle the block index.
Definition: gemm/threadblock_swizzle.h:379
Definition: gemm/threadblock_swizzle.h:37
CUTLASS_DEVICE Coord< 3 > get_threadblock_bounds(GemmCoord const &problem_size, int partitionK_range)
Definition: gemm/threadblock_swizzle.h:259
CUTLASS_HOST_DEVICE Index const & batch() const
Returns the GEMM batch coordinate.
Definition: gemm_coord.h:113
Kind
Definition: gemm/threadblock_swizzle.h:37
Definition: gemm/threadblock_swizzle.h:36
Definition: gemm/threadblock_swizzle.h:374
GemmCoord is a structure derived from Coord<4> that specifies a location within the coordinate system...
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, Coord< 3 > const &OutputTile)
Definition: gemm/threadblock_swizzle.h:73