41 template <
typename GemmEpilogueTraits_>
46 typedef typename Traits::Params
Params;
57 typedef typename Traits::Scalar
Scalar;
62 static_assert(Iterations::kD == 1 && Iterations::kC == 1,
"Unsupported 3D/4D shapes");
80 typedef typename Traits::Index
Index;
83 typedef typename GlobalLoadIteratorC::Scalar
ScalarC;
85 typedef typename GlobalStoreIteratorD::Scalar
ScalarD;
97 if (
functor.source_required()) {
98 epilogue_with_or_without_beta<true>(accumulators, block, batch_id);
100 epilogue_with_or_without_beta<false>(accumulators, block, batch_id);
104 template <
bool kSourceRequired>
109 typename GlobalLoadIteratorC::Fragment fragment_c;
111 typename GlobalTransformerC::OutputFragment transformed_c;
113 for (
int h = 0; h < Iterations::kH; ++h) {
115 int const pointer_offset =
116 ((
params.iterator_d.inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
117 params.iterator_d.inc_advance) *
122 int const predicate_offset =
123 ((
params.iterator_d.predicate_inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
124 params.iterator_d.predicate_inc_advance) *
134 global_load_iterator.add_pointer_offset(batch_id *
params.batch_stride_C);
146 global_store_iterator.add_pointer_offset(batch_id *
params.batch_stride_D);
149 typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
152 params.shared_store_iterator_d,
153 reinterpret_cast<typename SharedStoreIteratorD::Scalar*>(
shared_storage.data()));
156 params.shared_load_stream_d,
157 reinterpret_cast<typename SharedLoadStreamD::Scalar*>(
shared_storage.data()));
160 for (
int w = 0; w < Iterations::kW; ++w) {
162 if (kSourceRequired) {
163 global_load_iterator.load_post_increment(fragment_c);
170 int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
172 shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
174 shared_store_iterator.store_post_increment(shared_store_transformed_d);
180 shared_load_stream.copy();
181 shared_load_stream.commit();
184 typename GlobalTransformerD::InputFragment fragment_d;
185 if (kSourceRequired) {
187 transformer_c.transform(fragment_c, transformed_c);
189 functor.evaluate(shared_load_stream.fragment(), transformed_c, fragment_d);
191 functor.evaluate(shared_load_stream.fragment(), fragment_d);
195 typename GlobalTransformerD::OutputFragment global_transformed_d;
196 transformer_d.transform(fragment_d, global_transformed_d);
199 global_store_iterator.store_post_increment(global_transformed_d);
GlobalStoreIteratorD::Scalar ScalarD
The scalar for D.
Definition: gemm_epilogue.h:85
Coord< 3 > problem_size
The dimensions of the GEMM.
Definition: gemm_epilogue.h:215
Traits::SharedStoreIteratorD SharedStoreIteratorD
The iterator to store D in shared memory.
Definition: gemm_epilogue.h:73
Traits::Params Params
The params.
Definition: gemm_epilogue.h:46
Definition: gemm_epilogue.h:42
CUTLASS_DEVICE void epilogue_with_or_without_beta(Accumulators &accumulators, Coord< 3 > const &block, int batch_id)
Definition: gemm_epilogue.h:105
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Functor functor
Definition: gemm_epilogue.h:217
Traits::SharedStorage SharedStorage
The shared storage.
Definition: gemm_epilogue.h:48
Traits::GlobalTransformerD GlobalTransformerD
The transformer for D.
Definition: gemm_epilogue.h:69
CUTLASS_DEVICE GemmEpilogue(Params const ¶ms_, SharedStorage &shared_storage_, Coord< 3 > const &_problem_size)
Ctor.
Definition: gemm_epilogue.h:88
Traits::OutputTile OutputTile
The output tile.
Definition: gemm_epilogue.h:51
Traits::Accumulators Accumulators
The accumulators.
Definition: gemm_epilogue.h:55
CUTLASS_DEVICE void shared_load_fence()
The memory fence for shared loads.
Definition: gemm_epilogue.h:205
SharedStorage & shared_storage
The shared storage.
Definition: gemm_epilogue.h:213
GemmEpilogueTraits_ Traits
The traits class.
Definition: gemm_epilogue.h:44
Params const & params
The params.
Definition: gemm_epilogue.h:211
Traits::Index Index
The index.
Definition: gemm_epilogue.h:80
Traits::SharedStoreTransformerD SharedStoreTransformerD
The shared store transformer for D.
Definition: gemm_epilogue.h:75
Traits::GlobalStoreIteratorD GlobalStoreIteratorD
The iterator for D in global memory.
Definition: gemm_epilogue.h:71
GlobalLoadIteratorC::Scalar ScalarC
The scalar for C.
Definition: gemm_epilogue.h:83
Traits::SharedLoadStreamD SharedLoadStreamD
The iterator to load D in shared memory.
Definition: gemm_epilogue.h:77
Traits::Functor Functor
The functor in charge of the math.
Definition: gemm_epilogue.h:59
CUTLASS_DEVICE void epilogue(Accumulators &accumulators, Coord< 3 > const &block=make_Coord(0, 0, 0), int batch_id=0)
Execute the epilogue.
Definition: gemm_epilogue.h:94
Traits::Iterations Iterations
The number of iterations.
Definition: gemm_epilogue.h:53
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...
Traits::Scalar Scalar
The scalar.
Definition: gemm_epilogue.h:57
Defines conversion operations among Fragments of different base type.
CUTLASS_DEVICE void shared_store_fence()
The memory fence for shared stores.
Definition: gemm_epilogue.h:208
Traits::GlobalTransformerC GlobalTransformerC
The transformer for C.
Definition: gemm_epilogue.h:67
Traits::GlobalLoadIteratorC GlobalLoadIteratorC
We do not support 3D or 4D shapes.
Definition: gemm_epilogue.h:62