31 #ifdef CUTLASS_USE_WMMA_API 63 typename Accumulator_,
65 typename WarpGemmShape_,
67 typename InstructionShape_,
77 int kScalarsPerLdgCAndStgD_,
83 struct WmmaGemmConfig :
public GemmConfig<
95 WmmaGemmMultiplyAdd<kLayoutA_,
99 MatrixLayout::kColumnMajor,
116 kScalarsPerLdgCAndStgD_,
133 typename GemmConfig_,
135 struct WmmaGemmTileTraitsHelperA {};
139 template <
typename GemmConfig_,
typename ScalarA_>
140 struct WmmaGemmTileTraitsHelperA<MatrixLayout::
kColumnMajor, GemmConfig_, ScalarA_>
141 :
public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
143 typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
146 static int const kSkew = 16 /
sizeof(
typename Base::MultiplyAddScalar);
148 typedef Shape<GemmConfig_::kStages,
149 GemmConfig_::OutputTile::kD,
150 GemmConfig_::OutputTile::kW + kSkew>
156 typename Base::MultiplyAddScalar,
157 typename GemmConfig_::InstructionShape>
161 typedef GemmSharedStoreTileAbTraits<
163 typename Base::MultiplyAddScalar,
167 typename Base::GlobalTileTraits::Threads,
169 GemmConfig_::kScalarsPerStsA>
170 SharedStoreTileTraits;
173 static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
175 static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
177 typedef WmmaGemmSharedLoadTileATraits<
181 typename Base::MultiplyAddScalar,
185 typename GemmConfig_::Warps,
187 GemmConfig_::InstructionShape::kW,
189 Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
191 Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
193 typename GemmConfig_::InstructionShape>
194 SharedLoadTileTraits;
199 template <
typename GemmConfig_,
typename ScalarA_>
200 struct WmmaGemmTileTraitsHelperA<MatrixLayout::
kRowMajor, GemmConfig_, ScalarA_> {
205 typedef typename GemmConfig_::ScalarA Scalar;
207 typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
213 typename GemmConfig_::InstructionShape>
217 typedef GemmGlobalTileTraits<
225 Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
227 Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
229 GemmConfig_::kScalarsPerLdgA>
233 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
235 typedef Shape<GemmConfig_::kStages,
236 GemmConfig_::OutputTile::kW,
237 GemmConfig_::OutputTile::kD + kSkew>
241 typedef GemmSharedStoreTileAbTraits<
247 typename GlobalTileTraits::Threads,
249 GemmConfig_::kScalarsPerStsA>
250 SharedStoreTileTraits;
253 static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
255 typedef WmmaGemmSharedLoadTileATraits<
263 typename GemmConfig_::Warps,
265 GemmConfig_::InstructionShape::kW * Tile::kW,
267 Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
269 Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
271 typename GemmConfig_::InstructionShape>
272 SharedLoadTileTraits;
277 #ifdef CUTLASS_USE_SUBBYTE_WMMA 278 template <
typename GemmConfig_>
280 struct WmmaGemmTileTraitsHelperA<MatrixLayout::
kRowMajor, GemmConfig_, Vector<bin1_t, 32> > {
285 typedef typename GemmConfig_::ScalarA Scalar;
287 typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
291 static int const kBitsPerScalar =
sizeof(Scalar) * 8;
297 typename GemmConfig_::InstructionShape>
301 typedef GemmGlobalTileTraits<
309 Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kBitsPerScalar>,
312 GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kBitsPerScalar),
313 GemmConfig_::OutputTile::kD / kBitsPerScalar>,
315 GemmConfig_::kScalarsPerLdgA / kBitsPerScalar>
319 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
321 typedef Shape<GemmConfig_::kStages,
322 GemmConfig_::OutputTile::kW,
323 GemmConfig_::OutputTile::kD / kBitsPerScalar + kSkew>
327 typedef GemmSharedStoreTileAbTraits<
333 typename GlobalTileTraits::Threads,
335 GemmConfig_::kScalarsPerStsA / kBitsPerScalar>
336 SharedStoreTileTraits;
339 static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
341 typedef WmmaGemmSharedLoadTileATraits<
349 typename GemmConfig_::Warps,
351 GemmConfig_::InstructionShape::kW * Tile::kW,
353 Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
355 Shape<GemmConfig_::InstructionShape::kD / kBitsPerScalar, 0, kScalarsPerW * Tile::kW>,
357 typename GemmConfig_::InstructionShape>
358 SharedLoadTileTraits;
364 #ifdef CUTLASS_USE_SUBBYTE_WMMA 365 template <
typename GemmConfig_>
367 struct WmmaGemmTileTraitsHelperA<MatrixLayout::
kRowMajor, GemmConfig_, Vector<uint4_t, 8> > {
372 typedef typename GemmConfig_::ScalarA Scalar;
374 typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
378 static int const kInt4PerScalar =
sizeof(Scalar) * 2;
384 typename GemmConfig_::InstructionShape>
388 typedef GemmGlobalTileTraits<
396 Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
399 GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
400 GemmConfig_::OutputTile::kD / kInt4PerScalar>,
402 GemmConfig_::kScalarsPerLdgA / kInt4PerScalar>
406 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
408 typedef Shape<GemmConfig_::kStages,
409 GemmConfig_::OutputTile::kW,
410 GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
414 typedef GemmSharedStoreTileAbTraits<
420 typename GlobalTileTraits::Threads,
422 GemmConfig_::kScalarsPerStsA / kInt4PerScalar>
423 SharedStoreTileTraits;
426 static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
428 typedef WmmaGemmSharedLoadTileATraits<
436 typename GemmConfig_::Warps,
438 GemmConfig_::InstructionShape::kW * Tile::kW,
440 Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
442 Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
444 typename GemmConfig_::InstructionShape>
445 SharedLoadTileTraits;
451 #ifdef CUTLASS_USE_SUBBYTE_WMMA 452 template <
typename GemmConfig_>
454 struct WmmaGemmTileTraitsHelperA<MatrixLayout::
kRowMajor, GemmConfig_, Vector<int4_t, 8> > {
459 typedef typename GemmConfig_::ScalarA Scalar;
461 typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
465 static int const kInt4PerScalar =
sizeof(Scalar) * 2;
471 typename GemmConfig_::InstructionShape>
475 typedef GemmGlobalTileTraits<
483 Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
486 GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
487 GemmConfig_::OutputTile::kD / kInt4PerScalar>,
489 GemmConfig_::kScalarsPerLdgA / kInt4PerScalar>
493 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
495 typedef Shape<GemmConfig_::kStages,
496 GemmConfig_::OutputTile::kW,
497 GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
501 typedef GemmSharedStoreTileAbTraits<
507 typename GlobalTileTraits::Threads,
509 GemmConfig_::kScalarsPerStsA / kInt4PerScalar>
510 SharedStoreTileTraits;
513 static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
515 typedef WmmaGemmSharedLoadTileATraits<
523 typename GemmConfig_::Warps,
525 GemmConfig_::InstructionShape::kW * Tile::kW,
527 Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
529 Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
531 typename GemmConfig_::InstructionShape>
532 SharedLoadTileTraits;
539 typename GemmConfig_,
541 struct WmmaGemmTileTraitsHelperB {};
545 template <
typename GemmConfig_,
typename ScalarB_>
546 struct WmmaGemmTileTraitsHelperB<MatrixLayout::
kRowMajor, GemmConfig_, ScalarB_>
547 :
public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
549 typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
552 static int const kSkew = 16 /
sizeof(
typename Base::MultiplyAddScalar);
554 typedef Shape<GemmConfig_::kStages,
555 GemmConfig_::OutputTile::kD,
556 GemmConfig_::OutputTile::kH + kSkew>
562 typename Base::MultiplyAddScalar,
563 typename GemmConfig_::InstructionShape>
567 typedef GemmSharedStoreTileAbTraits<
569 typename Base::MultiplyAddScalar,
573 typename Base::GlobalTileTraits::Threads,
575 GemmConfig_::kScalarsPerStsB>
576 SharedStoreTileTraits;
579 static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
581 static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
583 typedef WmmaGemmSharedLoadTileBTraits<
587 typename Base::MultiplyAddScalar,
591 typename GemmConfig_::Warps,
593 GemmConfig_::InstructionShape::kH,
595 Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
597 Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
599 typename GemmConfig_::InstructionShape>
600 SharedLoadTileTraits;
605 template <
typename GemmConfig_,
typename ScalarB_>
606 struct WmmaGemmTileTraitsHelperB<MatrixLayout::
kColumnMajor, GemmConfig_, ScalarB_> {
611 typedef typename GemmConfig_::ScalarB Scalar;
613 typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
619 typename GemmConfig_::InstructionShape>
623 typedef GemmGlobalTileTraits<
631 Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
633 Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
635 GemmConfig_::kScalarsPerLdgB>
639 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
641 typedef Shape<GemmConfig_::kStages,
642 GemmConfig_::OutputTile::kH,
643 GemmConfig_::OutputTile::kD + kSkew>
647 typedef GemmSharedStoreTileAbTraits<
653 typename GlobalTileTraits::Threads,
655 GemmConfig_::kScalarsPerStsB>
656 SharedStoreTileTraits;
659 static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
661 typedef WmmaGemmSharedLoadTileBTraits<
669 typename GemmConfig_::Warps,
671 GemmConfig_::InstructionShape::kH * Tile::kW,
673 Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
675 Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
677 typename GemmConfig_::InstructionShape>
678 SharedLoadTileTraits;
683 #ifdef CUTLASS_USE_SUBBYTE_WMMA 684 template <
typename GemmConfig_>
686 struct WmmaGemmTileTraitsHelperB<MatrixLayout::
kColumnMajor, GemmConfig_, Vector<bin1_t, 32> > {
691 typedef typename GemmConfig_::ScalarB Scalar;
693 typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
697 static int const kBitsPerScalar =
sizeof(Scalar) * 8;
703 typename GemmConfig_::InstructionShape>
707 typedef GemmGlobalTileTraits<
715 Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kBitsPerScalar>,
718 GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kBitsPerScalar),
719 GemmConfig_::OutputTile::kD / kBitsPerScalar>,
721 GemmConfig_::kScalarsPerLdgB / kBitsPerScalar>
725 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
727 typedef Shape<GemmConfig_::kStages,
728 GemmConfig_::OutputTile::kH,
729 GemmConfig_::OutputTile::kD / kBitsPerScalar + kSkew>
733 typedef GemmSharedStoreTileAbTraits<
739 typename GlobalTileTraits::Threads,
741 GemmConfig_::kScalarsPerStsB / kBitsPerScalar>
742 SharedStoreTileTraits;
745 static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
747 typedef WmmaGemmSharedLoadTileBTraits<
755 typename GemmConfig_::Warps,
757 GemmConfig_::InstructionShape::kH * Tile::kW,
759 Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
761 Shape<GemmConfig_::InstructionShape::kD / kBitsPerScalar, 0, kScalarsPerW * Tile::kW>,
763 typename GemmConfig_::InstructionShape>
764 SharedLoadTileTraits;
770 #ifdef CUTLASS_USE_SUBBYTE_WMMA 771 template <
typename GemmConfig_>
773 struct WmmaGemmTileTraitsHelperB<MatrixLayout::
kColumnMajor, GemmConfig_, Vector<uint4_t, 8> > {
778 typedef typename GemmConfig_::ScalarB Scalar;
780 typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
784 static int const kInt4PerScalar =
sizeof(Scalar) * 2;
790 typename GemmConfig_::InstructionShape>
794 typedef GemmGlobalTileTraits<
802 Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
805 GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
806 GemmConfig_::OutputTile::kD / kInt4PerScalar>,
808 GemmConfig_::kScalarsPerLdgB / kInt4PerScalar>
812 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
814 typedef Shape<GemmConfig_::kStages,
815 GemmConfig_::OutputTile::kH,
816 GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
820 typedef GemmSharedStoreTileAbTraits<
826 typename GlobalTileTraits::Threads,
828 GemmConfig_::kScalarsPerStsB / kInt4PerScalar>
829 SharedStoreTileTraits;
832 static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
834 typedef WmmaGemmSharedLoadTileBTraits<
842 typename GemmConfig_::Warps,
844 GemmConfig_::InstructionShape::kH * Tile::kW,
846 Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
848 Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
850 typename GemmConfig_::InstructionShape>
851 SharedLoadTileTraits;
857 #ifdef CUTLASS_USE_SUBBYTE_WMMA 858 template <
typename GemmConfig_>
860 struct WmmaGemmTileTraitsHelperB<MatrixLayout::
kColumnMajor, GemmConfig_, Vector<int4_t, 8> > {
865 typedef typename GemmConfig_::ScalarB Scalar;
867 typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
871 static int const kInt4PerScalar =
sizeof(Scalar) * 2;
877 typename GemmConfig_::InstructionShape>
881 typedef GemmGlobalTileTraits<
889 Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
892 GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
893 GemmConfig_::OutputTile::kD / kInt4PerScalar>,
895 GemmConfig_::kScalarsPerLdgB / kInt4PerScalar>
899 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
901 typedef Shape<GemmConfig_::kStages,
902 GemmConfig_::OutputTile::kH,
903 GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
907 typedef GemmSharedStoreTileAbTraits<
913 typename GlobalTileTraits::Threads,
915 GemmConfig_::kScalarsPerStsB / kInt4PerScalar>
916 SharedStoreTileTraits;
919 static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
921 typedef WmmaGemmSharedLoadTileBTraits<
929 typename GemmConfig_::Warps,
931 GemmConfig_::InstructionShape::kH * Tile::kW,
933 Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
935 Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
937 typename GemmConfig_::InstructionShape>
938 SharedLoadTileTraits;
950 typename OutputTile_,
958 typename Accumulator_,
960 typename EpilogueFunctor_,
962 typename WarpGemmShape_,
964 typename InstructionShape_,
966 int kScalarsPerLdgA_,
968 int kScalarsPerLdgB_,
970 int KScalarsPerLdsA_,
972 int KscalarsPerLdsB_,
974 int kScalarsPerLdgCAndStgD_,
976 int kScalarsPerStsD_,
978 int kScalarsPerLdsD_,
981 struct WmmaGemmTraitsHelper {
983 typedef WmmaGemmConfig<kLayoutA_,
996 kScalarsPerLdgCAndStgD_,
1003 typedef WmmaGemmTileTraitsHelperA<kLayoutA_, GemmConfig, ScalarA_> GemmTileTraitsHelperA;
1005 typedef WmmaGemmTileTraitsHelperB<kLayoutB_, GemmConfig, ScalarB_> GemmTileTraitsHelperB;
1008 typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
1009 GlobalLoadIteratorA;
1011 typedef Copy<typename GlobalLoadIteratorA::Fragment> GlobalTransformerA;
1013 typedef TileStoreIterator<
typename GemmTileTraitsHelperA::SharedStoreTileTraits,
1014 typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
1017 SharedStoreIteratorA;
1020 GlobalLoadIteratorA,
1021 SharedStoreIteratorA,
1026 typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
1027 GlobalLoadIteratorB;
1029 typedef Copy<typename GlobalLoadIteratorB::Fragment> GlobalTransformerB;
1031 typedef TileStoreIterator<
typename GemmTileTraitsHelperB::SharedStoreTileTraits,
1032 typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
1035 SharedStoreIteratorB;
1038 GlobalLoadIteratorB,
1039 SharedStoreIteratorB,
1044 typedef TileLoadIterator<
typename GemmTileTraitsHelperA::SharedLoadTileTraits,
1045 typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
1049 typename GemmTileTraitsHelperA::WmmaMatrix,
1051 SharedLoadIteratorA;
1053 typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
1055 typedef TileLoadIterator<
typename GemmTileTraitsHelperB::SharedLoadTileTraits,
1056 typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
1060 typename GemmTileTraitsHelperB::WmmaMatrix,
1062 SharedLoadIteratorB;
1064 typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
1069 typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
1072 typedef WmmaGemmEpilogueTraitsHelper<GemmConfig, Accumulator_, EpilogueFunctor_, Index_> EpilogueTraitsHelper;
1074 typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_, EpilogueTraitsHelper>
1077 typedef GemmEpilogue<GemmEpilogueTraits> Epilogue;
1082 template <
typename OutputTile_,
typename DefaultShape_ = Shape<64, 32, 64> >
1083 struct WmmaGemmAccumulatorsPerWarp {
1095 typename OutputTile_ = Shape<64, 128, 128>,
1097 typename ScalarA_ = half,
1099 typename ScalarB_ = half,
1101 typename ScalarC_ = float,
1103 typename EpilogueFunctor_ = LinearScaling<ScalarC_>,
1105 typename Accumulator_ = ScalarC_,
1107 typename WarpGemmShape_ =
typename WmmaGemmAccumulatorsPerWarp<OutputTile_>::Shape,
1109 typename InstructionShape_ = Shape<16, 16, 16>,
1111 int kScalarsPerLdgA_ = 8,
1113 int kScalarsPerLdgB_ = 8,
1115 int KScalarsPerLdsA_ = 8,
1117 int KscalarsPerLdsB_ = 8,
1119 int kScalarsPerLdgCAndStgD_ = 16 /
sizeof(ScalarC_),
1121 int kScalarsPerStsD_ = 16 /
sizeof(Accumulator_),
1123 int kScalarsPerLdsD_ = 16 /
sizeof(Accumulator_),
1125 typename Index_ =
int,
1127 typename Helper_ = WmmaGemmTraitsHelper<kLayoutA_,
1141 kScalarsPerLdgCAndStgD_,
1145 struct WmmaGemmTraits :
public GemmTraits<
1147 typename Helper_::GemmConfig,
1149 typename Helper_::GlobalLoadStreamA,
1151 typename Helper_::GlobalLoadStreamB,
1153 typename Helper_::SharedLoadStreamA,
1155 typename Helper_::SharedLoadStreamB,
1157 typename Helper_::Epilogue,
1159 IdentityBlockSwizzle,
1163 typename Helper_::ClearAccumulators> {};
1170 #endif // defined CUTLASS_USE_WMMA_API Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:41
Defines iterators for efficiently loading and storing to global memory.
Defines structural properties of complete GEMM computation.
Defines structural properties of WMMA GEMM's epilogue phase.
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Definition: load_store.h:48
Definition: tile_iterator.h:65
Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
Defines iterators for efficiently loading and storing tiles to and from shared memory.
MultiplyAdd_ MultiplyAdd
The functor to do D = A*B + C.
Definition: gemm_config.h:90
Definition: matrix_traits.h:357
Definition: matrix_traits.h:159
Defines tile iterator traits for loading thread block-level tile from global memory.
Definition: matrix_traits.h:159
Implements warp-level matrix multiply-accumulate operation using CUDA WMMA API.
Definition: matrix_traits.h:357
Implements a software-pipelined efficient GEMM.
Defines structural properties of the GEMM epilogue.
Shape<(A_::kD< B_::kD ? A_::kD :B_::kD),(A_::kH< B_::kH ? A_::kH :B_::kH),(A_::kW< B_::kW ? A_::kW :B_::kW),(A_::kC< B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:159
Defines conversion operations among Fragments of different base type.