Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_epilogue.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include "cutlass/convert.h"
33 #include "cutlass/coord.h"
34 #include "cutlass/fragment.h"
35 
36 namespace cutlass {
37 namespace gemm {
38 
40 
41 template <typename GemmEpilogueTraits_>
42 struct GemmEpilogue {
44  typedef GemmEpilogueTraits_ Traits;
46  typedef typename Traits::Params Params;
48  typedef typename Traits::SharedStorage SharedStorage;
49 
51  typedef typename Traits::OutputTile OutputTile;
53  typedef typename Traits::Iterations Iterations;
55  typedef typename Traits::Accumulators Accumulators;
57  typedef typename Traits::Scalar Scalar;
59  typedef typename Traits::Functor Functor;
60 
62  static_assert(Iterations::kD == 1 && Iterations::kC == 1, "Unsupported 3D/4D shapes");
63 
65  typedef typename Traits::GlobalLoadIteratorC GlobalLoadIteratorC;
67  typedef typename Traits::GlobalTransformerC GlobalTransformerC;
69  typedef typename Traits::GlobalTransformerD GlobalTransformerD;
71  typedef typename Traits::GlobalStoreIteratorD GlobalStoreIteratorD;
73  typedef typename Traits::SharedStoreIteratorD SharedStoreIteratorD;
75  typedef typename Traits::SharedStoreTransformerD SharedStoreTransformerD;
77  typedef typename Traits::SharedLoadStreamD SharedLoadStreamD;
78 
80  typedef typename Traits::Index Index;
81 
83  typedef typename GlobalLoadIteratorC::Scalar ScalarC;
85  typedef typename GlobalStoreIteratorD::Scalar ScalarD;
86 
88  CUTLASS_DEVICE GemmEpilogue(Params const& params_,
89  SharedStorage& shared_storage_,
90  Coord<3> const& _problem_size)
91  : params(params_), shared_storage(shared_storage_), problem_size(_problem_size), functor(params_.functor) {}
92 
94  CUTLASS_DEVICE void epilogue(Accumulators& accumulators,
95  Coord<3> const& block = make_Coord(0, 0, 0),
96  int batch_id = 0) {
97  if (functor.source_required()) {
98  epilogue_with_or_without_beta<true>(accumulators, block, batch_id);
99  } else {
100  epilogue_with_or_without_beta<false>(accumulators, block, batch_id);
101  }
102  }
103 
104  template <bool kSourceRequired>
105  CUTLASS_DEVICE void epilogue_with_or_without_beta(Accumulators& accumulators,
106  Coord<3> const& block,
107  int batch_id) {
108  // The C fragment.
109  typename GlobalLoadIteratorC::Fragment fragment_c;
110  // The transformed C fragment.
111  typename GlobalTransformerC::OutputFragment transformed_c;
113  for (int h = 0; h < Iterations::kH; ++h) {
114  // Compute pointer and predicate offsets for C and D global iterators.
115  int const pointer_offset =
116  ((params.iterator_d.inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
117  params.iterator_d.inc_advance) *
118  Iterations::kW +
119  params.stride_h) *
120  h;
121 
122  int const predicate_offset =
123  ((params.iterator_d.predicate_inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
124  params.iterator_d.predicate_inc_advance) *
125  Iterations::kW +
126  Traits::Delta::kH) *
127  h;
128 
129  // The iterator to load the elements of the C matrix.
130  GlobalLoadIteratorC global_load_iterator(
131  params.iterator_c, problem_size, block, pointer_offset, predicate_offset);
132 
133  // update C pointer offset based on batch_id and batch_stride_offset
134  global_load_iterator.add_pointer_offset(batch_id * params.batch_stride_C);
135 
136  // The transformer for C.
137  GlobalTransformerC transformer_c;
138  // The transformer for D.
139  GlobalTransformerD transformer_d;
140 
141  // The iterator to store into the D matrix.
142  GlobalStoreIteratorD global_store_iterator(
143  params.iterator_d, problem_size, block, pointer_offset, predicate_offset);
144 
145  // update D pointer offset based on batch_id and batch_stride_offset
146  global_store_iterator.add_pointer_offset(batch_id * params.batch_stride_D);
147 
148  SharedStoreTransformerD shared_store_transformer;
149  typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
150 
151  SharedStoreIteratorD shared_store_iterator(
152  params.shared_store_iterator_d,
153  reinterpret_cast<typename SharedStoreIteratorD::Scalar*>(shared_storage.data()));
154 
155  SharedLoadStreamD shared_load_stream(
156  params.shared_load_stream_d,
157  reinterpret_cast<typename SharedLoadStreamD::Scalar*>(shared_storage.data()));
158 
160  for (int w = 0; w < Iterations::kW; ++w) {
161  // Load the C matrix into fragment.
162  if (kSourceRequired) {
163  global_load_iterator.load_post_increment(fragment_c);
164  }
165 
166  // Make sure we can write to shared memory.
168 
169  // Copy the accumulators to shared memory.
170  int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
171 
172  shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
173 
174  shared_store_iterator.store_post_increment(shared_store_transformed_d);
175 
176  // Make sure the data is in shared memory.
178 
179  // Copy the accumulators back to registers from shared memory.
180  shared_load_stream.copy();
181  shared_load_stream.commit();
182 
183  // Do the math.
184  typename GlobalTransformerD::InputFragment fragment_d;
185  if (kSourceRequired) {
186  // Transform C fragment.
187  transformer_c.transform(fragment_c, transformed_c);
188  // Do the math.
189  functor.evaluate(shared_load_stream.fragment(), transformed_c, fragment_d);
190  } else {
191  functor.evaluate(shared_load_stream.fragment(), fragment_d);
192  }
193 
194  // Transform D fragment.
195  typename GlobalTransformerD::OutputFragment global_transformed_d;
196  transformer_d.transform(fragment_d, global_transformed_d);
197 
198  // Copy the results to global memory.
199  global_store_iterator.store_post_increment(global_transformed_d);
200  }
201  }
202  }
203 
205  CUTLASS_DEVICE void shared_load_fence() { __syncthreads(); }
206 
208  CUTLASS_DEVICE void shared_store_fence() { __syncthreads(); }
209 
211  Params const& params;
216  // The functor.
218 };
219 
221 
222 } // namespace gemm
223 } // namespace cutlass
GlobalStoreIteratorD::Scalar ScalarD
The scalar for D.
Definition: gemm_epilogue.h:85
Coord< 3 > problem_size
The dimensions of the GEMM.
Definition: gemm_epilogue.h:215
Traits::SharedStoreIteratorD SharedStoreIteratorD
The iterator to store D in shared memory.
Definition: gemm_epilogue.h:73
#define CUTLASS_PRAGMA_UNROLL
Definition: performance_tuning.h:35
Definition: convert.h:33
Traits::Params Params
The params.
Definition: gemm_epilogue.h:46
Definition: gemm_epilogue.h:42
CUTLASS_DEVICE void epilogue_with_or_without_beta(Accumulators &accumulators, Coord< 3 > const &block, int batch_id)
Definition: gemm_epilogue.h:105
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Functor functor
Definition: gemm_epilogue.h:217
Traits::SharedStorage SharedStorage
The shared storage.
Definition: gemm_epilogue.h:48
Traits::GlobalTransformerD GlobalTransformerD
The transformer for D.
Definition: gemm_epilogue.h:69
CUTLASS_DEVICE GemmEpilogue(Params const &params_, SharedStorage &shared_storage_, Coord< 3 > const &_problem_size)
Ctor.
Definition: gemm_epilogue.h:88
Traits::OutputTile OutputTile
The output tile.
Definition: gemm_epilogue.h:51
Traits::Accumulators Accumulators
The accumulators.
Definition: gemm_epilogue.h:55
CUTLASS_DEVICE void shared_load_fence()
The memory fence for shared loads.
Definition: gemm_epilogue.h:205
SharedStorage & shared_storage
The shared storage.
Definition: gemm_epilogue.h:213
GemmEpilogueTraits_ Traits
The traits class.
Definition: gemm_epilogue.h:44
Params const & params
The params.
Definition: gemm_epilogue.h:211
Traits::Index Index
The index.
Definition: gemm_epilogue.h:80
#define static_assert(__e, __m)
Definition: platform.h:153
Traits::SharedStoreTransformerD SharedStoreTransformerD
The shared store transformer for D.
Definition: gemm_epilogue.h:75
Traits::GlobalStoreIteratorD GlobalStoreIteratorD
The iterator for D in global memory.
Definition: gemm_epilogue.h:71
GlobalLoadIteratorC::Scalar ScalarC
The scalar for C.
Definition: gemm_epilogue.h:83
Traits::SharedLoadStreamD SharedLoadStreamD
The iterator to load D in shared memory.
Definition: gemm_epilogue.h:77
Traits::Functor Functor
The functor in charge of the math.
Definition: gemm_epilogue.h:59
CUTLASS_DEVICE void epilogue(Accumulators &accumulators, Coord< 3 > const &block=make_Coord(0, 0, 0), int batch_id=0)
Execute the epilogue.
Definition: gemm_epilogue.h:94
Traits::Iterations Iterations
The number of iterations.
Definition: gemm_epilogue.h:53
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
Traits::Scalar Scalar
The scalar.
Definition: gemm_epilogue.h:57
Defines conversion operations among Fragments of different base type.
CUTLASS_DEVICE void shared_store_fence()
The memory fence for shared stores.
Definition: gemm_epilogue.h:208
Traits::GlobalTransformerC GlobalTransformerC
The transformer for C.
Definition: gemm_epilogue.h:67
Traits::GlobalLoadIteratorC GlobalLoadIteratorC
We do not support 3D or 4D shapes.
Definition: gemm_epilogue.h:62