Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm/threadblock_swizzle.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/coord.h"
32 
33 namespace cutlass {
34 namespace gemm {
35 
38 };
39 // helper template function
40 template <enum swizzleDirection::Kind>
41 CUTLASS_DEVICE int getLinearIdx(int groups) {
42  // groupCols is not needed for OneDirection Swizzle
43  return blockIdx.y * gridDim.x + blockIdx.x;
44 }
45 template <>
46 CUTLASS_DEVICE int getLinearIdx<swizzleDirection::Boustrophedon>(int groups) {
47  // reverse blockIdx.x for some columns
48  if ((blockIdx.y / groups) % 2 == 1)
49  return blockIdx.y * gridDim.x + (gridDim.x - blockIdx.x - 1);
50  else
51  return blockIdx.y * gridDim.x + blockIdx.x;
52 }
54 
68 
70  CUTLASS_DEVICE dim3 swizzle() { return blockIdx; }
71 
73  CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size,
74  Coord<3> const &OutputTile) {
75  /*OutputTile and problem_size are both in KNM order*/
76  dim3 grid;
77  grid.x = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
78  grid.y = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
79  grid.z = problem_size.batch();
80  return grid;
81  }
82 
84  CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
85  dim3 block = swizzle();
86  Coord<3> threadblock_offset =
87  make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
88  return threadblock_offset;
89  }
90 
92  CUTLASS_DEVICE int get_batch_id() {
93  dim3 block = swizzle();
94  return block.z;
95  }
96 
98  CUTLASS_DEVICE bool is_last_partition() {
99  if (get_batch_id() == (gridDim.z - 1))
100  return true;
101  else
102  return false;
103  }
104 
106  CUTLASS_DEVICE Coord<3> get_threadblock_bounds(GemmCoord const &problem_size,
107  int partitionK_range) {
108  // every partition except the last one has a smaller range
109  // partitionK_range is the bounds for every partition except the last one
110  // the last partition's bounds is the same with problem size
111  if(is_last_partition())
112  return problem_size.knm();
113  else
114  return make_Coord(partitionK_range, problem_size.n(), problem_size.m());
115  }
116 };
117 
119 
120 /*
121 ColumnMajorBlockSwizzle<1, OneDirection> is equivalent with IdentityBlockSwizzle
122 groupCols has the effect of controlling the schedulling of thread blocks
123 settings with different groupCols can contribute to the overall performance by affecting L2 cache
124 hit rate
125 
126 consider a regular thread block mapping btween matrix C and different thread blocks
127 note that C is column major, and the leading dimension of thread block id is blockIdx.x
128 
129 let's look at an example where gridIdx.x = 6, gridIdx.y = 7, gridIdx.z = 1
130 (blockIdx.x, blockIdx.y)
131 mapping between threadblockID and C matrix:
132 -------------------------------------------------------
133 (0,0) | (0,1) | (0,2) | (0,3) | (0,4) | (0,5) | (0,6) |
134 -------------------------------------------------------
135 (1,0) | (1,1) | (1,2) | (1,3) | (1,4) | (1,5) | (1,6) |
136 -------------------------------------------------------
137 (2,0) | (2,1) | (2,2) | (2,3) | (2,4) | (2,5) | (2,6) |
138 -------------------------------------------------------
139 (3,0) | (3,1) | (3,2) | (3,3) | (3,4) | (3,5) | (3,6) |
140 -------------------------------------------------------
141 (4,0) | (4,1) | (4,2) | (4,3) | (4,4) | (4,5) | (4,6) |
142 -------------------------------------------------------
143 (5,0) | (5,1) | (5,2) | (5,3) | (5,4) | (5,5) | (5,6) |
144 -------------------------------------------------------
145 
146 A ColumnMajorBlockSwizzle<1, OneDirection> will imply the above order where threadblocks are
147 launched in a column major
148 
149 A ColumnMajorBlockSwizzle<2, OneDirection> swizzles things a little,
150 -------------------------------------------------------
151 (0,0) | (3,0) | (0,2) | (3,2) | (0,4) | (3,4) | (0,6) |
152 -------------------------------------------------------
153 (0,1) | (3,1) | (0,3) | (3,3) | (0,5) | (3,5) | (1,6) |
154 -------------------------------------------------------
155 (1,0) | (4,0) | (1,2) | (4,2) | (1,4) | (4,4) | (2,6) |
156 -------------------------------------------------------
157 (1,1) | (4,1) | (1,3) | (4,3) | (1,5) | (4,5) | (3,6) |
158 -------------------------------------------------------
159 (2,0) | (5,0) | (2,2) | (5,2) | (2,4) | (5,4) | (4,6) |
160 -------------------------------------------------------
161 (2,1) | (5,1) | (2,3) | (5,3) | (2,5) | (5,5) | (5,6) |
162 -------------------------------------------------------
163 
164 so in memory, it would apprear that we work on 2 columns at a time rather than 1
165 Note that the index here really represent how each block maps to memory
166 
167 A ColumnMajorBlockSwizzle<1, Boustrophedon> is similar to ColumnMajorBlockSwizzle<1, OneDirection>
168 except that every column flips the ordering against the previous one
169 -------------------------------------------------------
170 (0,0) | (5,1) | (0,2) | (5,3) | (0,4) | (5,5) | (0,6) |
171 -------------------------------------------------------
172 (1,0) | (4,1) | (1,2) | (4,3) | (1,4) | (4,5) | (1,6) |
173 -------------------------------------------------------
174 (2,0) | (3,1) | (2,2) | (3,3) | (2,4) | (3,5) | (2,6) |
175 -------------------------------------------------------
176 (3,0) | (2,1) | (3,2) | (2,3) | (3,4) | (2,5) | (3,6) |
177 -------------------------------------------------------
178 (4,0) | (1,1) | (4,2) | (1,3) | (4,4) | (1,5) | (4,6) |
179 -------------------------------------------------------
180 (5,0) | (0,1) | (5,2) | (0,3) | (5,4) | (0,5) | (5,6) |
181 -------------------------------------------------------
182 
183 similarily, A ColumnMajorBlockSwizzle<2, Boustrophedon> looks like
184 -------------------------------------------------------
185 (0,0) | (3,0) | (2,3) | (5,3) | (0,4) | (3,4) | (5,6) |
186 -------------------------------------------------------
187 (0,1) | (3,1) | (2,2) | (5,2) | (0,5) | (3,5) | (4,6) |
188 -------------------------------------------------------
189 (1,0) | (4,0) | (1,3) | (4,3) | (1,4) | (4,4) | (3,6) |
190 -------------------------------------------------------
191 (1,1) | (4,1) | (1,2) | (4,2) | (1,5) | (4,5) | (2,6) |
192 -------------------------------------------------------
193 (2,0) | (5,0) | (0,3) | (3,3) | (2,4) | (5,4) | (1,6) |
194 -------------------------------------------------------
195 (2,1) | (5,1) | (0,2) | (3,2) | (2,5) | (5,5) | (0,6) |
196 -------------------------------------------------------
197 
198 */
199 
200 template <int groupCols, enum swizzleDirection::Kind swDirection>
204 
206  CUTLASS_DEVICE dim3 swizzle() {
207  assert(gridDim.z == 1);
208  int linearIdx = getLinearIdx<swDirection>(groupCols);
209  dim3 swizzledBlockIdx;
210  int currGroupCols = groupCols;
211  int prevGroupCols = groupCols;
212 
213  if ((gridDim.y % groupCols != 0) && ((blockIdx.y + (gridDim.y % groupCols)) >= gridDim.y)) {
214  // last colmuns if gridDim.y is not divisble by groupCols
215  currGroupCols = gridDim.y % groupCols;
216  }
217 
218  swizzledBlockIdx.x = (linearIdx / currGroupCols) % gridDim.x;
219  swizzledBlockIdx.y =
220  linearIdx % currGroupCols + prevGroupCols * (linearIdx / (prevGroupCols * gridDim.x));
221  swizzledBlockIdx.z = blockIdx.z;
222 
223  return swizzledBlockIdx;
224  }
225 
228  Coord<3> const &OutputTile) {
229  dim3 grid;
230  grid.x = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
231  grid.y = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
232  grid.z = problem_size.batch();
233  return grid;
234  }
235 
237  CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
238  dim3 block = swizzle();
239  Coord<3> threadblock_offset =
240  make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
241  return threadblock_offset;
242  }
243 
245  CUTLASS_DEVICE int get_batch_id() {
246  dim3 block = swizzle();
247  return block.z;
248  }
249 
251  CUTLASS_DEVICE bool is_last_partition() {
252  if (get_batch_id() == (gridDim.z - 1))
253  return true;
254  else
255  return false;
256  }
257 
259  CUTLASS_DEVICE Coord<3> get_threadblock_bounds(GemmCoord const &problem_size,
260  int partitionK_range) {
261  // every partition except the last one has a smaller range
262  // partitionK_range is the bounds for every partition except the last one
263  // the last partition's bounds is the same with problem size
264  if (is_last_partition())
265  return problem_size.knm();
266  else
267  return make_Coord(partitionK_range, problem_size.n(), problem_size.m());
268  }
269 };
270 
272 
273 /*
274 
275 consider a regular thread block mapping btween matrix C and different thread blocks
276 note that C is column major, and the leading dimension of thread block id is blockIdx.x
277 
278 let's look at an example where gridIdx.x = 6, gridIdx.y = 7, gridIdx.z = 1
279 (blockIdx.x, blockIdx.y)
280 mapping between threadblockID and C matrix:
281 -------------------------------------------------------
282 (0,0) | (0,1) | (0,2) | (0,3) | (0,4) | (0,5) | (0,6) |
283 -------------------------------------------------------
284 (1,0) | (1,1) | (1,2) | (1,3) | (1,4) | (1,5) | (1,6) |
285 -------------------------------------------------------
286 (2,0) | (2,1) | (2,2) | (2,3) | (2,4) | (2,5) | (2,6) |
287 -------------------------------------------------------
288 (3,0) | (3,1) | (3,2) | (3,3) | (3,4) | (3,5) | (3,6) |
289 -------------------------------------------------------
290 (4,0) | (4,1) | (4,2) | (4,3) | (4,4) | (4,5) | (4,6) |
291 -------------------------------------------------------
292 (5,0) | (5,1) | (5,2) | (5,3) | (5,4) | (5,5) | (5,6) |
293 -------------------------------------------------------
294 
295 A RowMajorBlockSwizzle<1, OneDirection> will effectively transpose the map
296 
297 -----------------------------------------------
298 (0,0) | (1,0) | (2,0) | (3,0) | (4,0) | (5,0) |
299 -----------------------------------------------
300 (0,1) | (1,1) | (2,1) | (3,1) | (4,1) | (5,1) |
301 -----------------------------------------------
302 (0,2) | (1,2) | (2,2) | (3,2) | (4,2) | (5,2) |
303 -----------------------------------------------
304 (0,3) | (1,3) | (2,3) | (3,3) | (4,3) | (5,3) |
305 -----------------------------------------------
306 (0,4) | (1,4) | (2,4) | (3,4) | (4,4) | (5,4) |
307 ---------------------------------------------
308 (0,5) | (1,5) | (2,5) | (3,5) | (4,5) | (5,5) |
309 -----------------------------------------------
310 (0,6) | (1,6) | (2,6) | (3,6) | (4,6) | (5,6) |
311 -----------------------------------------------
312 
313 It would aprear in memory we are working on 1 row at a time
314 
315 A ColumnMajorBlockSwizzle<2, OneDirection> swizzles things a little bit more
316 -----------------------------------------------
317 (0,0) | (1,3) | (2,0) | (3,3) | (4,0) | (5,3) |
318 -----------------------------------------------
319 (1,0) | (0,4) | (3,0) | (2,4) | (5,0) | (4,4) |
320 -----------------------------------------------
321 (0,1) | (1,4) | (2,1) | (3,4) | (4,1) | (5,4) |
322 -----------------------------------------------
323 (1,1) | (0,5) | (3,1) | (2,5) | (5,1) | (4,5) |
324 -----------------------------------------------
325 (0,2) | (1,5) | (2,2) | (3,5) | (4,2) | (5,5) |
326 ---------------------------------------------
327 (1,2) | (0,6) | (3,2) | (2,6) | (5,2) | (4,6) |
328 -----------------------------------------------
329 (0,3) | (1,6) | (2,3) | (3,6) | (4,3) | (5,6) |
330 -----------------------------------------------
331 
332 so in memory, it would apprear that we work on 2 rows at a time rather than 1 row
333 Note that the index here really represent how each block maps to memory
334 
335 A RowMajorBlockSwizzle<1, Boustrophedon> is similar to RowMajorBlockSwizzle<1, OneDirection>
336 except that every column flips the ordering against the previous one
337 
338 -----------------------------------------------
339 (0,0) | (1,6) | (2,0) | (3,6) | (4,0) | (5,6) |
340 -----------------------------------------------
341 (0,1) | (1,5) | (2,1) | (3,5) | (4,1) | (5,5) |
342 -----------------------------------------------
343 (0,2) | (1,4) | (2,2) | (3,4) | (4,2) | (5,4) |
344 -----------------------------------------------
345 (0,3) | (1,3) | (2,3) | (3,3) | (4,3) | (5,3) |
346 -----------------------------------------------
347 (0,4) | (1,2) | (2,4) | (3,2) | (4,4) | (5,2) |
348 ---------------------------------------------
349 (0,5) | (1,1) | (2,5) | (3,1) | (4,5) | (5,1) |
350 -----------------------------------------------
351 (0,6) | (1,0) | (2,6) | (3,0) | (4,6) | (5,0) |
352 -----------------------------------------------
353 
354 similarily, A RowMajorBlockSwizzle<2, Boustrophedon> looks like
355 -----------------------------------------------
356 (0,0) | (1,3) | (2,3) | (3,6) | (4,0) | (5,3) |
357 -----------------------------------------------
358 (1,0) | (0,4) | (3,2) | (2,6) | (5,0) | (4,4) |
359 -----------------------------------------------
360 (0,1) | (1,4) | (2,2) | (3,5) | (4,1) | (5,4) |
361 -----------------------------------------------
362 (1,1) | (0,5) | (3,1) | (2,5) | (5,1) | (4,5) |
363 -----------------------------------------------
364 (0,2) | (1,5) | (2,1) | (3,4) | (4,2) | (5,5) |
365 ---------------------------------------------
366 (1,2) | (0,6) | (3,0) | (2,4) | (5,2) | (4,6) |
367 -----------------------------------------------
368 (0,3) | (1,6) | (2,0) | (3,3) | (4,3) | (5,6) |
369 -----------------------------------------------
370 
371 */
372 
373 template <int groupRows, enum swizzleDirection::Kind swDirection>
377 
379  CUTLASS_DEVICE dim3 swizzle() {
380  assert(gridDim.z == 1);
381  int linearIdx = getLinearIdx<swDirection>(groupRows);
382  dim3 swizzledBlockIdx;
383  int currGroupRows = groupRows;
384  int prevGroupRows = groupRows;
385 
386  if ((gridDim.y % groupRows != 0) && ((blockIdx.y + (gridDim.y % groupRows)) >= gridDim.y)) {
387  // last columns
388  currGroupRows = gridDim.y % groupRows;
389  }
390 
391  swizzledBlockIdx.x =
392  linearIdx % currGroupRows + prevGroupRows * (linearIdx / (prevGroupRows * gridDim.x));
393  swizzledBlockIdx.y = (linearIdx / currGroupRows) % gridDim.x;
394  swizzledBlockIdx.z = blockIdx.z;
395 
396  return swizzledBlockIdx;
397  }
398 
401  Coord<3> const &OutputTile) {
402  dim3 grid;
403  grid.x = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
404  grid.y = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
405  grid.z = problem_size.batch();
406  return grid;
407  }
408 
410  CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
411  dim3 block = swizzle();
412  Coord<3> threadblock_offset =
413  make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
414  return threadblock_offset;
415  }
416 
418  CUTLASS_DEVICE int get_batch_id() {
419  dim3 block = swizzle();
420  return block.z;
421  }
422 
424  CUTLASS_DEVICE bool is_last_partition() {
425  if (get_batch_id() == (gridDim.z - 1) )
426  return true;
427  else
428  return false;
429  }
430 
432  CUTLASS_DEVICE Coord<3> get_threadblock_bounds(GemmCoord const &problem_size,
433  int partitionK_range) {
434  // every partition except the last one has a smaller range
435  // partitionK_range is the bounds for every partition except the last one
436  // the last partition's bounds is the same with problem size
437  if (is_last_partition())
438  return problem_size.knm();
439  else
440  return make_Coord(partitionK_range, problem_size.n(), problem_size.m());
441  }
442 };
443 
445 
446 } // namespace gemm
447 } // namespace cutlass
CUTLASS_DEVICE Coord< 3 > get_threadblock_bounds(GemmCoord const &problem_size, int partitionK_range)
Definition: gemm/threadblock_swizzle.h:432
Definition: convert.h:33
CUTLASS_HOST_DEVICE Coord< 3 > knm() const
Obtains a Coord<3> from GemmCoord.
Definition: gemm_coord.h:121
Definition: gemm/threadblock_swizzle.h:37
CUTLASS_HOST_DEVICE IdentityBlockSwizzle()
Ctor. aka ColumnMajorBlockSwizzle<1>
Definition: gemm/threadblock_swizzle.h:67
CUTLASS_DEVICE int get_batch_id()
Definition: gemm/threadblock_swizzle.h:92
CUTLASS_DEVICE Coord< 3 > get_threadblock_offset(Coord< 3 > const &OutputTile)
Definition: gemm/threadblock_swizzle.h:237
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: gemm_coord.h:97
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
CUTLASS_DEVICE bool is_last_partition()
check if at the last partition
Definition: gemm/threadblock_swizzle.h:424
Definition: gemm_coord.h:43
CUTLASS_DEVICE Coord< 3 > get_threadblock_bounds(GemmCoord const &problem_size, int partitionK_range)
Definition: gemm/threadblock_swizzle.h:106
Definition: gemm/threadblock_swizzle.h:201
CUTLASS_DEVICE Coord< 3 > get_threadblock_offset(Coord< 3 > const &OutputTile)
Definition: gemm/threadblock_swizzle.h:410
CUTLASS_HOST_DEVICE RowMajorBlockSwizzle()
Ctor.
Definition: gemm/threadblock_swizzle.h:376
CUTLASS_DEVICE int getLinearIdx(int groups)
Definition: gemm/threadblock_swizzle.h:41
CUTLASS_DEVICE bool is_last_partition()
check if at the last partition
Definition: gemm/threadblock_swizzle.h:98
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, Coord< 3 > const &OutputTile)
Definition: gemm/threadblock_swizzle.h:227
CUTLASS_DEVICE bool is_last_partition()
check if at the last partition
Definition: gemm/threadblock_swizzle.h:251
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, Coord< 3 > const &OutputTile)
Definition: gemm/threadblock_swizzle.h:400
CUTLASS_DEVICE dim3 swizzle()
Swizzle the block index.
Definition: gemm/threadblock_swizzle.h:70
CUTLASS_DEVICE int get_batch_id()
Definition: gemm/threadblock_swizzle.h:245
Definition: gemm/threadblock_swizzle.h:65
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_DEVICE Coord< 3 > get_threadblock_offset(Coord< 3 > const &OutputTile)
get threadblock offset, without considering tha batch dim
Definition: gemm/threadblock_swizzle.h:84
CUTLASS_DEVICE int get_batch_id()
Definition: gemm/threadblock_swizzle.h:418
CUTLASS_HOST_DEVICE ColumnMajorBlockSwizzle()
Ctor.
Definition: gemm/threadblock_swizzle.h:203
CUTLASS_DEVICE dim3 swizzle()
Swizzle the block index.
Definition: gemm/threadblock_swizzle.h:206
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: gemm_coord.h:89
CUTLASS_DEVICE dim3 swizzle()
Swizzle the block index.
Definition: gemm/threadblock_swizzle.h:379
Definition: gemm/threadblock_swizzle.h:37
CUTLASS_DEVICE Coord< 3 > get_threadblock_bounds(GemmCoord const &problem_size, int partitionK_range)
Definition: gemm/threadblock_swizzle.h:259
CUTLASS_HOST_DEVICE Index const & batch() const
Returns the GEMM batch coordinate.
Definition: gemm_coord.h:113
Kind
Definition: gemm/threadblock_swizzle.h:37
Definition: gemm/threadblock_swizzle.h:36
Definition: gemm/threadblock_swizzle.h:374
GemmCoord is a structure derived from Coord<4> that specifies a location within the coordinate system...
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, Coord< 3 > const &OutputTile)
Definition: gemm/threadblock_swizzle.h:73