30 #if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700) 31 #define CUTLASS_USE_WMMA_API 33 #if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750) 34 #define CUTLASS_USE_SUBBYTE_WMMA 39 #if __CUDACC_VER_MAJOR__ >= 10 54 template <MatrixLayout::Kind kLayout_>
56 typedef nvcuda::wmma::col_major Layout;
61 struct WmmaLayout<MatrixLayout::
kRowMajor> {
62 typedef nvcuda::wmma::row_major Layout;
68 template <
typename Type_>
73 #ifdef CUTLASS_USE_SUBBYTE_WMMA 76 struct WmmaDataType<Vector<bin1_t, 32> > {
77 typedef nvcuda::wmma::experimental::precision::b1 Type;
82 struct WmmaDataType<Vector<int4_t, 8> > {
83 typedef nvcuda::wmma::experimental::precision::s4 Type;
88 struct WmmaDataType<Vector<uint4_t, 8> > {
89 typedef nvcuda::wmma::experimental::precision::u4 Type;
100 struct WmmaMatrix {};
105 template <MatrixLayout::Kind kLayout_,
typename Scalar_,
typename WmmaShape_>
106 struct WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_>
107 :
public nvcuda::wmma::fragment<
109 nvcuda::wmma::matrix_a,
115 typename WmmaDataType<Scalar_>::Type,
117 typename WmmaLayout<kLayout_>::Layout> {
119 typedef WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_> This_;
122 CUTLASS_DEVICE This_& operator=(Scalar_
const& x) {
123 nvcuda::wmma::fill_fragment(*
this, x);
128 CUTLASS_DEVICE
void load(Scalar_
const* pointer,
int const stride) {
129 nvcuda::wmma::load_matrix_sync(*
this, pointer, stride);
133 CUTLASS_DEVICE
void store(Scalar_* pointer,
int const stride)
const {
134 nvcuda::wmma::store_matrix_sync(pointer, *
this, stride);
141 template <MatrixLayout::Kind kLayout_,
typename Scalar_,
typename WmmaShape_>
142 struct WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_>
143 :
public nvcuda::wmma::fragment<
145 nvcuda::wmma::matrix_b,
151 typename WmmaDataType<Scalar_>::Type,
153 typename WmmaLayout<kLayout_>::Layout> {
155 typedef WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_> This_;
158 CUTLASS_DEVICE This_& operator=(Scalar_
const& x) {
159 nvcuda::wmma::fill_fragment(*
this, x);
164 CUTLASS_DEVICE
void load(Scalar_
const* pointer,
int const stride) {
165 nvcuda::wmma::load_matrix_sync(*
this, pointer, stride);
169 CUTLASS_DEVICE
void store(Scalar_* pointer,
int const stride)
const {
170 nvcuda::wmma::store_matrix_sync(pointer, *
this, stride);
177 template <MatrixLayout::Kind kLayout_,
typename Scalar_,
typename WmmaShape_>
178 struct WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_>
179 :
public nvcuda::wmma::fragment<
181 nvcuda::wmma::accumulator,
189 typedef WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_> This_;
194 CUTLASS_DEVICE This_& operator=(Scalar_
const& x) {
195 nvcuda::wmma::fill_fragment(*
this, x);
200 CUTLASS_DEVICE
void load(Scalar_
const* pointer,
int const stride) {
202 nvcuda::wmma::load_matrix_sync(
206 kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
210 CUTLASS_DEVICE
void store(Scalar_* pointer,
int const stride)
const {
212 nvcuda::wmma::store_matrix_sync(
216 kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
229 struct Vectorize<WmmaMatrix<kOperand_, kLayout_, Scalar_, WmmaShape_>, 1> {
230 typedef WmmaMatrix<kOperand_, kLayout_, Scalar_, WmmaShape_>
Type;
236 #endif // defined CUTLASS_USE_WMMA_API
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Definition: matrix_traits.h:159
Vector< Element_, kLanes_ > Type
Definition: vector.h:271
Defines a 1D vector of elements held in the registers of each thread.
Kind
Definition: matrix_traits.h:357
Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...