44 template <
int kElements_>
52 static_assert(kElements_ % 4 == 0,
"kElements must be multiple of 4");
63 template <
typename Fragment_>
66 float4
const* src_f4 =
reinterpret_cast<float4 const*
>(&src[0]);
68 int* dst_int =
reinterpret_cast<int*
>(&dst[0]);
71 for (
int i = 0; i < kElements_ / 4; ++i) {
73 float4 f4 = src_f4[i];
76 float x = fmaxf(-128.f, fminf(127.f, f4.x));
77 float y = fmaxf(-128.f, fminf(127.f, f4.y));
78 float z = fmaxf(-128.f, fminf(127.f, f4.z));
79 float w = fmaxf(-128.f, fminf(127.f, f4.w));
88 asm volatile(
"prmt.b32 %0, %0, %1, 0x1140;" :
"+r"(ix) :
"r"(iy));
89 asm volatile(
"prmt.b32 %0, %0, %1, 0x1140;" :
"+r"(iz) :
"r"(iw));
90 asm volatile(
"prmt.b32 %0, %0, %1, 0x5410;" :
"+r"(ix) :
"r"(iz));
100 template <
typename InputScalar_,
typename OutputFragment_>
105 template <
int kElements_>
112 template <
int kElements_>
120 static_assert(kElements_ % 4 == 0,
"kElements must be multiple of 4");
131 template <
typename Fragment_>
134 int const* src_int =
reinterpret_cast<int const*
>(&src[0]);
136 float4* dst_f4 =
reinterpret_cast<float4*
>(&dst[0]);
139 for (
int i = 0; i < kElements_ / 4; ++i) {
141 int ix, iy, iz, iw = src_int[i];
144 asm volatile(
"prmt.b32 %0, 0x0, %1, 0x4440;" :
"=r"(ix) :
"r"(iw));
145 asm volatile(
"prmt.b32 %0, 0x0, %1, 0x4441;" :
"=r"(iy) :
"r"(iw));
146 asm volatile(
"prmt.b32 %0, 0x0, %1, 0x4442;" :
"=r"(iz) :
"r"(iw));
147 asm volatile(
"prmt.b32 %0, 0x0, %1, 0x4443;" :
"=r"(iw) :
"r"(iw));
150 float fx, fy, fz, fw;
153 asm volatile(
"cvt.rn.f32.s8 %0, %1;" :
"=f"(fx) :
"r"(ix));
154 asm volatile(
"cvt.rn.f32.s8 %0, %1;" :
"=f"(fy) :
"r"(iy));
155 asm volatile(
"cvt.rn.f32.s8 %0, %1;" :
"=f"(fz) :
"r"(iz));
156 asm volatile(
"cvt.rn.f32.s8 %0, %1;" :
"=f"(fw) :
"r"(iw));
159 dst_f4[i] = make_float4(fx, fy, fz, fw);
166 template <
typename InputFragment_,
typename OutputScalar_>
171 template <
int kElements_>
178 template <
typename InputScalar_,
typename OutputFragment_>
185 template <
typename IgemmConfig_,
typename EpilogueFunctor_,
typename Index_>
248 typename IgemmConfig_,
250 typename EpilogueFunctor_,
252 typename Index_ = int,
257 typename IgemmConfig_::OutputTile,
259 typename IgemmConfig_::Accumulators,
261 typename Helper_::GlobalLoadIteratorC,
263 typename Helper_::GlobalTransformerC,
265 typename Helper_::GlobalTransformerD,
267 typename Helper_::GlobalStoreIteratorD,
269 typename Helper_::SharedStoreIteratorD,
271 typename Helper_::SharedStoreTransformerD,
273 typename Helper_::SharedLoadStreamD,
275 typename Helper_::Iterations,
277 typename Helper_::Delta,
289 template <
typename GemmEpilogueTraits_,
bool = GemmEpilogueTraits_::kInt8Output>
298 :
Base(params_, shared_storage_, _problem_size) {}
303 template <
typename GemmEpilogueTraits_>
312 :
Base(params_, shared_storage_, _problem_size) {}
Definition: gemm_global_tile.h:120
Definition: igemm_epilogue.h:255
Definition: load_store.h:41
Base::Delta Delta
The iterations strides.
Definition: igemm_epilogue.h:198
Base::SharedStoreTileTraits SharedStoreTileTraits
The traits class for the shared iterator to store D to shared memory.
Definition: igemm_epilogue.h:221
IgemmGlobalStoreTransformer< Scalar, GlobalFragmentD >::Transformer GlobalTransformerD
The transformer from accumulators to shared memory fragments.
Definition: igemm_epilogue.h:218
Base::SharedLoadTileTraits SharedLoadTileTraits
The traits class for the shared iterator to load D from shared memory.
Definition: igemm_epilogue.h:235
TileLoadIterator< SharedLoadTileTraits, typename SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorD
The shared iterator to load D from shared memory.
Definition: igemm_epilogue.h:241
Definition: gemm_epilogue_traits.h:203
GemmEpilogue< GemmEpilogueTraits_ > Base
The base class.
Definition: igemm_epilogue.h:292
Traits::Params Params
The params.
Definition: gemm_epilogue.h:46
Definition: gemm_epilogue.h:42
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
CUTLASS_DEVICE IgemmInt8ToFloatConverter()
Ctor.
Definition: igemm_epilogue.h:123
SharedStoreIteratorD::Fragment SharedStoreFragmentD
The fragment that needs to be passed to that store iterator.
Definition: igemm_epilogue.h:229
EpilogueFunctor_::Scalar Scalar
The scalar.
Definition: gemm_epilogue_traits.h:205
Definition: igemm_epilogue.h:186
Definition: load_store.h:42
Fragment< int8_t, kElements_ > InputFragment
The input fragment.
Definition: igemm_epilogue.h:115
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:199
Definition: igemm_epilogue.h:290
Definition: igemm_epilogue.h:45
CUTLASS_DEVICE void transform(Fragment_ const &src, int offset, OutputFragment &dst)
Transform a fragment.
Definition: igemm_epilogue.h:64
Traits::SharedStorage SharedStorage
The shared storage.
Definition: gemm_epilogue.h:48
A template defining Fragment Concept.
Definition: fragment.h:99
Definition: tile_iterator.h:65
CUTLASS_DEVICE void transform(InputFragment const &src, OutputFragment &dst)
Transform a fragment.
Definition: igemm_epilogue.h:126
Base::Scalar Scalar
The scalar type of the epilogue.
Definition: igemm_epilogue.h:194
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const ¶ms_, typename Base::SharedStorage &shared_storage_, Coord< 3 > const &_problem_size)
Ctor.
Definition: igemm_epilogue.h:295
GlobalLoadIteratorC::Fragment GlobalFragmentC
The fragment that needs to be produced by the load iterator.
Definition: igemm_epilogue.h:205
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:909
CUTLASS_DEVICE void transform(InputFragment const &src, OutputFragment &dst)
Transform a fragment.
Definition: igemm_epilogue.h:58
Fragment< int8_t, kElements_ > OutputFragment
The output fragment.
Definition: igemm_epilogue.h:49
GemmGlobalIteratorCd< GlobalStoreTileTraits > GlobalStoreIteratorD
The iterator to store to shared memory.
Definition: igemm_epilogue.h:213
IgemmSharedStoreTransformer< typename IgemmConfig::Accumulators::Element, SharedStoreFragmentD >::Transformer SharedStoreTransformerD
The transformer from accumulators to shared memory fragments.
Definition: igemm_epilogue.h:233
static bool const kInt8Output
Do we output in int8?
Definition: igemm_epilogue.h:283
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:402
GemmEpilogue< GemmEpilogueTraits_ > Base
The base class.
Definition: igemm_epilogue.h:306
Defines a type for restructuring a tile.
Base::GlobalLoadTileTraits GlobalLoadTileTraits
The traits class for the iterator.
Definition: igemm_epilogue.h:201
Fragment< float, kElements_ > OutputFragment
The output fragment.
Definition: igemm_epilogue.h:117
GemmEpilogueTraitsHelper< IgemmConfig_, EpilogueFunctor_, Index_ > Base
The base class.
Definition: igemm_epilogue.h:189
Definition: gemm_shared_tile.h:339
GlobalStoreIteratorD::Fragment GlobalFragmentD
The fragment that needs to be passed to that store iterator.
Definition: igemm_epilogue.h:215
GemmGlobalIteratorCd< GlobalLoadTileTraits > GlobalLoadIteratorC
The iterator to store to shared memory.
Definition: igemm_epilogue.h:203
IgemmConfig_ IgemmConfig
The config.
Definition: igemm_epilogue.h:191
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
CUTLASS_DEVICE IgemmFloatToInt8Converter()
Ctor.
Definition: igemm_epilogue.h:55
Element_ Element
The element.
Definition: fragment.h:108
Fragment< float, kElements_ > InputFragment
The input fragment.
Definition: igemm_epilogue.h:47
Definition: gemm_epilogue_traits.h:70
Definition: gemm_global_tile.h:366
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
Base::Iterations Iterations
The iterations.
Definition: igemm_epilogue.h:196
IgemmGlobalLoadTransformer< GlobalFragmentC, Scalar >::Transformer GlobalTransformerC
The transformer from loaded data to math fragment.
Definition: igemm_epilogue.h:208
Base::GlobalStoreTileTraits GlobalStoreTileTraits
The traits class for the iterator.
Definition: igemm_epilogue.h:211
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:272
CUTLASS_DEVICE void transform(Fragment_ const &src, int offset, OutputFragment &dst)
Transform a fragment.
Definition: igemm_epilogue.h:132
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
TileStoreIterator< SharedStoreTileTraits, typename SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kGlobal > SharedStoreIteratorD
The shared iterator to store D to shared memory.
Definition: igemm_epilogue.h:227
Defines conversion operations among Fragments of different base type.
Definition: igemm_epilogue.h:113
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:341
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const ¶ms_, typename Base::SharedStorage &shared_storage_, Coord< 3 > const &_problem_size)
Ctor.
Definition: igemm_epilogue.h:309
Implements tile iterators to partition the thread block tile into 2D subtiles and efficiently load ea...
Definition: gemm_shared_tile.h:270
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:841