Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
device_gemm_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
25 #pragma once
26 #include <assert.h>
28 #include "cutlass/matrix_traits.h"
29 #include "cutlass/gemm/gemm_desc.h"
30 #include "tools/util/type_traits.h"
31 #include <iostream>
32 
33 namespace cutlass {
34 namespace gemm {
35 
36 template <
38  typename GemmTraits_,
40  typename ReductionTraits_
41 >
43  typedef GemmTraits_ GemmTraits;
44  typedef ReductionTraits_ ReductionTraits;
47 
49  typedef typename GemmTraits::Index Index;
51  typedef typename ReductionTraits::ScalarAlphaBeta Scalar;
53  typedef typename GemmTraits::ScalarA ScalarA;
55  typedef typename GemmTraits::ScalarB ScalarB;
57  typedef typename GemmTraits::ScalarD ScalarAccum;
59  typedef typename ReductionTraits::ScalarC ScalarC;
61  typedef typename ReductionTraits::ScalarD ScalarD;
66 
67  struct Params {
70 
80  typename ReductionTraits::Params ReductionParams;
81 
83  Params() :
84  workspace_size(0),
85  problem_size_initialized(false) {}
88  Index n_,
89  Index k_
90  ):
91  problem_size(k_, n_, m_, 1),
92  workspace_size(0),
94 
95  }
96 
99  Index n_,
100  Index k_){
101  problem_size = GemmCoord(k_, n_, m_, 1);
103  }
104 
105  int initialize(Scalar alpha_,
106  ScalarA const* d_a_,
107  Index lda_,
108  ScalarB const* d_b_,
109  Index ldb_,
110  Scalar beta_,
111  ScalarC const* d_c_,
112  Index ldc_,
113  ScalarD* d_d_,
114  Index ldd_,
115  ScalarAccum *workspace_ptr_) {
116 
117  workspace_ptr = workspace_ptr_;
118 
119  //call GemmTraits (first kernel) param
120  //for the first kernel A is A, B is B, C and D are workspace
121  //alpha is one, beta is zero, partitionK_count is reductionTraits::reductionSize
123  typename GemmTraits::ScalarB,
124  typename GemmTraits::ScalarC,
125  typename GemmTraits::ScalarD,
126  typename GemmTraits::Epilogue::Scalar>
127  desc(
128  problem_size,
129  typename cutlass::TypeTraits<typename GemmTraits::Epilogue::Scalar>::host_type(1.0f), /*alpha*/
132  typename cutlass::TypeTraits<typename GemmTraits::Epilogue::Scalar>::host_type(0.0f), /*beta*/
133  TensorRef<typename GemmTraits::ScalarC const, 2>(workspace_ptr, problem_size.m()), /*m = ldc, workspace is not transposed and is packed*/
134  TensorRef<typename GemmTraits::ScalarD, 2>(workspace_ptr, problem_size.m()) /*m = ldd, workspace is not transposed and is packed*/
135  );
136  GemmParams.initialize(desc, ReductionTraits::ReductionSize);
137 
138 
139  //call batched reduction (second kernel) param
140  ReductionParams.initialize(problem_size.m(), /*m*/
141  problem_size.n(), /*n*/
142  alpha_, /*alpha*/
143  beta_, /*beta*/
144  problem_size.n() * problem_size.m() /*reduction_stride*/,
146  problem_size.m(),
147  d_c_,
148  ldc_,
149  d_d_,
150  ldd_);
151 
152  return 0;
153  }
154 
155  // workspace will be used to store D (output) from the first gemm kernel (not D of the entire gemm)
156  // note typedef typename GemmTraits::ScalarD ScalarAccum;
157  // workspace of size of M * N * Reduction
159  assert(problem_size_initialized == true);
160  workspace_size = problem_size.n() * problem_size.m() * ReductionTraits::ReductionSize * static_cast<int>(sizeof(ScalarAccum));
161  return workspace_size;
162  }
163 
164 
165  };
166 
167 };
168 
169 } // namespace device_gemm
170 } // namespace cutalss
GEMM problem description.
Definition: gemm_desc.h:50
int initialize(Scalar alpha_, ScalarA const *d_a_, Index lda_, ScalarB const *d_b_, Index ldb_, Scalar beta_, ScalarC const *d_c_, Index ldc_, ScalarD *d_d_, Index ldd_, ScalarAccum *workspace_ptr_)
Definition: device_gemm_traits.h:105
GemmTraits::ScalarB ScalarB
Definition: device_gemm_traits.h:55
Definition: convert.h:33
Definition: device_gemm_traits.h:67
Epilogue::ScalarD ScalarD
Definition: gemm_traits.h:394
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: gemm_coord.h:97
Definition: device_gemm.h:40
Definition: gemm_coord.h:43
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
GlobalLoadStreamA_::Scalar ScalarA
The scalar for A.
Definition: gemm_traits.h:374
Params()
ctor
Definition: device_gemm_traits.h:83
bool problem_size_initialized
Check if params are init.
Definition: device_gemm_traits.h:72
Epilogue::ScalarC ScalarC
The scalars in the epilogue.
Definition: gemm_traits.h:393
ReductionTraits::Params ReductionParams
The Params for the second kernel.
Definition: device_gemm_traits.h:80
Definition: device_gemm_traits.h:42
Parameters object constructable on the host.
Definition: gemm_traits.h:416
static MatrixLayout::Kind const kLayoutB
The layout of B. can be deduced from the layout set in batched gemm.
Definition: device_gemm_traits.h:65
GemmTraits::Index Index
Definition: device_gemm_traits.h:49
int workspace_size
Definition: device_gemm_traits.h:76
ReductionTraits::ScalarAlphaBeta Scalar
Definition: device_gemm_traits.h:51
GemmTraits::Params GemmParams
The Params for the first kernel.
Definition: device_gemm_traits.h:78
Implements a software-pipelined efficient GEMM.
GemmTraits::ScalarA ScalarA
Definition: device_gemm_traits.h:53
Definition: tensor_ref.h:131
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Initialize the parameters.
Definition: gemm_traits.h:435
device level GEMM implemented by more than one kernels.
GlobalLoadStreamB_::Scalar ScalarB
The scalar for B.
Definition: gemm_traits.h:381
ReductionTraits_ ReductionTraits
Definition: device_gemm_traits.h:44
GemmTraits_ GemmTraits
Definition: device_gemm_traits.h:43
SplitkPIGemmTraits< GemmTraits_, ReductionTraits_ > This_
Definition: device_gemm_traits.h:45
cutlass::gemm::DeviceGemm< This_ > KernelClass
Definition: device_gemm_traits.h:46
GemmTraits::ScalarD ScalarAccum
Definition: device_gemm_traits.h:57
Index_ Index
The index.
Definition: gemm_traits.h:399
int required_workspace_memory_in_byte()
Definition: device_gemm_traits.h:158
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: gemm_coord.h:89
static MatrixLayout::Kind const kLayoutA
The layout of A.
Definition: gemm_traits.h:372
static MatrixLayout::Kind const kLayoutA
The layout of A. can be deduced from the layout set in batched gemm.
Definition: device_gemm_traits.h:63
Defines properties of matrices used to denote layout and operands to GEMM kernels.
GemmCoord problem_size
The dimensions of the GEMM in K, N, M order.
Definition: device_gemm_traits.h:69
Params(Index m_, Index n_, Index k_)
ctor
Definition: device_gemm_traits.h:87
void init_problem(Index m_, Index n_, Index k_)
init problem is needed if using default ctor
Definition: device_gemm_traits.h:98
ReductionTraits::ScalarD ScalarD
Definition: device_gemm_traits.h:61
ScalarAccum * workspace_ptr
The pointer to workspace memory.
Definition: device_gemm_traits.h:74
static MatrixLayout::Kind const kLayoutB
The layout of B.
Definition: gemm_traits.h:379
ReductionTraits::ScalarC ScalarC
Definition: device_gemm_traits.h:59