Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
igemm_global_tile.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
33 #pragma once
34 
35 #include "cutlass/coord.h"
37 #include "cutlass/matrix_traits.h"
38 
39 namespace cutlass {
40 namespace gemm {
41 
43 
44 template <GemmOperand::Kind kOperand_,
45  MatrixLayout::Kind kLayout_,
46  typename Scalar_,
47  typename Tile_,
48  typename Threads_,
49  int kAccessSize_>
51  // Which GEMM operand?
52  kOperand_,
53  // The layout.
54  kLayout_,
55  // The scalar.
56  Scalar_,
57  // The tile.
58  Tile_,
59  // The threads.
60  Threads_,
61  // The number of scalars per LDG/STG.
62  kAccessSize_> {
66  typedef typename Base::Threads Threads;
70  typedef Shape<Base::VectorizedTile::kH / Base::Threads::kH / 4,
71  4,
72  Base::VectorizedTile::kW / Base::Threads::kW,
73  Base::VectorizedTile::kC / Base::kAccessSize>
75 
77  struct ThreadOffset {
79  Coord<4> operator()() const {
80  int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
81  int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
82 
83  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
84  }
85  };
86 
87  public:
90 };
91 
93 
94 template <typename TileTraits_, typename Index_ = int>
95 struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb<TileTraits_, Index_> {
99  typedef typename TileTraits_::ThreadOffset ThreadOffset;
100 
102  CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const& _params,
103  const Coord<3>& threadblock_offset,
104  ThreadOffset thread_offset_func = ThreadOffset())
105  : Base(_params, threadblock_offset, thread_offset_func), mask_(0xffffffff) { }
106 
107  CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& threadblock_offset) {
108 
109  Base::initialize_predicates(bounds, threadblock_offset);
110  // The number of elements read in a single iteration.
111  int const kBlock = TileTraits_::Tile::kW;
112  // The residue.
113  int const kResidue = (int)(bounds[1] % kBlock);
114 
115  // Compute the number of elements that are valid.
116  int const left = kResidue - Base::thread_offset[2];
117  if (left > 0 && left < 4) {
118  mask_ = (1u << (8 * left)) - 1u;
119  }
120  }
121 
122  CUTLASS_DEVICE void load_element(
123  typename Base::AccessType& value, int d, int h, int w, int c) const {
124  Base::load_element(value, d, h, w, c);
125  reinterpret_cast<uint32_t&>(value) &= mask_;
126  }
127 
129  uint32_t mask_;
130 };
131 
133 
134 } // namespace gemm
135 } // namespace cutlass
Definition: convert.h:33
Base::Threads Threads
The threads.
Definition: igemm_global_tile.h:66
Computes the thread offset in (H, W) based on thread ID.
Definition: igemm_global_tile.h:77
Defines iterators for efficiently loading and storing to global memory.
Definition: gemm_global_tile.h:70
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Shape< Base::VectorizedTile::kH/Base::Threads::kH/4, 4, Base::VectorizedTile::kW/Base::Threads::kW, Base::VectorizedTile::kC/Base::kAccessSize > Iterations
The number of iterations needed to load/store the tile.
Definition: igemm_global_tile.h:74
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
CUTLASS_HOST_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: gemm_global_tile.h:267
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: igemm_global_tile.h:79
Definition: gemm_global_tile.h:163
static int const kH
The height of the cube.
Definition: shape.h:68
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:402
Definition: igemm_global_tile.h:50
CUTLASS_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Definition: igemm_global_tile.h:122
GemmGlobalIteratorAb< TileTraits_, Index_ > Base
The base class.
Definition: igemm_global_tile.h:97
Definition: igemm_global_tile.h:95
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Definition: vector.h:62
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
TileTraits_::ThreadOffset ThreadOffset
The functor to compute the thread offset.
Definition: igemm_global_tile.h:99
uint32_t mask_
The mask to clean up the values.
Definition: igemm_global_tile.h:129
ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:88
Shape< 1, 4, Base::VectorizedTile::kC > ThreadsDelta
The threads strides.
Definition: igemm_global_tile.h:89
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:194
static int const kW
The width of the cube.
Definition: shape.h:70
CUTLASS_DEVICE void initialize_predicates(const Coord< 3 > &bounds, const Coord< 3 > &threadblock_offset)
Definition: igemm_global_tile.h:107
Parameters.
Definition: tile_iterator.h:497
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_global_tile.h:80
Kind
Definition: matrix_traits.h:357
Shape< Base::Threads::kH *4, 1, Base::Threads::kW, Base::kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: igemm_global_tile.h:68
Defines properties of matrices used to denote layout and operands to GEMM kernels.
CUTLASS_HOST_DEVICE void initialize_predicates(const Coord< 3 > &bounds, const Coord< 3 > &block_offset)
Definition: gemm_global_tile.h:219
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:213
CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const &_params, const Coord< 3 > &threadblock_offset, ThreadOffset thread_offset_func=ThreadOffset())
Constructor.
Definition: igemm_global_tile.h:102
GemmGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ > Base
The base class.
Definition: igemm_global_tile.h:64