30 #if !defined(__CUDACC_RTC__) 42 template <
typename Gemm_>
44 void gemm_kernel(
typename Gemm_::Params params) {
46 __shared__
typename Gemm_::SharedStorage shared_storage;
49 Gemm_ gemm(params, shared_storage);
57 template <
typename Gemm_>
61 __shared__
typename Gemm_::SharedStorage shared_storage;
64 Gemm_ gemm(params, shared_storage);
72 template <
typename Gemm,
bool WithLaunchBounds>
75 gemm_kernel<Gemm><<< grid, block, 0, stream >>>(params);
82 template <
typename Gemm>
85 gemm_kernel_nolb<Gemm><<< grid, block, 0, stream >>>(params);
91 template <
typename GemmTraits_>
107 typedef typename Traits::Epilogue::ScalarC
ScalarC;
109 typedef typename Traits::Epilogue::ScalarD
ScalarD;
111 typedef typename Traits::Index
Index;
117 static int const kThreads = Traits::GemmConfig::kThreads;
121 Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
127 typedef typename Traits::Params
Params;
134 #if !defined(__CUDACC_RTC__) 137 cudaStream_t stream = cudaStreamDefault) {
143 return cudaGetLastError();
147 static __host__ cudaError_t
launch(CUfunction kernel,
149 CUstream stream = CU_STREAM_LEGACY) {
152 void* params_[] = {
const_cast<void*
>(
reinterpret_cast<void const*
>(&
params))};
154 CUresult result = cuLaunchKernel(
158 0, stream, params_, 0);
160 if (result != CUDA_SUCCESS) {
161 return cudaErrorLaunchFailure;
177 template <
bool Res
idue,
bool LastIteration>
178 CUTLASS_DEVICE
void consume_tile(
typename Traits::GlobalLoadStream& global_to_shared_stream,
179 typename Traits::SharedStream& shared_load_stream,
180 typename MultiplyAdd::Accumulators& accumulators,
183 if (Residue && outer_k <= Traits::OutputTile::kD) {
184 global_to_shared_stream.residue(outer_k);
188 if (!LastIteration) {
189 global_to_shared_stream.copy();
195 shared_load_stream.copy(step + 1);
198 shared_load_stream.commit(step);
203 multiply_add.multiply_add(shared_load_stream.fragment_a(step),
204 shared_load_stream.fragment_b(step),
210 Traits::shared_load_fence(
true);
213 if (!LastIteration) {
214 global_to_shared_stream.commit();
217 Traits::shared_store_fence(
true);
219 if (!LastIteration) {
221 shared_load_stream.inc_stage();
223 shared_load_stream.copy(0);
239 typename Traits::BlockSwizzle block_swizzle;
241 block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::OutputTile>());
247 Coord<3> bounds = block_swizzle.get_threadblock_bounds(
params.problem_size,
251 typename Traits::GlobalLoadStream global_to_shared_stream(
252 params.global_to_shared_stream,
259 global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());
266 global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);
269 global_to_shared_stream.copy();
272 global_to_shared_stream.commit();
275 Traits::shared_store_fence(
false);
279 global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
282 typename Traits::SharedStream shared_load_stream(
287 shared_load_stream.copy(0);
290 typename MultiplyAdd::Accumulators accumulators;
293 clear.
clear(accumulators);
298 Index outer_k = bounds[0] - Traits::OutputTile::kD;
300 if (Traits::GemmConfig::kResidueInProlog) {
304 for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
305 consume_tile<false, false>(
306 global_to_shared_stream, shared_load_stream, accumulators, outer_k);
311 for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
312 consume_tile<false, true>(
313 global_to_shared_stream, shared_load_stream, accumulators, outer_k);
319 if (Traits::GemmConfig::kResidueSeparate) {
322 for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
323 consume_tile<false, false>(
324 global_to_shared_stream, shared_load_stream, accumulators, outer_k);
330 for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
331 consume_tile<true, false>(
332 global_to_shared_stream, shared_load_stream, accumulators, outer_k);
337 typedef typename Traits::Epilogue Epilogue;
339 epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
SharedStorage & shared_storage
The shared storage.
Definition: gemm.h:349
Traits::Epilogue::ScalarD ScalarD
The scalar for D.
Definition: gemm.h:109
Traits::SharedStorage SharedStorage
The shared storage.
Definition: gemm.h:98
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream=0)
Definition: gemm.h:74
Params const & params
The params.
Definition: gemm.h:347
Traits::Epilogue::ScalarC ScalarC
The scalar for C.
Definition: gemm.h:107
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream=0)
Definition: gemm.h:84
CUTLASS_DEVICE void multiply_add()
Do the GEMM.
Definition: gemm.h:237
GemmTraits_ Traits
The traits.
Definition: gemm.h:96
Traits::Epilogue::Scalar ScalarEpilogue
The scalar in the epilogue.
Definition: gemm.h:105
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream &global_to_shared_stream, typename Traits::SharedStream &shared_load_stream, typename MultiplyAdd::Accumulators &accumulators, Index outer_k)
Computes a warp-level GEMM on data held in shared memory.
Definition: gemm.h:178
Traits::ScalarB ScalarB
The scalar for B.
Definition: gemm.h:103
Definition: clear_accumulators.h:38
Traits::MultiplyAdd MultiplyAdd
Define the mainloop iteration size.
Definition: gemm.h:114
static __host__ cudaError_t launch(CUfunction kernel, Params const ¶ms, CUstream stream=CU_STREAM_LEGACY)
Launch the kernel.
Definition: gemm.h:147
static Index const kWarpGemmSteps
Definition: gemm.h:120
Partial specialization for launching the GEMM kernel with or without launch bounds.
Definition: gemm.h:73
__global__ __launch_bounds__(Gemm_::kThreads) void gemm_kernel(typename Gemm_
GEMM kernel with launch bounds specified.
Definition: gemm.h:43
Gemm< GemmTraits_ > This_
This class.
Definition: gemm.h:94
CUTLASS_DEVICE Gemm(Params const ¶ms_, SharedStorage &shared_storage_)
Ctor.
Definition: gemm.h:173
Traits::ScalarA ScalarA
The scalar for A.
Definition: gemm.h:101
CUTLASS_DEVICE void clear(Fragment_ &fragment)
Clear the fragment.
Definition: clear_accumulators.h:50
__global__ void gemm_kernel_nolb(typename Gemm_::Params params)
GEMM kernel without launch bounds specified.
Definition: gemm.h:59
static int const kThreads
The number of threads.
Definition: gemm.h:117
Traits::Params Params
Use the params object defined in traits.
Definition: gemm.h:124
static __host__ cudaError_t launch(Params const ¶ms, cudaStream_t stream=cudaStreamDefault)
Support for NVRTC.
Definition: gemm.h:136
Traits::Index Index
The index.
Definition: gemm.h:111