30 #include "tools/util/type_traits.h" 40 typename ReductionTraits_
51 typedef typename ReductionTraits::ScalarAlphaBeta
Scalar;
59 typedef typename ReductionTraits::ScalarC
ScalarC;
61 typedef typename ReductionTraits::ScalarD
ScalarD;
126 typename GemmTraits::Epilogue::Scalar>
129 typename cutlass::TypeTraits<typename GemmTraits::Epilogue::Scalar>::host_type(1.0f),
132 typename cutlass::TypeTraits<typename GemmTraits::Epilogue::Scalar>::host_type(0.0f),
GEMM problem description.
Definition: gemm_desc.h:50
int initialize(Scalar alpha_, ScalarA const *d_a_, Index lda_, ScalarB const *d_b_, Index ldb_, Scalar beta_, ScalarC const *d_c_, Index ldc_, ScalarD *d_d_, Index ldd_, ScalarAccum *workspace_ptr_)
Definition: device_gemm_traits.h:105
GemmTraits::ScalarB ScalarB
Definition: device_gemm_traits.h:55
Definition: device_gemm_traits.h:67
Epilogue::ScalarD ScalarD
Definition: gemm_traits.h:394
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: gemm_coord.h:97
Definition: device_gemm.h:40
Definition: gemm_coord.h:43
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
GlobalLoadStreamA_::Scalar ScalarA
The scalar for A.
Definition: gemm_traits.h:374
Params()
ctor
Definition: device_gemm_traits.h:83
bool problem_size_initialized
Check if params are init.
Definition: device_gemm_traits.h:72
Epilogue::ScalarC ScalarC
The scalars in the epilogue.
Definition: gemm_traits.h:393
ReductionTraits::Params ReductionParams
The Params for the second kernel.
Definition: device_gemm_traits.h:80
Definition: device_gemm_traits.h:42
Parameters object constructable on the host.
Definition: gemm_traits.h:416
static MatrixLayout::Kind const kLayoutB
The layout of B. can be deduced from the layout set in batched gemm.
Definition: device_gemm_traits.h:65
GemmTraits::Index Index
Definition: device_gemm_traits.h:49
int workspace_size
Definition: device_gemm_traits.h:76
ReductionTraits::ScalarAlphaBeta Scalar
Definition: device_gemm_traits.h:51
GemmTraits::Params GemmParams
The Params for the first kernel.
Definition: device_gemm_traits.h:78
Implements a software-pipelined efficient GEMM.
GemmTraits::ScalarA ScalarA
Definition: device_gemm_traits.h:53
Definition: tensor_ref.h:131
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Initialize the parameters.
Definition: gemm_traits.h:435
device level GEMM implemented by more than one kernels.
GlobalLoadStreamB_::Scalar ScalarB
The scalar for B.
Definition: gemm_traits.h:381
ReductionTraits_ ReductionTraits
Definition: device_gemm_traits.h:44
GemmTraits_ GemmTraits
Definition: device_gemm_traits.h:43
SplitkPIGemmTraits< GemmTraits_, ReductionTraits_ > This_
Definition: device_gemm_traits.h:45
cutlass::gemm::DeviceGemm< This_ > KernelClass
Definition: device_gemm_traits.h:46
GemmTraits::ScalarD ScalarAccum
Definition: device_gemm_traits.h:57
Index_ Index
The index.
Definition: gemm_traits.h:399
int required_workspace_memory_in_byte()
Definition: device_gemm_traits.h:158
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: gemm_coord.h:89
static MatrixLayout::Kind const kLayoutA
The layout of A.
Definition: gemm_traits.h:372
static MatrixLayout::Kind const kLayoutA
The layout of A. can be deduced from the layout set in batched gemm.
Definition: device_gemm_traits.h:63
Defines properties of matrices used to denote layout and operands to GEMM kernels.
GemmCoord problem_size
The dimensions of the GEMM in K, N, M order.
Definition: device_gemm_traits.h:69
Params(Index m_, Index n_, Index k_)
ctor
Definition: device_gemm_traits.h:87
void init_problem(Index m_, Index n_, Index k_)
init problem is needed if using default ctor
Definition: device_gemm_traits.h:98
ReductionTraits::ScalarD ScalarD
Definition: device_gemm_traits.h:61
ScalarAccum * workspace_ptr
The pointer to workspace memory.
Definition: device_gemm_traits.h:74
static MatrixLayout::Kind const kLayoutB
The layout of B.
Definition: gemm_traits.h:379
ReductionTraits::ScalarC ScalarC
Definition: device_gemm_traits.h:59