73 template <
typename Tile_,
76 typename ThreadOffset_,
111 template <
typename Delta_>
125 return (iteration[0] * Delta::kD + offset[0] <
bounds[0]) &&
126 (iteration[1] * Delta::kH + offset[1] <
bounds[1]) &&
127 (iteration[2] * Delta::kW + offset[2] <
bounds[2]);
133 template <
typename T>
136 template <
typename Traits_,
140 typename Index_ = int,
141 typename FragmentElement_ = Scalar_,
173 typedef typename Traits::Tile
Tile;
176 typedef typename Traits::Delta
Delta;
271 Index _inc_advance) {
287 return initialize(stride[0], stride[1], stride[2]);
300 stride_w * Delta::kW * (Iterations::kW - 1);
317 stride_h * Delta::kH * (Iterations::kH - 1) +
318 stride_w * Delta::kW * (Iterations::kW - 1);
344 template <
typename PredicateIterator,
typename PredicateFunctor>
346 PredicateFunctor
const &predicate_func,
349 for (
int d = 0; d < Iterations::kD; ++d) {
351 for (
int h = 0; h < Iterations::kH; ++h) {
353 for (
int w = 0; w < Iterations::kW; ++w) {
354 bool enable = predicate_func(
make_Coord(d, h, w), offset);
355 predicate_it.set(enable);
394 template <
typename Traits_,
398 typename Index_ = int,
399 typename FragmentElement_ = Scalar_,
401 typename Skew_ = Shape<0, 0, 0, 0> >
408 FragmentElementType_,
417 FragmentElementType_,
531 _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
581 Index _inc_advance) {
584 _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
613 typename PredicateIterator>
628 typename PredicateIterator,
630 typename PredicateFunctor>
632 PredicateFunctor
const &functor,
707 int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
708 if (
stage == Tile::kD - 1) {
720 long long _offset = offset.template dot<long long>(
740 template <
typename Fragment,
typename PredicateIterator>
743 for (
int d = 0; d < Iterations::kD; ++d) {
744 for (
int h = 0; h < Iterations::kH; ++h) {
745 for (
int w = 0; w < Iterations::kW; ++w, ++pred_it) {
746 for (
int c = 0; c < Iterations::kC; ++c) {
749 reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
752 if (w < Iterations::kW - 1) {
756 if (h < Iterations::kH - 1) {
760 if (d < Iterations::kD - 1) {
768 template <
typename Fragment>
775 template <
typename Fragment,
typename PredicateIterator>
782 template <
typename Fragment>
785 load(fragment, pred_it);
789 template <
typename Fragment>
792 for (
int h = 0; h < Iterations::kH; ++h) {
793 for (
int w = 0; w < Iterations::kW; ++w) {
794 for (
int c = 0; c < Iterations::kC; ++c) {
795 load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
833 template <
typename Traits_,
837 typename Index_ = int,
838 typename FragmentElement_ = Scalar_,
840 typename Skew_ = Shape<0, 0, 0, 0> >
847 FragmentElementType_,
856 FragmentElementType_,
970 Index _inc_advance) {
971 initialize(ptr, _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
1005 Index _inc_advance) {
1008 _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
1037 typename PredicateIterator>
1052 typename PredicateIterator,
1054 typename PredicateFunctor>
1056 PredicateFunctor
const &functor,
1112 int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
1113 if (
stage == Tile::kD - 1) {
1148 template <
typename Fragment,
typename PredicateIterator>
1152 for (
int d = 0; d < Iterations::kD; ++d) {
1153 for (
int h = 0; h < Iterations::kH; ++h) {
1154 for (
int w = 0; w < Iterations::kW; ++w, ++pred_it) {
1155 for (
int c = 0; c < Iterations::kC; ++c) {
1158 reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
1161 if (w < Iterations::kW - 1) {
1165 if (h < Iterations::kH - 1) {
1169 if (d < Iterations::kD - 1) {
1177 template <
typename Fragment>
1184 template <
typename Fragment,
typename PredicateIterator>
1191 template <
typename Fragment>
1194 store(fragment, pred_it);
1212 template <
typename Fragment,
typename PredicateIterator>
1216 for (
int d = 0; d < Iterations::kD; ++d) {
1217 for (
int h = 0; h < Iterations::kH; ++h) {
1218 for (
int w = 0; w < Iterations::kW; ++w, ++pred_it) {
1219 for (
int c = 0; c < Iterations::kC; ++c) {
1222 reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
1225 if (w < Iterations::kW - 1) {
1229 if (h < Iterations::kH - 1) {
1233 if (d < Iterations::kD - 1) {
1241 template <
typename Fragment>
1248 template <
typename Fragment,
typename PredicateIterator>
1255 template <
typename Fragment>
1258 load(fragment, pred_it);
1262 template <
typename Fragment>
1265 for (
int h = 0; h < Iterations::kH; ++h) {
1266 for (
int w = 0; w < Iterations::kW; ++w) {
1267 for (
int c = 0; c < Iterations::kC; ++c) {
1268 load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
CUTLASS_HOST_DEVICE void inc_advance()
Increment in the next dimension.
Definition: tile_iterator.h:689
CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: tile_iterator.h:1198
Vectorize< FragmentElement, kAccessSize >::Type AccessType
The elements loaded/store by one instruction.
Definition: tile_iterator.h:191
Base::Tile Tile
Tile shape.
Definition: tile_iterator.h:891
Delta_ Delta
Definition: tile_iterator.h:113
Base::PredicateVector PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:918
CUTLASS_HOST_DEVICE Params()
Initialize params to access storage object.
Definition: tile_iterator.h:507
CUTLASS_HOST_DEVICE TileLoadIterator(Params const &_params, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile load iterator.
Definition: tile_iterator.h:650
CUTLASS_HOST_DEVICE int initialize(SharedStorage const &storage)
Initialize params to access storage object.
Definition: tile_iterator.h:550
Tile_ Tile
Shape of the tile.
Definition: tile_iterator.h:80
Index_ Index
Index type.
Definition: tile_iterator.h:164
Defines a structure containing strides, bounds, and a pointer to tensor data.
Base::Iterations Iterations
Iterations.
Definition: tile_iterator.h:455
CUTLASS_HOST_DEVICE int initialize(Coord< 4 > const &stride)
Initializes the parameters object from a vector of strides.
Definition: tile_iterator.h:286
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: tile_iterator.h:337
Skew_ Skew
Skew quantity.
Definition: tile_iterator.h:170
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:467
Enum to specify which memory space data resides in.
Definition: load_store.h:38
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:1242
Base::Index Index
Index type.
Definition: tile_iterator.h:882
Base::Storage SharedStorage
Storage object that may be loaded from.
Definition: tile_iterator.h:482
int stage
The stage.
Definition: tile_iterator.h:1028
Base::Tile Tile
Tile shape.
Definition: tile_iterator.h:449
FragmentIterator< Fragment, Iterations, AccessType > FragmentIterator
The fragment iterator.
Definition: tile_iterator.h:202
Scalar * Pointer
Pointer to underlying type.
Definition: tile_iterator.h:927
Traits::ThreadOffset ThreadOffset
Thread offset.
Definition: tile_iterator.h:185
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Kind
Definition: tile_iterator.h:65
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr, Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:998
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Shape< 0, 0, 0, 0 > ImmediateOffsetStrides
Strides for immediate offset computation.
Definition: tile_iterator.h:102
Base::FragmentConstIterator FragmentConstIterator
Fragment const iterator definition.
Definition: tile_iterator.h:476
A template defining Tile Traits Concept.
Definition: tile_iterator.h:78
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:428
TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > Base
Base class.
Definition: tile_iterator.h:419
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector using a RegularTilePredicateFunctor.
Definition: tile_iterator.h:614
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:199
Traits::Iterations Iterations
Iterations.
Definition: tile_iterator.h:182
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d)
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:790
Base::Delta Delta
Delta.
Definition: tile_iterator.h:452
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:741
Definition: load_store.h:48
Base::LongIndex LongIndex
Long index type.
Definition: tile_iterator.h:885
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:590
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:437
Traits::ImmediateOffsetStrides ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: tile_iterator.h:179
Base::ThreadOffset ThreadOffset
ThreadOffset functor.
Definition: tile_iterator.h:458
CUTLASS_HOST_DEVICE int initialize()
Initializes params to default values.
Definition: tile_iterator.h:1014
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:470
Definition: tile_iterator.h:65
Base::Storage SharedStorage
Storage object which may be stored to.
Definition: tile_iterator.h:921
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:876
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector using a RegularTilePredicateFunctor.
Definition: tile_iterator.h:1038
Index inc_d
Definition: tile_iterator.h:226
CUTLASS_HOST_DEVICE void inc_d()
Increment in the D dimension.
Definition: tile_iterator.h:1098
ThreadOffset_ ThreadOffset
Functor that returns the logical coordinate of each entity's initial offset in the tile...
Definition: tile_iterator.h:99
Iterator that always returns true.
Definition: predicate_vector.h:309
CUTLASS_HOST_DEVICE Params(Coord< 4 > const &stride)
Constructs params with a stride vector.
Definition: tile_iterator.h:259
Scalar * pointer
Pointer to memory.
Definition: tile_iterator.h:935
CUTLASS_HOST_DEVICE Params(Scalar *ptr, long long stride_d, Index stride_h, Index stride_w)
Definition: tile_iterator.h:957
Base::FragmentShape FragmentShape
Fragment type.
Definition: tile_iterator.h:903
Kind
Definition: load_store.h:39
PredicateVector< ShapeCount< Iterations >::kCount > PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:209
CUTLASS_HOST_DEVICE Index stride_advance(void)
Definition: tile_iterator.h:731
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:909
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr, Index stride_d, Index stride_h, Index stride_w)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:990
TensorRef< Scalar, 4 > TensorRef
Tensor reference for the store iterator.
Definition: tile_iterator.h:930
Base::Traits Traits
concept TileTraits
Definition: tile_iterator.h:422
TensorRef< Scalar const, 4 > TensorRef
Tensor reference for the load iterator.
Definition: tile_iterator.h:494
Definition: load_store.h:178
CUTLASS_HOST_DEVICE int initialize(Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:265
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:1256
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &functor, Coord< 3 > const &block_offset)
Initializes a predicate vector using an arbitrary predicate functor.
Definition: tile_iterator.h:631
CUTLASS_HOST_DEVICE Params(Scalar *ptr, Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Definition: tile_iterator.h:963
CUTLASS_HOST_DEVICE void store_element(AccessType const &value, int d, int h, int w, int c)
Stores a single fragment element into memory.
Definition: tile_iterator.h:1135
FragmentIterator::FragmentShape FragmentShape
The shape of the fragment.
Definition: tile_iterator.h:206
Scalar const * Pointer
The pointer type.
Definition: tile_iterator.h:491
CUTLASS_HOST_DEVICE Params(Scalar const *ptr, Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initialize params to access storage object.
Definition: tile_iterator.h:521
static IteratorAdvance::Kind const kAdvance
Specifies dimension in which post-increment accesses advance.
Definition: tile_iterator.h:155
Index stride_h
Definition: tile_iterator.h:223
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
Parameters.
Definition: tile_iterator.h:933
static FragmentElementType::Kind const kFragmentElementType
Specifies iterator storage fragment type (Scalar or WmmaMatrix)
Definition: tile_iterator.h:158
Traits_ Traits
concept TileTraits
Definition: tile_iterator.h:146
Params params
Parameters structure.
Definition: tile_iterator.h:1022
Base::FragmentConstIterator FragmentConstIterator
Fragment const iterator definition.
Definition: tile_iterator.h:915
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:402
Base::Params BaseParams
IteratorBase parameters.
Definition: tile_iterator.h:924
CUTLASS_HOST_DEVICE void inc_stage()
Increment the stage.
Definition: tile_iterator.h:705
CUTLASS_HOST_DEVICE int initialize()
Gotta have this.
Definition: tile_iterator.h:324
Kind
Definition: load_store.h:48
CUTLASS_HOST_DEVICE RegularTilePredicateFunctor(Coord< 3 > _bounds)
Constructs a predicate functor given the bounds of a tensor.
Definition: tile_iterator.h:120
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset)
Adds a raw offset to the pointer.
Definition: tile_iterator.h:1132
CUTLASS_HOST_DEVICE TileLoadIterator()
Default constructor.
Definition: tile_iterator.h:646
Definition: load_store.h:40
Base::Params BaseParams
IteratorBase parameters.
Definition: tile_iterator.h:485
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:566
Params params
Parameters structure.
Definition: tile_iterator.h:598
FragmentConstIterator< Fragment, Iterations, AccessType > FragmentConstIterator
The fragment const iterator.
Definition: tile_iterator.h:204
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr)
Initialize params to access storage object.
Definition: tile_iterator.h:558
CUTLASS_HOST_DEVICE TileLoadIterator & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: tile_iterator.h:719
CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: tile_iterator.h:692
Base::AccessType AccessType
Memory access type.
Definition: tile_iterator.h:906
Definition: tile_iterator.h:488
Definition: tile_iterator.h:134
Iterations_ Iterations
Number of accesses performed.
Definition: tile_iterator.h:86
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:183
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:776
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset)
Adds a raw offset to the pointer.
Definition: tile_iterator.h:729
static int const kAccessSize
Access size.
Definition: tile_iterator.h:105
Fragment< Scalar, ShapeCount< Tile >::kCount, kFragmentSize > Storage
The storage.
Definition: tile_iterator.h:197
Base::Skew Skew
Skew quantity.
Definition: tile_iterator.h:888
Delta_ Delta
Number of steps between accesses along each dimension.
Definition: tile_iterator.h:83
Defines abstractions for efficiently loading and storing vectors to memory.
Base::Scalar Scalar
Scalar element.
Definition: tile_iterator.h:425
Base::FragmentShape FragmentShape
Fragment type.
Definition: tile_iterator.h:461
static FragmentElementType::Kind const kFragmentElementType
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:873
CUTLASS_HOST_DEVICE TileStoreIterator(Params const &, Scalar *ptr, ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile store iterator.
Definition: tile_iterator.h:1087
CUTLASS_HOST_DEVICE void inc_advance()
Increment in the next dimension.
Definition: tile_iterator.h:1107
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE Params(Scalar *ptr)
Definition: tile_iterator.h:947
CUTLASS_HOST_DEVICE TileStoreIterator(Params const &_params, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile store iterator.
Definition: tile_iterator.h:1074
Base::LongIndex LongIndex
Index type.
Definition: tile_iterator.h:443
Index inc_h
Definition: tile_iterator.h:227
CUTLASS_HOST_DEVICE Params()
Constructs params.
Definition: tile_iterator.h:238
Index stride_d
Definition: tile_iterator.h:222
TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > Base
Base class.
Definition: tile_iterator.h:858
Definition: load_store.h:60
Base::Traits Traits
concept TileTraits
Definition: tile_iterator.h:861
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
CUTLASS_HOST_DEVICE Params(Scalar const *ptr)
Initialize params to access storage object.
Definition: tile_iterator.h:511
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:912
Specifies dimension in which post-increment accesses advance.
Definition: tile_iterator.h:64
CUTLASS_HOST_DEVICE Params(TensorRef const &ref)
Constructs with a CompactTensorRef<>
Definition: tile_iterator.h:515
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:161
CUTLASS_HOST_DEVICE TileStoreIterator()
Default constructor.
Definition: tile_iterator.h:1070
Definition: load_store.h:48
CUTLASS_HOST_DEVICE int initialize(TensorRef const &ref)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:543
Base::Scalar Scalar
Scalar element.
Definition: tile_iterator.h:864
Defines a 1D vector of elements held in the registers of each thread.
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:144
Scalar const * pointer
Pointer to memory.
Definition: tile_iterator.h:499
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it)
Stores a fragment and advances to the next tile.
Definition: tile_iterator.h:1149
Definition: tile_iterator.h:65
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:431
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:152
CUTLASS_HOST_DEVICE void store(Fragment const &fragment) const
Stores a fragment without advancing the iterator.
Definition: tile_iterator.h:1192
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: tile_iterator.h:1025
Base::Iterations Iterations
Iterations.
Definition: tile_iterator.h:897
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:879
CUTLASS_HOST_DEVICE Params(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w)
Initialize params to access storage object.
Definition: tile_iterator.h:536
Functor computing a predicate given the logical position of an access.
Definition: tile_iterator.h:112
Traits::Tile Tile
Tile shape.
Definition: tile_iterator.h:173
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:188
Parameters.
Definition: tile_iterator.h:497
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:199
Base::PredicateVector PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:479
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d)
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:1263
CUTLASS_HOST_DEVICE int initialize(Index _stride_d, Index _stride_h, Index _stride_w)
Initializes the parameters object from a vector of strides.
Definition: tile_iterator.h:292
Base::AccessType AccessType
Memory access type.
Definition: tile_iterator.h:464
Base::Skew Skew
Skew quantity.
Definition: tile_iterator.h:446
Base::Delta Delta
Delta.
Definition: tile_iterator.h:894
CUTLASS_HOST_DEVICE void inc_h()
Increment in the H dimension.
Definition: tile_iterator.h:1101
CUTLASS_HOST_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const
Stores a fragment without advancing the iterator.
Definition: tile_iterator.h:1185
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:686
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:783
Base::ThreadOffset ThreadOffset
ThreadOffset functor.
Definition: tile_iterator.h:900
Definition: tile_iterator.h:65
static FragmentElementType::Kind const kFragmentElementType
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:434
CUTLASS_HOST_DEVICE void inc_h()
Increment in the H dimension.
Definition: tile_iterator.h:683
static CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &predicate_func, Coord< 3 > const &offset)
Initializes a predicate vector.
Definition: tile_iterator.h:345
Scalar_ Scalar
Scalar element.
Definition: tile_iterator.h:149
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: tile_iterator.h:601
Coord< 3 > bounds
Dimensions of the bounding volume.
Definition: tile_iterator.h:116
Traits::Delta Delta
Distance along each dimension.
Definition: tile_iterator.h:176
static int const kFragmentSize
The size of storage needed per fragment.
Definition: tile_iterator.h:194
Index inc_advance
Definition: tile_iterator.h:230
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment)
Stores a fragment and advances to the next tile.
Definition: tile_iterator.h:1178
long long LongIndex
Long index.
Definition: tile_iterator.h:167
CUTLASS_HOST_DEVICE Params()
Definition: tile_iterator.h:943
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr)
Initialize params to access storage object.
Definition: tile_iterator.h:983
Base::FragmentElement FragmentElement
Fragment element.
Definition: tile_iterator.h:867
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:1249
CUTLASS_HOST_DEVICE Params(TensorRef const &ref)
Constructs with a CompactTensorRef<>
Definition: tile_iterator.h:951
CUTLASS_HOST_DEVICE void inc_d()
Increment in the D dimension.
Definition: tile_iterator.h:680
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:769
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...
Parameters to the iterator.
Definition: tile_iterator.h:216
CUTLASS_HOST_DEVICE TileStoreIterator & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: tile_iterator.h:1124
CUTLASS_HOST_DEVICE void inc_stage()
Increment the stage.
Definition: tile_iterator.h:1110
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:473
CUTLASS_HOST_DEVICE bool operator()(Coord< 3 > iteration, Coord< 3 > offset) const
Computes the predicate given the logical position of an access.
Definition: tile_iterator.h:124
Base::Index Index
Index type.
Definition: tile_iterator.h:440
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:1104
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:870
int stage
Stage argument enables wrapping after some number of tiles have been loaded.
Definition: tile_iterator.h:604
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:574
CUTLASS_HOST_DEVICE Params(Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Constructs params.
Definition: tile_iterator.h:242
CUTLASS_HOST_DEVICE TileLoadIterator(Params const &, Scalar const *ptr, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile load iterator.
Definition: tile_iterator.h:665
CUTLASS_HOST_DEVICE int initialize(SharedStorage &storage)
Initialize params to access storage object.
Definition: tile_iterator.h:976
Index inc_w
Definition: tile_iterator.h:228
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &functor, Coord< 3 > const &block_offset)
Initializes a predicate vector using an arbitrary predicate functor.
Definition: tile_iterator.h:1055
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:841
Index stride_w
Definition: tile_iterator.h:224
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:1213