Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_global_stream.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include "cutlass/coord.h"
33 #include "cutlass/convert.h"
36 
37 namespace cutlass {
38 namespace gemm {
39 
41 
42 template <
44  GemmOperand::Kind Operand,
46  typename LoadIterator_,
48  typename StoreIterator_,
50  typename Transformer_>
51 
54  static GemmOperand::Kind const kOperand = Operand;
56  typedef LoadIterator_ LoadIterator;
58  typedef Transformer_ Transformer;
60  typedef StoreIterator_ StoreIterator;
61 
63  typedef typename LoadIterator::Fragment FetchedFragment;
65  typedef typename Transformer::OutputFragment TransformedFragment;
68  "");
73  "");
74 
76  static MatrixLayout::Kind const kLayout = LoadIterator::kLayout;
78  typedef typename LoadIterator::Scalar Scalar;
80  typedef typename LoadIterator::Pointer Pointer;
82  typedef typename LoadIterator::Index Index;
84  typedef typename LoadIterator::LongIndex LongIndex;
86  typedef typename LoadIterator::Tile Tile;
87 
91 
94 
96  struct Params {
97  // The load iterator.
98  typename LoadIterator::Params load_iterator;
99 
102 
103  // The store iterator.
104  typename StoreIterator::Params store_iterator;
105 
106  // Offset to residue.
108 
109  // Offset to residue for the last partition
111 
114  LongIndex batch_stride_,
115  Index ldm,
116  Index offset_to_residue_,
117  Index offset_to_residue_last_partition_) {
118 
119  int error_code = load_iterator.initialize(pointer, ldm, ldm);
120  if (error_code) {
121  return error_code;
122  }
123 
124  batch_stride = batch_stride_;
125  offset_to_residue = offset_to_residue_;
126  offset_to_residue_last_partition = offset_to_residue_last_partition_;
127 
128  return store_iterator.initialize();
129  }
130 
131  CUTLASS_DEVICE Index get_offset_to_residue() {
132  if (blockIdx.z == gridDim.z - 1) { //last partition
134  }
135  else {
136  return offset_to_residue;
137  }
138  }
139  };
140 
144  struct SharedStorage {};
145 
146  //
147  // Static member functions
148  //
149 
151  CUTLASS_HOST_DEVICE static Coord<3> project_coordinate(Coord<3> const& coord, Index d_offset = 0) {
152  bool const kKstrided =
155  return make_Coord(
156  tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC);
157  }
158 
160  CUTLASS_DEVICE GlobalLoadStream(
161  Params const& _params,
162  SharedStorage& shared_storage,
163  ThreadblockTileRef const& threadblock_tile_ref,
164  Coord<3> const bounds,
165  Coord<3> const& _threadblock_offset)
166  : params(_params),
167  threadblock_offset(project_coordinate(_threadblock_offset)),
170  transformer(),
171  store_iterator(params.store_iterator, threadblock_tile_ref.data()) {
172  load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
173  fetched_fragment.clear();
174  }
175 
176 
178  CUTLASS_DEVICE void copy() {
179  load_iterator.load_post_increment(fetched_fragment);
180  }
181 
183  CUTLASS_DEVICE void commit() {
185  store_iterator.store_post_increment(transformed_fragment);
186  store_iterator.inc_stage();
187  }
188 
190  CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
191  load_iterator.residue(k);
192  if (!skip_clear) {
193  fetched_fragment.clear();
194  }
195  }
196 
198  CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
199  Index kResidue = k % kTileK;
200  if (kResidue) {
201  residue(kResidue);
202  Index this_offset_residue = params.get_offset_to_residue();
203  load_iterator.add_pointer_offset(this_offset_residue * load_iterator.stride_advance());
204  }
205  }
206 
208  CUTLASS_DEVICE void rollback(void) {
209  load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
210 
211  int const kBlock = kOperand == GemmOperand::kA
212  ? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW)
213  : (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW);
214  Index this_offset_residue = params.get_offset_to_residue();
215  load_iterator.add_pointer_offset(-(this_offset_residue + kBlock) *
216  load_iterator.stride_advance());
217  }
218 
220  CUTLASS_DEVICE GlobalLoadStream &operator+=(Coord<3> const &offset) {
221  load_iterator += offset;
222  return *this;
223  }
224 
226  CUTLASS_DEVICE GlobalLoadStream &add_batch_offset(int batch_id) {
227  load_iterator.add_pointer_offset(batch_id * params.batch_stride);
228  return *this;
229  }
230 
231  //
232  // Data members
233  //
234 
251 };
252 
254 } // namespace gemm
255 } // namespace cutlass
ThreadblockTileStorage::TensorRef ThreadblockTileRef
Tensor reference to threadblock tile.
Definition: gemm_global_stream.h:93
LoadIterator::Pointer Pointer
The pointer.
Definition: gemm_global_stream.h:80
LoadIterator load_iterator
The iterator.
Definition: gemm_global_stream.h:242
Definition: convert.h:33
static CUTLASS_HOST_DEVICE Coord< 3 > project_coordinate(Coord< 3 > const &coord, Index d_offset=0)
Maps a coordinate in the GEMM&#39;s (K, N, M) coordinate system to global memory.
Definition: gemm_global_stream.h:151
StoreIterator store_iterator
The store iterator.
Definition: gemm_global_stream.h:250
Params params
Parameters.
Definition: gemm_global_stream.h:236
Defines iterators for efficiently loading and storing to global memory.
std::is_same (false specialization)
Definition: platform.h:420
TensorRef< Scalar, 4 > TensorRef
Defines the tensor reference for this allocation.
Definition: tile_allocation.h:63
static GemmOperand::Kind const kOperand
Indicates the type of GEMM operand.
Definition: gemm_global_stream.h:54
CUTLASS_DEVICE GlobalLoadStream & operator+=(Coord< 3 > const &offset)
Adds a Coord<3> to the underlying global load iterator.
Definition: gemm_global_stream.h:220
CUTLASS_DEVICE void copy()
Load the data from shared memory to the fetch fragment.
Definition: gemm_global_stream.h:178
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Coord< 3 > multiplicand_bounds
Multiplicand bounds.
Definition: gemm_global_stream.h:240
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
static MatrixLayout::Kind const kLayout
Make sure the transformed fragment is the same as the store fragment.
Definition: gemm_global_stream.h:76
StoreIterator::Params store_iterator
Definition: gemm_global_stream.h:104
LoadIterator::LongIndex LongIndex
The index.
Definition: gemm_global_stream.h:84
FetchedFragment fetched_fragment
The fragment to fetch from shared memory.
Definition: gemm_global_stream.h:244
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, LongIndex batch_stride_, Index ldm, Index offset_to_residue_, Index offset_to_residue_last_partition_)
Setup the params.
Definition: gemm_global_stream.h:113
Definition: gemm_global_stream.h:52
Definition: gemm_global_stream.h:144
LoadIterator::Scalar Scalar
The scalar type of the iterator.
Definition: gemm_global_stream.h:78
CUTLASS_DEVICE void residue(Index k, bool skip_clear=false)
Execute the residue code.
Definition: gemm_global_stream.h:190
TransformedFragment transformed_fragment
The fragment to convert the data after it has been fetched from shared memory.
Definition: gemm_global_stream.h:248
Defines a fragment based on a Shape<> template.
Index offset_to_residue
Definition: gemm_global_stream.h:107
TransformedFragment Fragment
Make sure the fragments match.
Definition: gemm_global_stream.h:68
CUTLASS_DEVICE GlobalLoadStream & add_batch_offset(int batch_id)
Adds an offset based on batch stride.
Definition: gemm_global_stream.h:226
LoadIterator_ LoadIterator
The load iterator.
Definition: gemm_global_stream.h:56
Definition: gemm_operand.h:67
Index offset_to_residue_last_partition
Definition: gemm_global_stream.h:110
CUTLASS_DEVICE void commit()
Commit the data.
Definition: gemm_global_stream.h:183
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Class for storing a tile in memory and accessing it through a tensor ref.
Definition: tile_allocation.h:42
Transformer transformer
The transformer.
Definition: gemm_global_stream.h:246
#define static_assert(__e, __m)
Definition: platform.h:153
LongIndex batch_stride
Batch stride in global memory.
Definition: gemm_global_stream.h:101
Definition: matrix_traits.h:159
Definition: gemm_operand.h:96
Definition: matrix_traits.h:159
StoreIterator_ StoreIterator
The store iterator to write to shared memory.
Definition: gemm_global_stream.h:60
TileAllocation< typename StoreIterator::Scalar, typename StoreIterator::Tile > ThreadblockTileStorage
Shared memory allocation for the tile.
Definition: gemm_global_stream.h:90
LoadIterator::Params load_iterator
Definition: gemm_global_stream.h:98
The params.
Definition: gemm_global_stream.h:96
Transformer_ Transformer
The transformer.
Definition: gemm_global_stream.h:58
Coord< 3 > threadblock_offset
Threadblock offset.
Definition: gemm_global_stream.h:238
LoadIterator::Index Index
The index.
Definition: gemm_global_stream.h:82
Kind
Definition: matrix_traits.h:357
Definition: matrix_traits.h:357
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK)
Move to the residue portion.
Definition: gemm_global_stream.h:198
LoadIterator::Fragment FetchedFragment
The fragment that is copied from shared memory.
Definition: gemm_global_stream.h:63
Transformer::OutputFragment TransformedFragment
The fragment that is obtained after the transformation by the transformer.
Definition: gemm_global_stream.h:65
CUTLASS_DEVICE GlobalLoadStream(Params const &_params, SharedStorage &shared_storage, ThreadblockTileRef const &threadblock_tile_ref, Coord< 3 > const bounds, Coord< 3 > const &_threadblock_offset)
Ctor.
Definition: gemm_global_stream.h:160
Defines conversion operations among Fragments of different base type.
CUTLASS_DEVICE Index get_offset_to_residue()
Definition: gemm_global_stream.h:131
LoadIterator::Tile Tile
The tile.
Definition: gemm_global_stream.h:86
CUTLASS_DEVICE void rollback(void)
Rollback to the beginning of the first tile.
Definition: gemm_global_stream.h:208