Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
threadblock_swizzle.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/coord.h"
32 
33 namespace cutlass {
34 namespace gemm {
35 
38 };
39 // helper template function
40 template <enum swizzleDirection::Kind>
41 CUTLASS_DEVICE int getLinearIdx(int groups) {
42  // groupCols is not needed for OneDirection Swizzle
43  return blockIdx.y * gridDim.x + blockIdx.x;
44 }
45 template <>
46 CUTLASS_DEVICE int getLinearIdx<swizzleDirection::Boustrophedon>(int groups) {
47  // reverse blockIdx.x for some columns
48  if ((blockIdx.y / groups) % 2 == 1)
49  return blockIdx.y * gridDim.x + (gridDim.x - blockIdx.x - 1);
50  else
51  return blockIdx.y * gridDim.x + blockIdx.x;
52 }
54 
68 
70  CUTLASS_DEVICE dim3 swizzle() { return blockIdx; }
71 
73  CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size,
74  Coord<3> const &OutputTile) {
75  /*OutputTile and problem_size are both in KNM order*/
76  dim3 grid;
77  grid.x = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
78  grid.y = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
79  grid.z = problem_size.batch();
80  return grid;
81  }
82 
84  CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
85  dim3 block = swizzle();
86  Coord<3> threadblock_offset =
87  make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
88  return threadblock_offset;
89  }
90 
92  CUTLASS_DEVICE int get_batch_id() {
93  dim3 block = swizzle();
94  return block.z;
95  }
96 };
97 
99 
100 /*
101 ColumnMajorBlockSwizzle<1, OneDirection> is equivalent with IdentityBlockSwizzle
102 groupCols has the effect of controlling the schedulling of thread blocks
103 settings with different groupCols can contribute to the overall performance by affecting L2 cache
104 hit rate
105 
106 consider a regular thread block mapping btween matrix C and different thread blocks
107 note that C is column major, and the leading dimension of thread block id is blockIdx.x
108 
109 let's look at an example where gridIdx.x = 6, gridIdx.y = 7, gridIdx.z = 1
110 (blockIdx.x, blockIdx.y)
111 mapping between threadblockID and C matrix:
112 -------------------------------------------------------
113 (0,0) | (0,1) | (0,2) | (0,3) | (0,4) | (0,5) | (0,6) |
114 -------------------------------------------------------
115 (1,0) | (1,1) | (1,2) | (1,3) | (1,4) | (1,5) | (1,6) |
116 -------------------------------------------------------
117 (2,0) | (2,1) | (2,2) | (2,3) | (2,4) | (2,5) | (2,6) |
118 -------------------------------------------------------
119 (3,0) | (3,1) | (3,2) | (3,3) | (3,4) | (3,5) | (3,6) |
120 -------------------------------------------------------
121 (4,0) | (4,1) | (4,2) | (4,3) | (4,4) | (4,5) | (4,6) |
122 -------------------------------------------------------
123 (5,0) | (5,1) | (5,2) | (5,3) | (5,4) | (5,5) | (5,6) |
124 -------------------------------------------------------
125 
126 A ColumnMajorBlockSwizzle<1, OneDirection> will imply the above order where threadblocks are
127 launched in a column major
128 
129 A ColumnMajorBlockSwizzle<2, OneDirection> swizzles things a little,
130 -------------------------------------------------------
131 (0,0) | (3,0) | (0,2) | (3,2) | (0,4) | (3,4) | (0,6) |
132 -------------------------------------------------------
133 (0,1) | (3,1) | (0,3) | (3,3) | (0,5) | (3,5) | (1,6) |
134 -------------------------------------------------------
135 (1,0) | (4,0) | (1,2) | (4,2) | (1,4) | (4,4) | (2,6) |
136 -------------------------------------------------------
137 (1,1) | (4,1) | (1,3) | (4,3) | (1,5) | (4,5) | (3,6) |
138 -------------------------------------------------------
139 (2,0) | (5,0) | (2,2) | (5,2) | (2,4) | (5,4) | (4,6) |
140 -------------------------------------------------------
141 (2,1) | (5,1) | (2,3) | (5,3) | (2,5) | (5,5) | (5,6) |
142 -------------------------------------------------------
143 
144 so in memory, it would apprear that we work on 2 columns at a time rather than 1
145 Note that the index here really represent how each block maps to memory
146 
147 A ColumnMajorBlockSwizzle<1, Boustrophedon> is similar to ColumnMajorBlockSwizzle<1, OneDirection>
148 except that every column flips the ordering against the previous one
149 -------------------------------------------------------
150 (0,0) | (5,1) | (0,2) | (5,3) | (0,4) | (5,5) | (0,6) |
151 -------------------------------------------------------
152 (1,0) | (4,1) | (1,2) | (4,3) | (1,4) | (4,5) | (1,6) |
153 -------------------------------------------------------
154 (2,0) | (3,1) | (2,2) | (3,3) | (2,4) | (3,5) | (2,6) |
155 -------------------------------------------------------
156 (3,0) | (2,1) | (3,2) | (2,3) | (3,4) | (2,5) | (3,6) |
157 -------------------------------------------------------
158 (4,0) | (1,1) | (4,2) | (1,3) | (4,4) | (1,5) | (4,6) |
159 -------------------------------------------------------
160 (5,0) | (0,1) | (5,2) | (0,3) | (5,4) | (0,5) | (5,6) |
161 -------------------------------------------------------
162 
163 similarily, A ColumnMajorBlockSwizzle<2, Boustrophedon> looks like
164 -------------------------------------------------------
165 (0,0) | (3,0) | (2,3) | (5,3) | (0,4) | (3,4) | (5,6) |
166 -------------------------------------------------------
167 (0,1) | (3,1) | (2,2) | (5,2) | (0,5) | (3,5) | (4,6) |
168 -------------------------------------------------------
169 (1,0) | (4,0) | (1,3) | (4,3) | (1,4) | (4,4) | (3,6) |
170 -------------------------------------------------------
171 (1,1) | (4,1) | (1,2) | (4,2) | (1,5) | (4,5) | (2,6) |
172 -------------------------------------------------------
173 (2,0) | (5,0) | (0,3) | (3,3) | (2,4) | (5,4) | (1,6) |
174 -------------------------------------------------------
175 (2,1) | (5,1) | (0,2) | (3,2) | (2,5) | (5,5) | (0,6) |
176 -------------------------------------------------------
177 
178 */
179 
180 template <int groupCols, enum swizzleDirection::Kind swDirection>
184 
186  CUTLASS_DEVICE dim3 swizzle() {
187  assert(gridDim.z == 1);
188  int linearIdx = getLinearIdx<swDirection>(groupCols);
189  dim3 swizzledBlockIdx;
190  int currGroupCols = groupCols;
191  int prevGroupCols = groupCols;
192 
193  if ((gridDim.y % groupCols != 0) && ((blockIdx.y + (gridDim.y % groupCols)) >= gridDim.y)) {
194  // last colmuns if gridDim.y is not divisble by groupCols
195  currGroupCols = gridDim.y % groupCols;
196  }
197 
198  swizzledBlockIdx.x = (linearIdx / currGroupCols) % gridDim.x;
199  swizzledBlockIdx.y =
200  linearIdx % currGroupCols + prevGroupCols * (linearIdx / (prevGroupCols * gridDim.x));
201  swizzledBlockIdx.z = blockIdx.z;
202 
203  return swizzledBlockIdx;
204  }
205 
208  Coord<3> const &OutputTile) {
209  dim3 grid;
210  grid.x = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
211  grid.y = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
212  grid.z = problem_size.batch();
213  return grid;
214  }
215 
217  CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
218  dim3 block = swizzle();
219  Coord<3> threadblock_offset =
220  make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
221  return threadblock_offset;
222  }
223 
225  CUTLASS_DEVICE int get_batch_id() {
226  dim3 block = swizzle();
227  return block.z;
228  }
229 };
230 
232 
233 /*
234 
235 consider a regular thread block mapping btween matrix C and different thread blocks
236 note that C is column major, and the leading dimension of thread block id is blockIdx.x
237 
238 let's look at an example where gridIdx.x = 6, gridIdx.y = 7, gridIdx.z = 1
239 (blockIdx.x, blockIdx.y)
240 mapping between threadblockID and C matrix:
241 -------------------------------------------------------
242 (0,0) | (0,1) | (0,2) | (0,3) | (0,4) | (0,5) | (0,6) |
243 -------------------------------------------------------
244 (1,0) | (1,1) | (1,2) | (1,3) | (1,4) | (1,5) | (1,6) |
245 -------------------------------------------------------
246 (2,0) | (2,1) | (2,2) | (2,3) | (2,4) | (2,5) | (2,6) |
247 -------------------------------------------------------
248 (3,0) | (3,1) | (3,2) | (3,3) | (3,4) | (3,5) | (3,6) |
249 -------------------------------------------------------
250 (4,0) | (4,1) | (4,2) | (4,3) | (4,4) | (4,5) | (4,6) |
251 -------------------------------------------------------
252 (5,0) | (5,1) | (5,2) | (5,3) | (5,4) | (5,5) | (5,6) |
253 -------------------------------------------------------
254 
255 A RowMajorBlockSwizzle<1, OneDirection> will effectively transpose the map
256 
257 -----------------------------------------------
258 (0,0) | (1,0) | (2,0) | (3,0) | (4,0) | (5,0) |
259 -----------------------------------------------
260 (0,1) | (1,1) | (2,1) | (3,1) | (4,1) | (5,1) |
261 -----------------------------------------------
262 (0,2) | (1,2) | (2,2) | (3,2) | (4,2) | (5,2) |
263 -----------------------------------------------
264 (0,3) | (1,3) | (2,3) | (3,3) | (4,3) | (5,3) |
265 -----------------------------------------------
266 (0,4) | (1,4) | (2,4) | (3,4) | (4,4) | (5,4) |
267 ---------------------------------------------
268 (0,5) | (1,5) | (2,5) | (3,5) | (4,5) | (5,5) |
269 -----------------------------------------------
270 (0,6) | (1,6) | (2,6) | (3,6) | (4,6) | (5,6) |
271 -----------------------------------------------
272 
273 It would aprear in memory we are working on 1 row at a time
274 
275 A ColumnMajorBlockSwizzle<2, OneDirection> swizzles things a little bit more
276 -----------------------------------------------
277 (0,0) | (1,3) | (2,0) | (3,3) | (4,0) | (5,3) |
278 -----------------------------------------------
279 (1,0) | (0,4) | (3,0) | (2,4) | (5,0) | (4,4) |
280 -----------------------------------------------
281 (0,1) | (1,4) | (2,1) | (3,4) | (4,1) | (5,4) |
282 -----------------------------------------------
283 (1,1) | (0,5) | (3,1) | (2,5) | (5,1) | (4,5) |
284 -----------------------------------------------
285 (0,2) | (1,5) | (2,2) | (3,5) | (4,2) | (5,5) |
286 ---------------------------------------------
287 (1,2) | (0,6) | (3,2) | (2,6) | (5,2) | (4,6) |
288 -----------------------------------------------
289 (0,3) | (1,6) | (2,3) | (3,6) | (4,3) | (5,6) |
290 -----------------------------------------------
291 
292 so in memory, it would apprear that we work on 2 rows at a time rather than 1 row
293 Note that the index here really represent how each block maps to memory
294 
295 A RowMajorBlockSwizzle<1, Boustrophedon> is similar to RowMajorBlockSwizzle<1, OneDirection>
296 except that every column flips the ordering against the previous one
297 
298 -----------------------------------------------
299 (0,0) | (1,6) | (2,0) | (3,6) | (4,0) | (5,6) |
300 -----------------------------------------------
301 (0,1) | (1,5) | (2,1) | (3,5) | (4,1) | (5,5) |
302 -----------------------------------------------
303 (0,2) | (1,4) | (2,2) | (3,4) | (4,2) | (5,4) |
304 -----------------------------------------------
305 (0,3) | (1,3) | (2,3) | (3,3) | (4,3) | (5,3) |
306 -----------------------------------------------
307 (0,4) | (1,2) | (2,4) | (3,2) | (4,4) | (5,2) |
308 ---------------------------------------------
309 (0,5) | (1,1) | (2,5) | (3,1) | (4,5) | (5,1) |
310 -----------------------------------------------
311 (0,6) | (1,0) | (2,6) | (3,0) | (4,6) | (5,0) |
312 -----------------------------------------------
313 
314 similarily, A RowMajorBlockSwizzle<2, Boustrophedon> looks like
315 -----------------------------------------------
316 (0,0) | (1,3) | (2,3) | (3,6) | (4,0) | (5,3) |
317 -----------------------------------------------
318 (1,0) | (0,4) | (3,2) | (2,6) | (5,0) | (4,4) |
319 -----------------------------------------------
320 (0,1) | (1,4) | (2,2) | (3,5) | (4,1) | (5,4) |
321 -----------------------------------------------
322 (1,1) | (0,5) | (3,1) | (2,5) | (5,1) | (4,5) |
323 -----------------------------------------------
324 (0,2) | (1,5) | (2,1) | (3,4) | (4,2) | (5,5) |
325 ---------------------------------------------
326 (1,2) | (0,6) | (3,0) | (2,4) | (5,2) | (4,6) |
327 -----------------------------------------------
328 (0,3) | (1,6) | (2,0) | (3,3) | (4,3) | (5,6) |
329 -----------------------------------------------
330 
331 */
332 
333 template <int groupRows, enum swizzleDirection::Kind swDirection>
337 
339  CUTLASS_DEVICE dim3 swizzle() {
340  assert(gridDim.z == 1);
341  int linearIdx = getLinearIdx<swDirection>(groupRows);
342  dim3 swizzledBlockIdx;
343  int currGroupRows = groupRows;
344  int prevGroupRows = groupRows;
345 
346  if ((gridDim.y % groupRows != 0) && ((blockIdx.y + (gridDim.y % groupRows)) >= gridDim.y)) {
347  // last columns
348  currGroupRows = gridDim.y % groupRows;
349  }
350 
351  swizzledBlockIdx.x =
352  linearIdx % currGroupRows + prevGroupRows * (linearIdx / (prevGroupRows * gridDim.x));
353  swizzledBlockIdx.y = (linearIdx / currGroupRows) % gridDim.x;
354  swizzledBlockIdx.z = blockIdx.z;
355 
356  return swizzledBlockIdx;
357  }
358 
361  Coord<3> const &OutputTile) {
362  dim3 grid;
363  grid.x = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
364  grid.y = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
365  grid.z = problem_size.batch();
366  return grid;
367  }
368 
370  CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
371  dim3 block = swizzle();
372  Coord<3> threadblock_offset =
373  make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
374  return threadblock_offset;
375  }
376 
378  CUTLASS_DEVICE int get_batch_id() {
379  dim3 block = swizzle();
380  return block.z;
381  }
382 };
383 
385 
386 } // namespace gemm
387 } // namespace cutlass
Definition: convert.h:33
Definition: threadblock_swizzle.h:37
CUTLASS_HOST_DEVICE IdentityBlockSwizzle()
Ctor. aka ColumnMajorBlockSwizzle<1>
Definition: threadblock_swizzle.h:67
CUTLASS_DEVICE int get_batch_id()
Definition: threadblock_swizzle.h:92
CUTLASS_DEVICE Coord< 3 > get_threadblock_offset(Coord< 3 > const &OutputTile)
Definition: threadblock_swizzle.h:217
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: gemm_coord.h:97
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:318
Definition: gemm_coord.h:43
Definition: threadblock_swizzle.h:181
CUTLASS_DEVICE Coord< 3 > get_threadblock_offset(Coord< 3 > const &OutputTile)
Definition: threadblock_swizzle.h:370
CUTLASS_HOST_DEVICE RowMajorBlockSwizzle()
Ctor.
Definition: threadblock_swizzle.h:336
CUTLASS_DEVICE int getLinearIdx(int groups)
Definition: threadblock_swizzle.h:41
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, Coord< 3 > const &OutputTile)
Definition: threadblock_swizzle.h:207
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, Coord< 3 > const &OutputTile)
Definition: threadblock_swizzle.h:360
CUTLASS_DEVICE dim3 swizzle()
Swizzle the block index.
Definition: threadblock_swizzle.h:70
CUTLASS_DEVICE int get_batch_id()
Definition: threadblock_swizzle.h:225
Definition: threadblock_swizzle.h:65
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_DEVICE Coord< 3 > get_threadblock_offset(Coord< 3 > const &OutputTile)
Definition: threadblock_swizzle.h:84
CUTLASS_DEVICE int get_batch_id()
Definition: threadblock_swizzle.h:378
CUTLASS_HOST_DEVICE ColumnMajorBlockSwizzle()
Ctor.
Definition: threadblock_swizzle.h:183
CUTLASS_DEVICE dim3 swizzle()
Swizzle the block index.
Definition: threadblock_swizzle.h:186
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: gemm_coord.h:89
CUTLASS_DEVICE dim3 swizzle()
Swizzle the block index.
Definition: threadblock_swizzle.h:339
Definition: threadblock_swizzle.h:37
CUTLASS_HOST_DEVICE Index const & batch() const
Returns the GEMM batch coordinate.
Definition: gemm_coord.h:113
Kind
Definition: threadblock_swizzle.h:37
Definition: threadblock_swizzle.h:36
Definition: threadblock_swizzle.h:334
GemmCoord is a structure derived from Coord<4> that specifies a location within the coordinate system...
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, Coord< 3 > const &OutputTile)
Definition: threadblock_swizzle.h:73