Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_global_tile.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/coord.h"
31 #include "cutlass/util/platform.h"
32 
34 #include "cutlass/matrix_traits.h"
36 #include "cutlass/reshape_tile.h"
37 #include "cutlass/tile_iterator.h"
38 
39 namespace cutlass {
40 namespace gemm {
41 
43 
44 // The following functor reshapes a tile of threads to match a tile of data. The idea is that when
45 // the user wants to build the iterator traits, he/she may want to specify the tile independently
46 // from the number of scalars loaded/stored per instruction. For example, in the row-major version
47 // with a tile of size 128x8 - the user may want to that the iterator works with 32x8 threads if
48 // each thread loads 1 scalar per LDG. If the user changes to 4 scalars per LDG, then the tile of
49 // threads has to change. The code below detects that and correct the code automatically - it is
50 // a helper when the user does not specify the right configuration.
51 
52 template <typename Tile_, typename Threads_, bool = (Tile_::kW < Threads_::kW)>
53 struct ReshapeThreads {
54  typedef Threads_ Threads;
55 };
56 
57 template <typename Tile_, typename Threads_>
59  typedef Shape<Threads_::kD, Threads_::kH * Threads_::kW / Tile_::kW, Tile_::kW, 1> Threads;
60 };
61 
63 
64 template <GemmOperand::Kind kOperand_,
65  MatrixLayout::Kind kLayout_,
66  typename Scalar_,
67  typename Tile_,
68  typename Threads_,
69  int kAccessSize_>
72  static GemmOperand::Kind const kOperand = kOperand_;
74  static MatrixLayout::Kind const kLayout = kLayout_;
78  typedef Scalar_* Pointer;
80  static int const kAccessSize = kAccessSize_;
84  typedef Tile_ Tile;
93 
97  typedef Shape<1,
98  VectorizedTile::kH / Threads::kH,
99  VectorizedTile::kW / Threads::kW,
100  VectorizedTile::kC / kAccessSize>
102 
104 
106  struct ThreadOffset {
109  int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
110  int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
111 
112  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
113  }
114  };
115 };
116 
118 
119 template <typename Scalar_, typename Tile_, typename Threads_, int kStrideH_, int kAccessSize_>
120 struct GemmGlobalTileCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
121  MatrixLayout::kColumnMajor,
122  Scalar_,
123  Tile_,
124  Threads_,
125  kAccessSize_> {
129  Scalar_,
130  Tile_,
131  Threads_,
132  kAccessSize_>
134 
136  static int const kStrideH = kStrideH_;
139 
140  typedef typename Base::Iterations Iterations;
141 
142  typedef typename Base::Threads Threads;
143 
145 
147 
149  struct ThreadOffset {
152  int thread_offset_h = threadIdx.x / Threads::kW * kStrideH * Iterations::kH;
153  int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
154 
155  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
156  }
157  };
158 };
159 
161 
162 template <typename TileTraits_, typename Index_ = int>
164  : public TileLoadIterator<TileTraits_,
165  typename TileTraits_::Scalar,
166  TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
167  : IteratorAdvance::kW,
168  MemorySpace::kGlobal,
169  Index_> {
172  typedef TileLoadIterator<TileTraits_,
173  typename TileTraits_::Scalar,
174  TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
177  Index_>
180  static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
182  typedef typename TileTraits_::Tile Tile;
184  typedef typename Base::Fragment Fragment;
186  typedef typename TileTraits_::Scalar Scalar;
188  typedef typename TileTraits_::Threads Threads;
190  typedef Index_ Index;
192  typedef long long LongIndex;
194  typedef typename TileTraits_::ThreadOffset ThreadOffset;
197 
199 
201  typedef typename Base::Params BaseParams;
202 
203  struct Params : public BaseParams {
206  Index stride_d,
207  Index stride_h) {
209  }
210  };
211 
218 
219  CUTLASS_HOST_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block_offset) {
220  // Setup the masks to control loads.
221  predicates.fill(0);
222 
223  // Fill in the bits of the predicate vector.
224  for (int d = 0; d < Base::Iterations::kD; ++d) {
225  for (int h = 0; h < Base::Iterations::kH; ++h) {
226  for (int w = 0; w < Base::Iterations::kW; ++w) {
227  for (int c = 0; c < Base::Iterations::kC; ++c) {
228  bool flag = w * Base::Delta::kW + thread_offset[2] + block_offset[2] < bounds[2];
229  if (kAdvance == IteratorAdvance::kH) {
230  flag =
231  flag &&
232  (h * Base::Delta::kH + d * Base::Delta::kD) + thread_offset[1] + block_offset[1] <
233  bounds[1];
234  } else {
235  flag = flag && (h * Base::Delta::kH) + thread_offset[1] + block_offset[1] < bounds[1];
236  }
237  int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
238  predicates.set(bit, flag);
239  }
240  }
241  }
242  }
243  }
244 
247  const Coord<3>& threadblock_offset,
248  ThreadOffset thread_offset_func = ThreadOffset())
249  : params(_params) {
250  thread_offset = thread_offset_func();
251  // Setup the pointer.
252  params.pointer += ((threadblock_offset[1] + thread_offset[1]) * params.stride_h +
253  (threadblock_offset[2] + thread_offset[2]));
254 
255  }
256 
265 
268  typename Base::AccessType& value, int d, int h, int w, int c) const {
269  int const offset =
271  Load<Scalar,
275  typename Base::FragmentElement,
276  Base::Tile::kW,
277  Base::kAccessSize * sizeof(Scalar)>::load(value, params.pointer, offset);
278  }
279 
282  // Update the predicate vector.
283  for (int d = 0; d < Base::Iterations::kD; ++d) {
284  for (int h = 0; h < Base::Iterations::kH; ++h) {
285  for (int w = 0; w < Base::Iterations::kW; ++w) {
286  for (int c = 0; c < Base::Iterations::kC; ++c) {
287  Index offset = 0;
288  if (kAdvance == IteratorAdvance::kH) {
289  offset += thread_offset[1] + h * Base::Delta::kH + d * Base::Delta::kD;
290  } else {
291  offset += thread_offset[2] + w * Base::Delta::kW;
292  }
293 
294  int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
295  if (offset >= k) {
296  predicates.set(bit, false);
297  }
298  }
299  }
300  }
301  }
302  }
303 
305  CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const {
306  int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
307  return predicates[bit];
308  }
309 
312 
313  LongIndex _offset = offset.template dot<LongIndex>(
315  );
316 
317  params.pointer += _offset;
318  return *this;
319  }
320 
322 
324  Index stride = params.stride_h;
325  if (kAdvance == IteratorAdvance::kW) {
326  stride = params.stride_w;
327  }
328  return stride;
329  }
330 
331  template <typename Fragment>
333  typename Base::FragmentIterator frag_iterator(fragment);
334  for (int d = 0; d < Base::Iterations::kD; ++d) {
335  for (int h = 0; h < Base::Iterations::kH; ++h) {
336  for (int w = 0; w < Base::Iterations::kW; ++w) {
337  for (int c = 0; c < Base::Iterations::kC; ++c) {
338  if (valid(d, h, w, c)) {
339  load_element(
340  reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
341  d,
342  h,
343  w,
344  c);
345  }
346  }
347  if (w < Base::Iterations::kW - 1) {
348  inc_w();
349  }
350  }
351  if (h < Base::Iterations::kH - 1) {
352  inc_h();
353  }
354  }
355  if (d < Base::Iterations::kD - 1) {
356  inc_d();
357  }
358  }
359  inc_advance();
360  }
361 };
362 
364 
365 template <typename TileTraits_, typename Index_ = int>
366 struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
367  typename TileTraits_::Scalar,
368  IteratorAdvance::kH,
369  MemorySpace::kGlobal,
370  Index_> {
374  typedef TileIteratorBase<TileTraits_,
375  typename TileTraits_::Scalar,
378  Index_>
380 
382  static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
383 
385  typedef typename TileTraits_::Scalar Scalar;
387  typedef typename TileTraits_::Pointer Pointer;
389  typedef typename TileTraits_::Threads Threads;
391  typedef Index_ Index;
393  typedef long long LongIndex;
395  typedef typename TileTraits_::ThreadOffset ThreadOffset;
396 
398  struct Params {
402  long long stride_d;
411 
414  int stride_d_,
415  Index ldm,
416  Index bound,
417  Index epilogue_stride_w,
418  Index epilogue_delta_w) {
419  // The pointer.
420  this->pointer = pointer;
421  // Stride per batch
422  stride_d = stride_d_;
423  // Each column of the matrix.
424  stride_h = TileTraits_::ThreadsDelta::kH * ldm;
425  // Each thread output 1 column per iteration. The stride between columns is given by the
426  // number of scalars that are loaded per LDS for B.
427  inc_h = ldm * TileTraits_::kStrideH;
428  inc_advance =
429  (ldm - ldm * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
430 
431  predicate_offset = bound;
432  predicate_inc_h = TileTraits_::kStrideH;
434  -((TileTraits_::kStrideH * (Base::Iterations::kH - 1) - 1) + epilogue_delta_w);
435 
436  return 0;
437  }
438 
439  CUTLASS_HOST_DEVICE int initialize(Pointer pointer, long long _stride_d, Index _stride_h,
440  Index _inc_advance, Index _inc_h, Index _predicate_inc_advance, Index _predicate_inc_h,
441  Index _predicate_offset) {
442  this->pointer = pointer;
443  stride_d = _stride_d;
444  stride_h = _stride_h;
445  inc_advance = _inc_advance;
446  inc_h = _inc_h;
447  predicate_inc_advance = _predicate_inc_advance;
448  predicate_inc_h = _predicate_inc_h;
449  predicate_offset = _predicate_offset;
450 
451  return 0;
452  }
453  };
454 
461 
464  const Coord<3>& bounds,
465  const Coord<3>& block,
466  int offset = 0,
467  int pred_offset = 0,
468  ThreadOffset thread_offset_func = ThreadOffset())
469  : params(_params) {
470  thread_offset = thread_offset_func();
471  // Each warp works on a different column of the tile.
472  int const h = thread_offset[1] + block[1];
473  // Each lane writes a different element.
474  int const w = thread_offset[2] + block[2];
475  // Setup the pointer.
476  params.pointer += ((h * params.stride_h + w) + offset);
477 
478  // Prepare the vector of predicates.
479  for (int i = 0; i < Base::Iterations::kW; ++i) {
480  predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
481  }
482  params.predicate_offset -= (h + pred_offset);
483  }
484 
493  }
500  }
501 
504  LongIndex _offset = offset.template dot<LongIndex>(
506  );
507  params.pointer += _offset;
508  return *this;
509  }
510 
513  typename Base::AccessType& value, int d, int h, int w, int c) const {
514  int const offset =
516  Load<Scalar,
520  typename Base::FragmentElement,
521  Base::Tile::kW,
522  Base::kAccessSize * sizeof(Scalar)>::load(value, params.pointer, offset);
523  }
524 
527  typename Base::AccessType const& value, int d, int h, int w, int c) {
528  int const offset =
530  Store<Scalar,
534  typename Base::FragmentElement,
535  Base::Tile::kW,
536  Base::kAccessSize * sizeof(Scalar)>::store(value, params.pointer, offset);
537  }
538 
540  CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const {
541  return predicates.at(w) && params.predicate_offset > 0;
542  }
543 
546 
548  template <typename Fragment>
550  typename Base::FragmentIterator frag_iterator(fragment);
551  for (int d = 0; d < Base::Iterations::kD; ++d) {
552  for (int h = 0; h < Base::Iterations::kH; ++h) {
553  for (int w = 0; w < Base::Iterations::kW; ++w) {
554  for (int c = 0; c < Base::Iterations::kC; ++c) {
555  if (valid(d, h, w, c)) {
556  load_element(
557  reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
558  d,
559  h,
560  w,
561  c);
562  }
563  }
564  if (w < Base::Iterations::kW - 1) {
565  inc_w();
566  }
567  }
568  if (h < Base::Iterations::kH - 1) {
569  inc_h();
570  }
571  }
572  if (d < Base::Iterations::kD - 1) {
573  inc_d();
574  }
575  }
576  inc_advance();
577  }
578 
579  template <typename Fragment>
581  typename Base::FragmentIterator frag_iterator(fragment);
582  for (int d = 0; d < Base::Iterations::kD; ++d) {
583  for (int h = 0; h < Base::Iterations::kH; ++h) {
584  for (int w = 0; w < Base::Iterations::kW; ++w) {
585  for (int c = 0; c < Base::Iterations::kC; ++c) {
586  if (valid(d, h, w, c)) {
588  reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
589  d,
590  h,
591  w,
592  c);
593  }
594  }
595  if (w < Base::Iterations::kW - 1) {
596  inc_w();
597  }
598  }
599  if (h < Base::Iterations::kH - 1) {
600  inc_h();
601  }
602  }
603  if (d < Base::Iterations::kD - 1) {
604  inc_d();
605  }
606  }
607  inc_advance();
608  }
609 };
610 
612 
613 } // namespace gemm
614 } // namespace cutlass
Definition: gemm_global_tile.h:120
Shape< 0, Threads::kH, Threads::kW *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_global_tile.h:92
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: gemm_global_tile.h:503
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset)
add pointer offset
Definition: gemm_global_tile.h:545
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, long long _stride_d, Index _stride_h, Index _inc_advance, Index _inc_h, Index _predicate_inc_advance, Index _predicate_inc_h, Index _predicate_offset)
Definition: gemm_global_tile.h:439
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Definition: gemm_global_tile.h:332
Index inc_advance
The strides to increment the pointer.
Definition: gemm_global_tile.h:406
Definition: convert.h:33
cutlass::PredicateVector< ShapeCount< typename Base::Iterations >::kCount > PredicateVector
Definition: gemm_global_tile.h:198
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:180
T type
Definition: platform.h:377
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:467
Base::Params BaseParams
Iterator parameters type.
Definition: gemm_global_tile.h:201
CUTLASS_HOST_DEVICE void inc_c()
Increment the pointer in the C dimension.
Definition: gemm_global_tile.h:486
Index_ Index
The index.
Definition: gemm_global_tile.h:391
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
ReshapeTile< Tile_, kAccessSize_ >::Tile VectorizedTile
The vectorized tile shape.
Definition: gemm_global_tile.h:86
GemmGlobalIteratorCd< TileTraits_, Index_ > This_
This class.
Definition: gemm_global_tile.h:372
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:382
Definition: gemm_global_tile.h:70
Scalar_ * Pointer
The pointer.
Definition: gemm_global_tile.h:78
FragmentIterator< Fragment, Iterations, AccessType > FragmentIterator
The fragment iterator.
Definition: tile_iterator.h:202
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Kind
Definition: tile_iterator.h:65
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, Index stride_d, Index stride_h)
Initializes params to load a strip-mined tile, given pointer and stride_h.
Definition: gemm_global_tile.h:205
CUTLASS_HOST_DEVICE bool at(int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:357
Definition: load_store.h:42
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Shape< 1, 1, VectorizedTile::kC > ThreadsDelta
The relative offset between two elements in the H/W dimension in adjacent threads.
Definition: gemm_global_tile.h:90
GemmMultiplicandTraits< Tile, kOperand, kLayout > MultiplicandTraits
Definition: gemm_global_tile.h:103
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:428
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_global_tile.h:82
long long LongIndex
Long index.
Definition: gemm_global_tile.h:192
CUTLASS_HOST_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: gemm_global_tile.h:267
TileIteratorBase< TileTraits_, typename TileTraits_::Scalar, IteratorAdvance::kH, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: gemm_global_tile.h:379
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:199
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > Delta
Override the strides in each dimension between different loads/stores.
Definition: gemm_global_tile.h:138
Index predicate_inc_h
Definition: gemm_global_tile.h:408
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:590
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:437
Tile_ Tile
The tile shape.
Definition: gemm_global_tile.h:84
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:470
Definition: tile_iterator.h:65
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads and increments iterator.
Definition: gemm_global_tile.h:549
TileLoadIterator< TileTraits_, typename TileTraits_::Scalar, TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH :IteratorAdvance::kW, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: gemm_global_tile.h:178
Definition: gemm_global_tile.h:203
Index inc_d
Definition: tile_iterator.h:226
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const &_params, const Coord< 3 > &bounds, const Coord< 3 > &block, int offset=0, int pred_offset=0, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:463
Definition: matrix_traits.h:357
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:181
C++ features that may be otherwise unimplemented for CUDA device functions.
Definition: gemm_global_tile.h:163
GemmGlobalTileTraits< GemmOperand::kC, MatrixLayout::kColumnMajor, Scalar_, Tile_, Threads_, kAccessSize_ > Base
The base class.
Definition: gemm_global_tile.h:133
Kind
Definition: load_store.h:39
CUTLASS_HOST_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: gemm_global_tile.h:262
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: gemm_global_tile.h:196
TileTraits_::Threads Threads
The threads.
Definition: gemm_global_tile.h:188
static int const kStrideH
The stride in the H dimension.
Definition: gemm_global_tile.h:136
static int const kH
The height of the cube.
Definition: shape.h:68
Definition: load_store.h:178
Shape< Threads_::kD, Threads_::kH *Threads_::kW/Tile_::kW, Tile_::kW, 1 > Threads
Definition: gemm_global_tile.h:59
Index predicate_inc_advance
The strides to increment the predicate offset.
Definition: gemm_global_tile.h:408
static GemmOperand::Kind const kOperand
Identity of the operand.
Definition: gemm_global_tile.h:72
Index stride_h
Definition: tile_iterator.h:223
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Test the validity of the.
Definition: gemm_global_tile.h:540
static FragmentElementType::Kind const kFragmentElementType
Specifies iterator storage fragment type (Scalar or WmmaMatrix)
Definition: tile_iterator.h:158
CUTLASS_HOST_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: gemm_global_tile.h:495
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:402
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb(Params const &_params, const Coord< 3 > &threadblock_offset, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:246
PredicateVector predicates
The predicates.
Definition: gemm_global_tile.h:217
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_global_tile.h:76
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
CUTLASS_HOST_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: gemm_global_tile.h:260
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: gemm_global_tile.h:311
Base::Fragment Fragment
Fragment type loaded by the iterator.
Definition: gemm_global_tile.h:184
TileTraits_::Threads Threads
The threads.
Definition: gemm_global_tile.h:389
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_global_tile.h:151
Definition: gemm_operand.h:67
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_global_tile.h:106
CUTLASS_HOST_DEVICE void inc_w()
Increment the pointer in the W dimension.
Definition: gemm_global_tile.h:258
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:776
CUTLASS_HOST_DEVICE void fill(bool value=true)
Fills all predicates with a given value.
Definition: predicate_vector.h:344
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment)
Definition: gemm_global_tile.h:580
Base::Threads Threads
Definition: gemm_global_tile.h:142
Index stride_h
The stride in the H dimension to setup the thread in the block.
Definition: gemm_global_tile.h:404
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: gemm_global_tile.h:264
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_global_tile.h:108
CUTLASS_HOST_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: gemm_global_tile.h:512
Index inc_h
Definition: tile_iterator.h:227
Index stride_d
Definition: tile_iterator.h:222
Shape< 0, 0, Threads::kW *ThreadsDelta::kW, kAccessSize > ImmediateOffsetStrides
Strides for immediate offset computation.
Definition: gemm_global_tile.h:95
Statically sized array of bits implementing.
Definition: predicate_vector.h:105
Definition: vector.h:62
Definition: load_store.h:60
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:395
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Definition: matrix_traits.h:159
Base::ImmediateOffsetStrides ImmediateOffsetStrides
Definition: gemm_global_tile.h:146
long long LongIndex
The index.
Definition: gemm_global_tile.h:393
TileTraits_::Scalar Scalar
The scalar.
Definition: gemm_global_tile.h:385
Index inc_h
Definition: gemm_global_tile.h:406
cutlass::PredicateVector< Base::Iterations::kW > predicates
The predicates for the row.
Definition: gemm_global_tile.h:460
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:161
Pointer pointer
The pointer.
Definition: gemm_global_tile.h:400
GemmGlobalIteratorAb< TileTraits_, Index_ > This_
This class.
Definition: gemm_global_tile.h:171
TileTraits_::Tile Tile
The tile.
Definition: gemm_global_tile.h:182
ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:88
Shape< 1, VectorizedTile::kH/Threads::kH, VectorizedTile::kW/Threads::kW, VectorizedTile::kC/kAccessSize > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_global_tile.h:101
CUTLASS_HOST_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: gemm_global_tile.h:497
CUTLASS_HOST_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: gemm_global_tile.h:490
CUTLASS_HOST_DEVICE void inc_w()
Increment the pointer in the W dimension.
Definition: gemm_global_tile.h:488
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:144
Scalar const * pointer
Pointer to memory.
Definition: tile_iterator.h:499
Params params
Parameters.
Definition: gemm_global_tile.h:456
Definition: gemm_global_tile.h:366
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:431
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:458
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:152
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:194
static int const kW
The width of the cube.
Definition: shape.h:70
CUTLASS_HOST_DEVICE void set(int idx, bool value=true)
Set a bit within the predicate vector.
Definition: predicate_vector.h:365
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset)
Definition: gemm_global_tile.h:321
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:188
Parameters.
Definition: tile_iterator.h:497
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:199
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_global_tile.h:149
long long stride_d
The stride in the D dimension.
Definition: gemm_global_tile.h:402
CUTLASS_HOST_DEVICE Index stride_advance(void)
Definition: gemm_global_tile.h:323
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:686
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_global_tile.h:80
Tile_ Tile
Definition: reshape_tile.h:43
Definition: tile_iterator.h:65
Base::Iterations Iterations
Definition: gemm_global_tile.h:140
Index_ Index
The index.
Definition: gemm_global_tile.h:190
TileTraits_::Pointer Pointer
The pointer.
Definition: gemm_global_tile.h:387
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Is the valid?
Definition: gemm_global_tile.h:305
static FragmentElementType::Kind const kFragmentElementType
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:434
Kind
Definition: matrix_traits.h:357
TileTraits_::Scalar Scalar
The scalar.
Definition: gemm_global_tile.h:186
Index inc_advance
Definition: tile_iterator.h:230
Threads_ Threads
Definition: gemm_global_tile.h:54
Params params
The parameters.
Definition: gemm_global_tile.h:215
Defines properties of matrices used to denote layout and operands to GEMM kernels.
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, int stride_d_, Index ldm, Index bound, Index epilogue_stride_w, Index epilogue_delta_w)
Setup the params.
Definition: gemm_global_tile.h:413
CUTLASS_HOST_DEVICE void initialize_predicates(const Coord< 3 > &bounds, const Coord< 3 > &block_offset)
Definition: gemm_global_tile.h:219
The params.
Definition: gemm_global_tile.h:398
Base::ThreadsDelta ThreadsDelta
Definition: gemm_global_tile.h:144
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:213
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:473
CUTLASS_HOST_DEVICE void residue(Index k)
That&#39;s the residue! Update the predicates.
Definition: gemm_global_tile.h:281
Index predicate_offset
The column offset to compute the predicate for the columns.
Definition: gemm_global_tile.h:410
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:74
CUTLASS_HOST_DEVICE void store_element(typename Base::AccessType const &value, int d, int h, int w, int c)
Stores a single fragment element into memory.
Definition: gemm_global_tile.h:526
Index stride_w
Definition: tile_iterator.h:224