Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #if !defined(__CUDACC_RTC__)
31 #include <cuda.h>
32 #endif
33 
34 #include "cutlass/coord.h"
35 #include "cutlass/util/platform.h"
36 namespace cutlass {
37 namespace gemm {
38 
40 
42 template <typename Gemm_>
43 __global__ __launch_bounds__(Gemm_::kThreads)
44 void gemm_kernel(typename Gemm_::Params params) {
45  // Declare shared memory.
46  __shared__ typename Gemm_::SharedStorage shared_storage;
47 
48  // Construct the GEMM object.
49  Gemm_ gemm(params, shared_storage);
50  // Run GEMM.
51  gemm.multiply_add();
52 }
53 
55 
57 template <typename Gemm_>
58 __global__ /* __launch_bounds__(Gemm_::kThreads) */
59 void gemm_kernel_nolb(typename Gemm_::Params params) {
60  // Declare shared memory.
61  __shared__ typename Gemm_::SharedStorage shared_storage;
62 
63  // Construct the GEMM object.
64  Gemm_ gemm(params, shared_storage);
65  // Run GEMM.
66  gemm.multiply_add();
67 }
68 
70 
72 template <typename Gemm, bool WithLaunchBounds>
73 struct Launch {
74  Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
75  gemm_kernel<Gemm><<< grid, block, 0, stream >>>(params);
76  }
77 };
78 
80 
82 template <typename Gemm>
83 struct Launch<Gemm, false> {
84  Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
85  gemm_kernel_nolb<Gemm><<< grid, block, 0, stream >>>(params);
86  }
87 };
88 
90 
91 template <typename GemmTraits_>
92 struct Gemm {
96  typedef GemmTraits_ Traits;
98  typedef typename Traits::SharedStorage SharedStorage;
99 
101  typedef typename Traits::ScalarA ScalarA;
103  typedef typename Traits::ScalarB ScalarB;
105  typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
107  typedef typename Traits::Epilogue::ScalarC ScalarC;
109  typedef typename Traits::Epilogue::ScalarD ScalarD;
111  typedef typename Traits::Index Index;
112 
114  typedef typename Traits::MultiplyAdd MultiplyAdd;
115 
117  static int const kThreads = Traits::GemmConfig::kThreads;
118 
119  // Number of warp-level multiply-accumulate steps executed by each warp.
120  static Index const kWarpGemmSteps =
121  Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
122 
123  // Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
124  static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
125 
127  typedef typename Traits::Params Params;
128 
129 //
130 // Static function members
131 //
132 
134 #if !defined(__CUDACC_RTC__)
135  static __host__ cudaError_t launch(Params const& params,
137  cudaStream_t stream = cudaStreamDefault) {
138 
139  // Launch the kernel.
141  params, params.grid, params.block, stream);
142 
143  return cudaGetLastError();
144  }
145 
147  static __host__ cudaError_t launch(CUfunction kernel,
148  Params const& params,
149  CUstream stream = CU_STREAM_LEGACY) {
150 
151  // Launch the kernel.
152  void* params_[] = {const_cast<void*>(reinterpret_cast<void const*>(&params))};
153 
154  CUresult result = cuLaunchKernel(
155  kernel,
156  params.grid.x, params.grid.y, params.grid.z,
157  params.block.x, params.block.y, params.block.z,
158  0, stream, params_, 0);
159 
160  if (result != CUDA_SUCCESS) {
161  return cudaErrorLaunchFailure;
162  }
163  return cudaSuccess;
164  }
165 
166 #endif
167 
168  //
169  // Methods
170  //
171 
173  CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_)
174  : params(params_), shared_storage(shared_storage_) {}
175 
177  template <bool Residue, bool LastIteration>
178  CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
179  typename Traits::SharedStream& shared_load_stream,
180  typename MultiplyAdd::Accumulators& accumulators,
181  Index outer_k) {
182  // If residue portion and not calculating residue in prolog, update residue predicates now.
183  if (Residue && outer_k <= Traits::OutputTile::kD) {
184  global_to_shared_stream.residue(outer_k);
185  }
186 
187  // Load data for the next iteration of the main loop (unless it's the last iteration).
188  if (!LastIteration) {
189  global_to_shared_stream.copy();
190  }
191 
193  for (int step = 0; step < kWarpGemmSteps - 1; ++step) {
194  // Trigger the copy from shared memory for the next A/B values.
195  shared_load_stream.copy(step + 1);
196 
197  // Make sure the values are available for the current iteration to do the multiply-add.
198  shared_load_stream.commit(step);
199 
201 
202  // Do the math on the fragments of the current iteration.
203  multiply_add.multiply_add(shared_load_stream.fragment_a(step),
204  shared_load_stream.fragment_b(step),
205  accumulators,
206  accumulators);
207  }
208 
209  // Make sure the data from shared memory has been entirely consumed.
210  Traits::shared_load_fence(true);
211 
212  // Commit the data in shared memory for A/B.
213  if (!LastIteration) {
214  global_to_shared_stream.commit();
215  }
216  // Make sure the data is in shared memory.
217  Traits::shared_store_fence(true);
218 
219  if (!LastIteration) {
220  // Move to the next stage for the load (if it makes sense).
221  shared_load_stream.inc_stage();
222  // Trigger the copy from shared memory for the next loop iteration.
223  shared_load_stream.copy(0);
224  }
225  // Make sure the values are available for the current iteration to do the multiply-add.
226  shared_load_stream.commit(kWarpGemmSteps - 1);
227 
228  // Do the math on the fragments of the current iteration.
230  multiply_add.multiply_add(shared_load_stream.fragment_a(kWarpGemmSteps - 1),
231  shared_load_stream.fragment_b(kWarpGemmSteps - 1),
232  accumulators,
233  accumulators);
234  }
235 
237  CUTLASS_DEVICE void multiply_add() {
238  // Swizzle the IDs of the block (to enable better cache behavior).
239  typename Traits::BlockSwizzle block_swizzle;
240  Coord<3> threadblock_offset =
241  block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::OutputTile>());
242 
243  // We may want to use shared memory to clear the registers.
244  typedef typename Traits::ClearAccumulators ClearAccumulators;
245 
246  // Get the bounds for each thread, it maybe different than problem_size
247  Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,
248  params.partitionK_range);
249 
250  // The streams to read A/B from global memory to shared memory.
251  typename Traits::GlobalLoadStream global_to_shared_stream(
252  params.global_to_shared_stream,
253  shared_storage.main_loop.global_to_shared_stream,
254  shared_storage.main_loop.threadblock_tile.reference(),
255  bounds,
256  threadblock_offset);
257 
258  // update A and B pointer offset based on batch_id and batch_stride_offset
259  global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());
260 
261  // Create the accumulator clear.
262  ClearAccumulators clear;
263 
264  // Deal with residue in prolog.
265  // global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
266  global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);
267 
268  // Fetch the fragments for A and B from global memory.
269  global_to_shared_stream.copy();
270 
271  // Copy the elements to shared memory (after transformation if needed).
272  global_to_shared_stream.commit();
273 
274  // Make sure the data is in shared memory.
275  Traits::shared_store_fence(false);
276 
277  // Rollback to the beginning of the first tile (if residue exists).
278  // global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
279  global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
280 
281  // The stream of data from shared memory to fragments.
282  typename Traits::SharedStream shared_load_stream(
283  params.shared_stream,
284  shared_storage.main_loop.threadblock_tile.reference());
285 
286  // Trigger the copy from shared memory for the 1st stream.
287  shared_load_stream.copy(0);
288 
289  // Allocate the accumulators.
290  typename MultiplyAdd::Accumulators accumulators;
291 
292  // Clear the accumulators.
293  clear.clear(accumulators);
294 
295  // Initial index
296  // Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
297  // problem_size[0] might be bigger than bounds[0]
298  Index outer_k = bounds[0] - Traits::OutputTile::kD;
299  // Check if we are computing residue in prolog or not.
300  if (Traits::GemmConfig::kResidueInProlog) {
301  // Execute all mainloop iterations but the last one.
302 
304  for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
305  consume_tile<false, false>(
306  global_to_shared_stream, shared_load_stream, accumulators, outer_k);
307  }
308 
309  // Don't load data for the last "residue" portion since we've already computed the residue.
311  for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
312  consume_tile<false, true>(
313  global_to_shared_stream, shared_load_stream, accumulators, outer_k);
314  }
315  } else {
316  // When kResidueSeparate = true, execute all mainloop iterations but the last two without any
317  // consideration for K-residue or predicate updates. This improves the steady state of some
318  // kernels.
319  if (Traits::GemmConfig::kResidueSeparate) {
320 
322  for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
323  consume_tile<false, false>(
324  global_to_shared_stream, shared_load_stream, accumulators, outer_k);
325  }
326  }
327 
328  // Execute remaining tiles with K-residue predicate updates enabled.
330  for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
331  consume_tile<true, false>(
332  global_to_shared_stream, shared_load_stream, accumulators, outer_k);
333  }
334  }
335 
336  // Epilogue.
337  typedef typename Traits::Epilogue Epilogue;
338  Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
339  epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
340  }
341 
342  //
343  // Data members
344  //
345 
347  Params const& params;
350 };
351 
353 
354 } // namespace gemm
355 } // namespace cutlass
#define CUTLASS_PRAGMA_UNROLL
Definition: performance_tuning.h:35
Definition: convert.h:33
SharedStorage & shared_storage
The shared storage.
Definition: gemm.h:349
Traits::Epilogue::ScalarD ScalarD
The scalar for D.
Definition: gemm.h:109
Traits::SharedStorage SharedStorage
The shared storage.
Definition: gemm.h:98
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream=0)
Definition: gemm.h:74
#define CUTLASS_GEMM_LOOP
Definition: performance_tuning.h:39
Params const & params
The params.
Definition: gemm.h:347
Traits::Epilogue::ScalarC ScalarC
The scalar for C.
Definition: gemm.h:107
C++ features that may be otherwise unimplemented for CUDA device functions.
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream=0)
Definition: gemm.h:84
CUTLASS_DEVICE void multiply_add()
Do the GEMM.
Definition: gemm.h:237
GemmTraits_ Traits
The traits.
Definition: gemm.h:96
Traits::Epilogue::Scalar ScalarEpilogue
The scalar in the epilogue.
Definition: gemm.h:105
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream &global_to_shared_stream, typename Traits::SharedStream &shared_load_stream, typename MultiplyAdd::Accumulators &accumulators, Index outer_k)
Computes a warp-level GEMM on data held in shared memory.
Definition: gemm.h:178
Traits::ScalarB ScalarB
The scalar for B.
Definition: gemm.h:103
Definition: clear_accumulators.h:38
Definition: gemm.h:92
Traits::MultiplyAdd MultiplyAdd
Define the mainloop iteration size.
Definition: gemm.h:114
static __host__ cudaError_t launch(CUfunction kernel, Params const &params, CUstream stream=CU_STREAM_LEGACY)
Launch the kernel.
Definition: gemm.h:147
static Index const kWarpGemmSteps
Definition: gemm.h:120
Partial specialization for launching the GEMM kernel with or without launch bounds.
Definition: gemm.h:73
__global__ __launch_bounds__(Gemm_::kThreads) void gemm_kernel(typename Gemm_
GEMM kernel with launch bounds specified.
Definition: gemm.h:43
Gemm< GemmTraits_ > This_
This class.
Definition: gemm.h:94
CUTLASS_DEVICE Gemm(Params const &params_, SharedStorage &shared_storage_)
Ctor.
Definition: gemm.h:173
#define static_assert(__e, __m)
Definition: platform.h:153
Traits::ScalarA ScalarA
The scalar for A.
Definition: gemm.h:101
CUTLASS_DEVICE void clear(Fragment_ &fragment)
Clear the fragment.
Definition: clear_accumulators.h:50
__global__ void gemm_kernel_nolb(typename Gemm_::Params params)
GEMM kernel without launch bounds specified.
Definition: gemm.h:59
static int const kThreads
The number of threads.
Definition: gemm.h:117
Traits::Params Params
Use the params object defined in traits.
Definition: gemm.h:124
static __host__ cudaError_t launch(Params const &params, cudaStream_t stream=cudaStreamDefault)
Support for NVRTC.
Definition: gemm.h:136
Traits::Index Index
The index.
Definition: gemm.h:111