Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
tile_iterator.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/coord.h"
32 #include "cutlass/tensor_ref.h"
33 #include "cutlass/fragment.h"
34 #include "cutlass/load_store.h"
36 #include "cutlass/vector.h"
37 #include <cstdio>
38 
39 namespace cutlass {
40 
42 
61 
65  enum Kind { kD, kH, kW };
66 };
68 
73 template <typename Tile_,
74  typename Delta_,
75  typename Iterations_,
76  typename ThreadOffset_,
77  int AccessSize>
78 struct TileTraits {
80  typedef Tile_ Tile;
81 
83  typedef Delta_ Delta;
84 
86  typedef Iterations_ Iterations;
87 
89  //
90  // ThreadOffset should be a functor defined like:
91  //
92  // struct ThreadOffsetExample {
93  // CUTLASS_DEVICE
94  // Coord<4> operator()() const {
95  // return make_Coord(0, threadIdx.y, threadIdx.x, 0);
96  // }
97  // };
98  //
99  typedef ThreadOffset_ ThreadOffset;
100 
103 
105  static int const kAccessSize = AccessSize;
106 };
107 
109 
111 template <typename Delta_>
113  typedef Delta_ Delta;
114 
117 
121 
124  bool operator()(Coord<3> iteration, Coord<3> offset) const {
125  return (iteration[0] * Delta::kD + offset[0] < bounds[0]) &&
126  (iteration[1] * Delta::kH + offset[1] < bounds[1]) &&
127  (iteration[2] * Delta::kW + offset[2] < bounds[2]);
128  }
129 };
130 
132 
133 template <typename T>
134 struct DumpType {};
136 template <typename Traits_,
137  typename Scalar_,
140  typename Index_ = int,
141  typename FragmentElement_ = Scalar_,
143  typename Skew_ = Shape<0, 0, 0, 0> >
146  typedef Traits_ Traits;
147 
149  typedef Scalar_ Scalar;
150 
152  typedef FragmentElement_ FragmentElement;
153 
155  static IteratorAdvance::Kind const kAdvance = Advance_;
156 
158  static FragmentElementType::Kind const kFragmentElementType = FragmentElementType_;
159 
162 
164  typedef Index_ Index;
165 
167  typedef long long LongIndex;
168 
170  typedef Skew_ Skew;
171 
173  typedef typename Traits::Tile Tile;
174 
176  typedef typename Traits::Delta Delta;
177 
179  typedef typename Traits::ImmediateOffsetStrides ImmediateOffsetStrides;
180 
182  typedef typename Traits::Iterations Iterations;
183 
185  typedef typename Traits::ThreadOffset ThreadOffset;
186 
188  static int const kAccessSize = Traits::kAccessSize;
189 
192 
194  static int const kFragmentSize =
200 
207 
210 
211  //
212  // Params struct
213  //
214 
216  struct Params {
217 
218  //
219  // Dat members
220  //
221 
225 
229 
231 
232  //
233  // Methods
234  //
235 
238  Params() : stride_d(0), stride_h(0), stride_w(0), inc_d(0), inc_h(0), inc_w(0) {}
239 
242  Params(Index _stride_d,
243  Index _stride_h,
244  Index _stride_w,
245  Index _inc_d,
246  Index _inc_h,
247  Index _inc_w,
248  Index _inc_advance)
249  : stride_d(_stride_d),
250  stride_h(_stride_h),
251  stride_w(_stride_w),
252  inc_d(_inc_d),
253  inc_h(_inc_h),
254  inc_w(_inc_w),
255  inc_advance(_inc_advance) {}
256 
259  Params(Coord<4> const &stride) {
260  initialize(stride);
261  }
262 
265  int initialize(Index _stride_d,
266  Index _stride_h,
267  Index _stride_w,
268  Index _inc_d,
269  Index _inc_h,
270  Index _inc_w,
271  Index _inc_advance) {
272  stride_d = _stride_d;
273  stride_h = _stride_h;
274  stride_w = _stride_w;
275 
276  inc_d = _inc_d;
277  inc_h = _inc_h;
278  inc_w = _inc_w;
279  inc_advance = _inc_advance;
280 
281  return 0;
282  }
283 
286  int initialize(Coord<4> const &stride) {
287  return initialize(stride[0], stride[1], stride[2]);
288  }
289 
292  int initialize(Index _stride_d, Index _stride_h, Index _stride_w) {
293  stride_d = _stride_d;
294  stride_h = _stride_h;
295  stride_w = _stride_w;
296 
297  inc_w = stride_w * Delta::kW;
298  inc_h = stride_h * Delta::kH - stride_w * Delta::kW * (Iterations::kW - 1);
299  inc_d = stride_h * Delta::kD - stride_h * Delta::kH * (Iterations::kH - 1) -
300  stride_w * Delta::kW * (Iterations::kW - 1);
301 
302  inc_advance = 0;
303 
304  if (kAdvance == IteratorAdvance::kH) {
305  // Advance in the H dimension.
306  inc_advance = Tile::kH * stride_h;
307  } else if (kAdvance == IteratorAdvance::kW) {
308  // Advance in the W dimension.
309  inc_advance = Tile::kW * stride_w;
310 
311  } else {
312  // Advance in the D dimension.
313  inc_advance = Tile::kD * stride_d;
314  }
315 
316  inc_advance -= stride_h * Delta::kD * (Iterations::kD - 1) +
317  stride_h * Delta::kH * (Iterations::kH - 1) +
318  stride_w * Delta::kW * (Iterations::kW - 1);
319 
320  return 0;
321  }
322 
325  stride_d = 0;
326  stride_h = 0;
327  stride_w = 1;
328 
329  inc_advance = 0;
330  inc_d = inc_h = inc_w = 0;
331 
332  return 0;
333  }
334  };
335 
337  CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
338 
339  //
340  // Static function members
341  //
342 
344  template <typename PredicateIterator, typename PredicateFunctor>
345  CUTLASS_HOST_DEVICE static void initialize_predicates(PredicateIterator predicate_it,
346  PredicateFunctor const &predicate_func,
347  Coord<3> const &offset) {
349  for (int d = 0; d < Iterations::kD; ++d) {
351  for (int h = 0; h < Iterations::kH; ++h) {
353  for (int w = 0; w < Iterations::kW; ++w) {
354  bool enable = predicate_func(make_Coord(d, h, w), offset);
355  predicate_it.set(enable);
356  ++predicate_it;
357  }
358  }
359  }
360  }
361 };
362 
364 
388 
394 template <typename Traits_,
395  typename Scalar_,
398  typename Index_ = int,
399  typename FragmentElement_ = Scalar_,
401  typename Skew_ = Shape<0, 0, 0, 0> >
402 struct TileLoadIterator : public TileIteratorBase<Traits_,
403  Scalar_,
404  Advance_,
405  MemorySpace,
406  Index_,
407  FragmentElement_,
408  FragmentElementType_,
409  Skew_> {
411  typedef TileIteratorBase<Traits_,
412  Scalar_,
413  Advance_,
414  MemorySpace,
415  Index_,
416  FragmentElement_,
417  FragmentElementType_,
418  Skew_>
420 
422  typedef typename Base::Traits Traits;
423 
425  typedef typename Base::Scalar Scalar;
426 
428  typedef FragmentElement_ FragmentElement;
429 
432 
434  static FragmentElementType::Kind const kFragmentElementType = FragmentElementType_;
435 
438 
440  typedef typename Base::Index Index;
441 
443  typedef typename Base::LongIndex LongIndex;
444 
446  typedef typename Base::Skew Skew;
447 
449  typedef typename Base::Tile Tile;
450 
452  typedef typename Base::Delta Delta;
453 
455  typedef typename Base::Iterations Iterations;
456 
459 
462 
464  typedef typename Base::AccessType AccessType;
465 
467  static int const kAccessSize = Base::kAccessSize;
468 
470  typedef typename Base::Fragment Fragment;
471 
474 
477 
480 
482  typedef typename Base::Storage SharedStorage;
483 
485  typedef typename Base::Params BaseParams;
486 
488  enum { kRequiresLoadFence = Tile::kD == 1 };
489 
491  typedef Scalar const *Pointer;
492 
495 
497  struct Params : public BaseParams {
499  Scalar const *pointer;
500 
501  //
502  // Methods
503  //
504 
508 
511  Params(Scalar const *ptr) : pointer(ptr) { Base::Params::initialize(); }
512 
515  Params(TensorRef const &ref): pointer(ref.data()) {
516  Base::Params::initialize(ref.stride());
517  }
518 
521  Params(Scalar const *ptr,
522  Index _stride_d,
523  Index _stride_h,
524  Index _stride_w,
525  Index _inc_d,
526  Index _inc_h,
527  Index _inc_w,
528  Index _inc_advance)
529  : pointer(ptr) {
531  _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
532  }
533 
537  : pointer(ptr) {
539  }
540 
543  int initialize(TensorRef const &ref) {
544  pointer = ref.data();
545  return Base::Params::initialize(ref.stride());
546  }
547 
550  int initialize(SharedStorage const &storage) {
551  pointer = &storage[0];
553  return 0;
554  }
555 
558  int initialize(Scalar const *ptr) {
559  pointer = ptr;
561  return 0;
562  }
563 
568  pointer = ptr;
569  return 0;
570  }
571 
574  int initialize(Scalar const *ptr,
575  Index _stride_d,
576  Index _stride_h,
577  Index _stride_w,
578  Index _inc_d,
579  Index _inc_h,
580  Index _inc_w,
581  Index _inc_advance) {
582  pointer = ptr;
584  _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
585  return 0;
586  }
587 
588  // Initializes params to default values
591  };
592 
593  //
594  // Data members
595  //
596 
598  Params params;
599 
602 
604  int stage;
605 
606  //
607  // Predicate initialization
608  //
609 
611  template <
613  typename PredicateIterator>
614  CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
615  Coord<3> const &bounds,
616  Coord<3> const &block_offset = make_Coord(0,
617  0,
618  0)) {
620  predicate_it,
622  block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
623  }
624 
626  template <
628  typename PredicateIterator,
630  typename PredicateFunctor>
631  CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
632  PredicateFunctor const &functor,
633  Coord<3> const &block_offset) {
635  predicate_it,
636  functor,
637  block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
638  }
639 
640  //
641  // Methods
642  //
643 
647 
650  TileLoadIterator(Params const &_params,
651  Coord<3> const &block_offset = make_Coord(0, 0, 0),
652  ThreadOffset thread_offset_func = ThreadOffset())
653  : params(_params), stage(0) {
654  thread_offset = thread_offset_func();
655 
656  Index pointer_offset = Index((block_offset[0] + thread_offset[0]) * params.stride_d) +
657  Index((block_offset[1] + thread_offset[1]) * params.stride_h) +
658  Index((block_offset[2] + thread_offset[2]) * params.stride_w);
659 
660  params.pointer += pointer_offset;
661  }
662 
665  TileLoadIterator(Params const &,
666  Scalar const *ptr,
667  Coord<3> const &block_offset = make_Coord(0, 0, 0),
668  ThreadOffset thread_offset_func = ThreadOffset())
669  : stage(0) {
670  params.pointer = ptr + thread_offset_func()[2];
671 
672  params.stride_d = 0;
673  params.stride_h = 0;
674  params.stride_w = 1;
675 
677  }
678 
681 
684 
687 
690 
692  CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const {
693  int const offset =
695  Load<Scalar,
696  kAccessSize,
697  kMemorySpace,
700  Tile::kW,
701  sizeof(FragmentElement) * kAccessSize>::load(value, params.pointer, offset);
702  }
703 
706  if (Tile::kD > 1) {
707  int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
708  if (stage == Tile::kD - 1) {
709  params.pointer -= (Tile::kD - 1) * kStageSize;
710  stage = 0;
711  } else {
712  params.pointer += kStageSize;
713  stage = stage + 1;
714  }
715  }
716  }
717 
720  long long _offset = offset.template dot<long long>(
722  );
723 
724  params.pointer += _offset;
725  return *this;
726  }
727 
730 
732  Index stride = params.stride_h;
733  if (kAdvance == IteratorAdvance::kW) {
734  stride = params.stride_w;
735  }
736  return stride;
737  }
738 
740  template <typename Fragment, typename PredicateIterator>
741  CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
742  FragmentIterator frag_iterator(fragment);
743  for (int d = 0; d < Iterations::kD; ++d) {
744  for (int h = 0; h < Iterations::kH; ++h) {
745  for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
746  for (int c = 0; c < Iterations::kC; ++c) {
747  if (*pred_it) {
748  load_element(
749  reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
750  }
751  }
752  if (w < Iterations::kW - 1) {
753  inc_w();
754  }
755  }
756  if (h < Iterations::kH - 1) {
757  inc_h();
758  }
759  }
760  if (d < Iterations::kD - 1) {
761  inc_d();
762  }
763  }
764  inc_advance();
765  }
766 
768  template <typename Fragment>
770  typename PredicateVector::TrivialIterator pred_it;
771  load_post_increment(fragment, pred_it);
772  }
773 
775  template <typename Fragment, typename PredicateIterator>
776  CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
777  TileLoadIterator _load_it(*this);
778  _load_it.load_post_increment(fragment, pred_it);
779  }
780 
782  template <typename Fragment>
783  CUTLASS_HOST_DEVICE void load(Fragment &fragment) const {
784  typename PredicateVector::TrivialIterator pred_it;
785  load(fragment, pred_it);
786  }
787 
789  template <typename Fragment>
790  CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) {
791  FragmentIterator frag_iterator(fragment);
792  for (int h = 0; h < Iterations::kH; ++h) {
793  for (int w = 0; w < Iterations::kW; ++w) {
794  for (int c = 0; c < Iterations::kC; ++c) {
795  load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
796  }
797  }
798  }
799  }
800 };
801 
803 
827 
833 template <typename Traits_,
834  typename Scalar_,
837  typename Index_ = int,
838  typename FragmentElement_ = Scalar_,
840  typename Skew_ = Shape<0, 0, 0, 0> >
841 struct TileStoreIterator : public TileIteratorBase<Traits_,
842  Scalar_,
843  Advance_,
844  MemorySpace,
845  Index_,
846  FragmentElement_,
847  FragmentElementType_,
848  Skew_> {
850  typedef TileIteratorBase<Traits_,
851  Scalar_,
852  Advance_,
853  MemorySpace,
854  Index_,
855  FragmentElement_,
856  FragmentElementType_,
857  Skew_>
859 
861  typedef typename Base::Traits Traits;
862 
864  typedef typename Base::Scalar Scalar;
865 
868 
871 
874 
877 
879  static int const kAccessSize = Base::kAccessSize;
880 
882  typedef typename Base::Index Index;
883 
885  typedef typename Base::LongIndex LongIndex;
886 
888  typedef typename Base::Skew Skew;
889 
891  typedef typename Base::Tile Tile;
892 
894  typedef typename Base::Delta Delta;
895 
897  typedef typename Base::Iterations Iterations;
898 
901 
904 
906  typedef typename Base::AccessType AccessType;
907 
909  typedef typename Base::Fragment Fragment;
910 
913 
916 
919 
921  typedef typename Base::Storage SharedStorage;
922 
924  typedef typename Base::Params BaseParams;
925 
927  typedef Scalar *Pointer;
928 
931 
933  struct Params : public BaseParams {
936 
937  //
938  // Methods
939  //
940 
941  // Default constructor
943  Params() : pointer(0) {}
944 
945  // Default constructor
948 
951  Params(TensorRef const &ref): pointer(ref.data()) {
952  Base::Params::initialize(ref.stride());
953  }
954 
955  // Default constructor
959  }
960 
961  // Default constructor
964  Index _stride_d,
965  Index _stride_h,
966  Index _stride_w,
967  Index _inc_d,
968  Index _inc_h,
969  Index _inc_w,
970  Index _inc_advance) {
971  initialize(ptr, _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
972  }
973 
976  int initialize(SharedStorage &storage) {
977  pointer = &storage[0];
978  return Base::Params::initialize();
979  }
980 
983  int initialize(Scalar *ptr) {
984  pointer = ptr;
985  return Base::Params::initialize();
986  }
987 
992  pointer = ptr;
993  return 0;
994  }
995 
998  int initialize(Scalar *ptr,
999  Index _stride_d,
1000  Index _stride_h,
1001  Index _stride_w,
1002  Index _inc_d,
1003  Index _inc_h,
1004  Index _inc_w,
1005  Index _inc_advance) {
1006  pointer = ptr;
1008  _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
1009  return 0;
1010  }
1011 
1015  };
1016 
1017  //
1018  // Data members
1019  //
1020 
1023 
1026 
1028  int stage;
1029 
1030  //
1031  // Predicate initialization
1032  //
1033 
1035  template <
1037  typename PredicateIterator>
1038  CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
1039  Coord<3> const &bounds,
1040  Coord<3> const &block_offset = make_Coord(0,
1041  0,
1042  0)) {
1044  predicate_it,
1046  block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
1047  }
1048 
1050  template <
1052  typename PredicateIterator,
1054  typename PredicateFunctor>
1055  CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
1056  PredicateFunctor const &functor,
1057  Coord<3> const &block_offset) {
1059  predicate_it,
1060  functor,
1061  block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
1062  }
1063 
1064  //
1065  // Methods
1066  //
1067 
1071 
1074  TileStoreIterator(Params const &_params,
1075  Coord<3> const &block_offset = make_Coord(0, 0, 0),
1076  ThreadOffset thread_offset_func = ThreadOffset())
1077  : params(_params), stage(0) {
1078  thread_offset = thread_offset_func();
1079 
1080  params.pointer += (block_offset[0] + thread_offset[0]) * params.stride_d +
1081  (block_offset[1] + thread_offset[1]) * params.stride_h +
1082  (block_offset[2] + thread_offset[2]) * params.stride_w;
1083  }
1084 
1087  TileStoreIterator(Params const &, Scalar *ptr, ThreadOffset thread_offset_func = ThreadOffset())
1088  : stage(0) {
1089  params.pointer = ptr + thread_offset_func()[2];
1090  params.stride_d = 0;
1091  params.stride_h = 0;
1092  params.stride_w = 1;
1093 
1095  }
1096 
1099 
1102 
1105 
1108 
1111  if (Tile::kD > 1) {
1112  int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
1113  if (stage == Tile::kD - 1) {
1114  params.pointer -= (Tile::kD - 1) * kStageSize;
1115  stage = 0;
1116  } else {
1117  params.pointer += kStageSize;
1118  stage = stage + 1;
1119  }
1120  }
1121  }
1122 
1125  params.pointer += offset.template dot<long long>(
1127  );
1128  return *this;
1129  }
1130 
1133 
1135  CUTLASS_HOST_DEVICE void store_element(AccessType const &value, int d, int h, int w, int c) {
1136  int const offset =
1138  Store<Scalar,
1139  kAccessSize,
1140  kMemorySpace,
1143  Tile::kW,
1144  sizeof(FragmentElement) * kAccessSize>::store(value, params.pointer, offset);
1145  }
1146 
1148  template <typename Fragment, typename PredicateIterator>
1149  CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it) {
1150  FragmentConstIterator frag_iterator(fragment);
1151 
1152  for (int d = 0; d < Iterations::kD; ++d) {
1153  for (int h = 0; h < Iterations::kH; ++h) {
1154  for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
1155  for (int c = 0; c < Iterations::kC; ++c) {
1156  if (*pred_it) {
1157  store_element(
1158  reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
1159  }
1160  }
1161  if (w < Iterations::kW - 1) {
1162  inc_w();
1163  }
1164  }
1165  if (h < Iterations::kH - 1) {
1166  inc_h();
1167  }
1168  }
1169  if (d < Iterations::kD - 1) {
1170  inc_d();
1171  }
1172  }
1173  inc_advance();
1174  }
1175 
1177  template <typename Fragment>
1179  typename PredicateVector::TrivialIterator pred_it;
1180  store_post_increment(fragment, pred_it);
1181  }
1182 
1184  template <typename Fragment, typename PredicateIterator>
1185  CUTLASS_HOST_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const {
1186  TileStoreIterator _store_it(*this);
1187  _store_it.store_post_increment(fragment, pred_it);
1188  }
1189 
1191  template <typename Fragment>
1192  CUTLASS_HOST_DEVICE void store(Fragment const &fragment) const {
1193  typename PredicateVector::TrivialIterator pred_it;
1194  store(fragment, pred_it);
1195  }
1196 
1198  CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const {
1199  int const offset =
1201 
1202  Load<Scalar,
1203  kAccessSize,
1204  kMemorySpace,
1207  Tile::kW,
1208  sizeof(FragmentElement) * kAccessSize>::load(value, params.pointer, offset);
1209  }
1210 
1212  template <typename Fragment, typename PredicateIterator>
1213  CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
1214  FragmentIterator frag_iterator(fragment);
1215 
1216  for (int d = 0; d < Iterations::kD; ++d) {
1217  for (int h = 0; h < Iterations::kH; ++h) {
1218  for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
1219  for (int c = 0; c < Iterations::kC; ++c) {
1220  if (*pred_it) {
1221  load_element(
1222  reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
1223  }
1224  }
1225  if (w < Iterations::kW - 1) {
1226  inc_w();
1227  }
1228  }
1229  if (h < Iterations::kH - 1) {
1230  inc_h();
1231  }
1232  }
1233  if (d < Iterations::kD - 1) {
1234  inc_d();
1235  }
1236  }
1237  inc_advance();
1238  }
1239 
1241  template <typename Fragment>
1243  typename PredicateVector::TrivialIterator pred_it;
1244  load_post_increment(fragment, pred_it);
1245  }
1246 
1248  template <typename Fragment, typename PredicateIterator>
1249  CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
1250  TileStoreIterator _load_it(*this);
1251  _load_it.load_post_increment(fragment, pred_it);
1252  }
1253 
1255  template <typename Fragment>
1256  CUTLASS_HOST_DEVICE void load(Fragment &fragment) const {
1257  typename PredicateVector::TrivialIterator pred_it;
1258  load(fragment, pred_it);
1259  }
1260 
1262  template <typename Fragment>
1263  CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) {
1264  FragmentIterator frag_iterator(fragment);
1265  for (int h = 0; h < Iterations::kH; ++h) {
1266  for (int w = 0; w < Iterations::kW; ++w) {
1267  for (int c = 0; c < Iterations::kC; ++c) {
1268  load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
1269  }
1270  }
1271  }
1272  }
1273 };
1274 
1276 
1277 } // namespace cutlass
CUTLASS_HOST_DEVICE void inc_advance()
Increment in the next dimension.
Definition: tile_iterator.h:689
CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: tile_iterator.h:1198
Vectorize< FragmentElement, kAccessSize >::Type AccessType
The elements loaded/store by one instruction.
Definition: tile_iterator.h:191
Base::Tile Tile
Tile shape.
Definition: tile_iterator.h:891
Delta_ Delta
Definition: tile_iterator.h:113
Base::PredicateVector PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:918
CUTLASS_HOST_DEVICE Params()
Initialize params to access storage object.
Definition: tile_iterator.h:507
CUTLASS_HOST_DEVICE TileLoadIterator(Params const &_params, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile load iterator.
Definition: tile_iterator.h:650
CUTLASS_HOST_DEVICE int initialize(SharedStorage const &storage)
Initialize params to access storage object.
Definition: tile_iterator.h:550
Tile_ Tile
Shape of the tile.
Definition: tile_iterator.h:80
#define CUTLASS_PRAGMA_UNROLL
Definition: performance_tuning.h:35
Index_ Index
Index type.
Definition: tile_iterator.h:164
Definition: convert.h:33
Defines a structure containing strides, bounds, and a pointer to tensor data.
Base::Iterations Iterations
Iterations.
Definition: tile_iterator.h:455
CUTLASS_HOST_DEVICE int initialize(Coord< 4 > const &stride)
Initializes the parameters object from a vector of strides.
Definition: tile_iterator.h:286
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: tile_iterator.h:337
Skew_ Skew
Skew quantity.
Definition: tile_iterator.h:170
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:467
Enum to specify which memory space data resides in.
Definition: load_store.h:38
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:1242
Base::Index Index
Index type.
Definition: tile_iterator.h:882
Base::Storage SharedStorage
Storage object that may be loaded from.
Definition: tile_iterator.h:482
int stage
The stage.
Definition: tile_iterator.h:1028
Base::Tile Tile
Tile shape.
Definition: tile_iterator.h:449
FragmentIterator< Fragment, Iterations, AccessType > FragmentIterator
The fragment iterator.
Definition: tile_iterator.h:202
Scalar * Pointer
Pointer to underlying type.
Definition: tile_iterator.h:927
Traits::ThreadOffset ThreadOffset
Thread offset.
Definition: tile_iterator.h:185
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Kind
Definition: tile_iterator.h:65
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr, Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:998
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Shape< 0, 0, 0, 0 > ImmediateOffsetStrides
Strides for immediate offset computation.
Definition: tile_iterator.h:102
Base::FragmentConstIterator FragmentConstIterator
Fragment const iterator definition.
Definition: tile_iterator.h:476
A template defining Tile Traits Concept.
Definition: tile_iterator.h:78
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:428
TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > Base
Base class.
Definition: tile_iterator.h:419
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector using a RegularTilePredicateFunctor.
Definition: tile_iterator.h:614
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:199
Traits::Iterations Iterations
Iterations.
Definition: tile_iterator.h:182
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d)
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:790
Base::Delta Delta
Delta.
Definition: tile_iterator.h:452
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:741
Base::LongIndex LongIndex
Long index type.
Definition: tile_iterator.h:885
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:590
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:437
Traits::ImmediateOffsetStrides ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: tile_iterator.h:179
Base::ThreadOffset ThreadOffset
ThreadOffset functor.
Definition: tile_iterator.h:458
CUTLASS_HOST_DEVICE int initialize()
Initializes params to default values.
Definition: tile_iterator.h:1014
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:470
Definition: tile_iterator.h:65
Base::Storage SharedStorage
Storage object which may be stored to.
Definition: tile_iterator.h:921
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:876
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector using a RegularTilePredicateFunctor.
Definition: tile_iterator.h:1038
Index inc_d
Definition: tile_iterator.h:226
CUTLASS_HOST_DEVICE void inc_d()
Increment in the D dimension.
Definition: tile_iterator.h:1098
ThreadOffset_ ThreadOffset
Functor that returns the logical coordinate of each entity&#39;s initial offset in the tile...
Definition: tile_iterator.h:99
Iterator that always returns true.
Definition: predicate_vector.h:309
CUTLASS_HOST_DEVICE Params(Coord< 4 > const &stride)
Constructs params with a stride vector.
Definition: tile_iterator.h:259
Scalar * pointer
Pointer to memory.
Definition: tile_iterator.h:935
CUTLASS_HOST_DEVICE Params(Scalar *ptr, long long stride_d, Index stride_h, Index stride_w)
Definition: tile_iterator.h:957
Base::FragmentShape FragmentShape
Fragment type.
Definition: tile_iterator.h:903
Kind
Definition: load_store.h:39
PredicateVector< ShapeCount< Iterations >::kCount > PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:209
CUTLASS_HOST_DEVICE Index stride_advance(void)
Definition: tile_iterator.h:731
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:909
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr, Index stride_d, Index stride_h, Index stride_w)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:990
TensorRef< Scalar, 4 > TensorRef
Tensor reference for the store iterator.
Definition: tile_iterator.h:930
Base::Traits Traits
concept TileTraits
Definition: tile_iterator.h:422
Definition: load_store.h:178
CUTLASS_HOST_DEVICE int initialize(Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:265
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:1256
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &functor, Coord< 3 > const &block_offset)
Initializes a predicate vector using an arbitrary predicate functor.
Definition: tile_iterator.h:631
CUTLASS_HOST_DEVICE Params(Scalar *ptr, Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Definition: tile_iterator.h:963
CUTLASS_HOST_DEVICE void store_element(AccessType const &value, int d, int h, int w, int c)
Stores a single fragment element into memory.
Definition: tile_iterator.h:1135
FragmentIterator::FragmentShape FragmentShape
The shape of the fragment.
Definition: tile_iterator.h:206
Scalar const * Pointer
The pointer type.
Definition: tile_iterator.h:491
CUTLASS_HOST_DEVICE Params(Scalar const *ptr, Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initialize params to access storage object.
Definition: tile_iterator.h:521
static IteratorAdvance::Kind const kAdvance
Specifies dimension in which post-increment accesses advance.
Definition: tile_iterator.h:155
Index stride_h
Definition: tile_iterator.h:223
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
Parameters.
Definition: tile_iterator.h:933
static FragmentElementType::Kind const kFragmentElementType
Specifies iterator storage fragment type (Scalar or WmmaMatrix)
Definition: tile_iterator.h:158
Traits_ Traits
concept TileTraits
Definition: tile_iterator.h:146
Params params
Parameters structure.
Definition: tile_iterator.h:1022
Base::FragmentConstIterator FragmentConstIterator
Fragment const iterator definition.
Definition: tile_iterator.h:915
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:402
Base::Params BaseParams
IteratorBase parameters.
Definition: tile_iterator.h:924
CUTLASS_HOST_DEVICE void inc_stage()
Increment the stage.
Definition: tile_iterator.h:705
CUTLASS_HOST_DEVICE int initialize()
Gotta have this.
Definition: tile_iterator.h:324
Kind
Definition: load_store.h:48
CUTLASS_HOST_DEVICE RegularTilePredicateFunctor(Coord< 3 > _bounds)
Constructs a predicate functor given the bounds of a tensor.
Definition: tile_iterator.h:120
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset)
Adds a raw offset to the pointer.
Definition: tile_iterator.h:1132
CUTLASS_HOST_DEVICE TileLoadIterator()
Default constructor.
Definition: tile_iterator.h:646
Definition: load_store.h:40
Base::Params BaseParams
IteratorBase parameters.
Definition: tile_iterator.h:485
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:566
Params params
Parameters structure.
Definition: tile_iterator.h:598
FragmentConstIterator< Fragment, Iterations, AccessType > FragmentConstIterator
The fragment const iterator.
Definition: tile_iterator.h:204
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr)
Initialize params to access storage object.
Definition: tile_iterator.h:558
CUTLASS_HOST_DEVICE TileLoadIterator & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: tile_iterator.h:719
CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: tile_iterator.h:692
Base::AccessType AccessType
Memory access type.
Definition: tile_iterator.h:906
Definition: tile_iterator.h:488
Definition: tile_iterator.h:134
Iterations_ Iterations
Number of accesses performed.
Definition: tile_iterator.h:86
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:183
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:776
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset)
Adds a raw offset to the pointer.
Definition: tile_iterator.h:729
static int const kAccessSize
Access size.
Definition: tile_iterator.h:105
Fragment< Scalar, ShapeCount< Tile >::kCount, kFragmentSize > Storage
The storage.
Definition: tile_iterator.h:197
Base::Skew Skew
Skew quantity.
Definition: tile_iterator.h:888
Delta_ Delta
Number of steps between accesses along each dimension.
Definition: tile_iterator.h:83
Defines abstractions for efficiently loading and storing vectors to memory.
Base::Scalar Scalar
Scalar element.
Definition: tile_iterator.h:425
Base::FragmentShape FragmentShape
Fragment type.
Definition: tile_iterator.h:461
static FragmentElementType::Kind const kFragmentElementType
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:873
CUTLASS_HOST_DEVICE TileStoreIterator(Params const &, Scalar *ptr, ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile store iterator.
Definition: tile_iterator.h:1087
CUTLASS_HOST_DEVICE void inc_advance()
Increment in the next dimension.
Definition: tile_iterator.h:1107
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE Params(Scalar *ptr)
Definition: tile_iterator.h:947
CUTLASS_HOST_DEVICE TileStoreIterator(Params const &_params, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile store iterator.
Definition: tile_iterator.h:1074
Base::LongIndex LongIndex
Index type.
Definition: tile_iterator.h:443
Index inc_h
Definition: tile_iterator.h:227
CUTLASS_HOST_DEVICE Params()
Constructs params.
Definition: tile_iterator.h:238
Index stride_d
Definition: tile_iterator.h:222
Definition: vector.h:62
TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > Base
Base class.
Definition: tile_iterator.h:858
Definition: load_store.h:60
Base::Traits Traits
concept TileTraits
Definition: tile_iterator.h:861
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
CUTLASS_HOST_DEVICE Params(Scalar const *ptr)
Initialize params to access storage object.
Definition: tile_iterator.h:511
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:912
Specifies dimension in which post-increment accesses advance.
Definition: tile_iterator.h:64
CUTLASS_HOST_DEVICE Params(TensorRef const &ref)
Constructs with a CompactTensorRef<>
Definition: tile_iterator.h:515
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:161
CUTLASS_HOST_DEVICE TileStoreIterator()
Default constructor.
Definition: tile_iterator.h:1070
Definition: load_store.h:48
CUTLASS_HOST_DEVICE int initialize(TensorRef const &ref)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:543
Base::Scalar Scalar
Scalar element.
Definition: tile_iterator.h:864
Defines a 1D vector of elements held in the registers of each thread.
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:144
Scalar const * pointer
Pointer to memory.
Definition: tile_iterator.h:499
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it)
Stores a fragment and advances to the next tile.
Definition: tile_iterator.h:1149
Definition: tile_iterator.h:65
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:431
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:152
CUTLASS_HOST_DEVICE void store(Fragment const &fragment) const
Stores a fragment without advancing the iterator.
Definition: tile_iterator.h:1192
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: tile_iterator.h:1025
Base::Iterations Iterations
Iterations.
Definition: tile_iterator.h:897
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:879
CUTLASS_HOST_DEVICE Params(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w)
Initialize params to access storage object.
Definition: tile_iterator.h:536
Functor computing a predicate given the logical position of an access.
Definition: tile_iterator.h:112
Traits::Tile Tile
Tile shape.
Definition: tile_iterator.h:173
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:188
Parameters.
Definition: tile_iterator.h:497
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:199
Base::PredicateVector PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:479
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d)
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:1263
CUTLASS_HOST_DEVICE int initialize(Index _stride_d, Index _stride_h, Index _stride_w)
Initializes the parameters object from a vector of strides.
Definition: tile_iterator.h:292
Base::AccessType AccessType
Memory access type.
Definition: tile_iterator.h:464
Base::Skew Skew
Skew quantity.
Definition: tile_iterator.h:446
Base::Delta Delta
Delta.
Definition: tile_iterator.h:894
CUTLASS_HOST_DEVICE void inc_h()
Increment in the H dimension.
Definition: tile_iterator.h:1101
CUTLASS_HOST_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const
Stores a fragment without advancing the iterator.
Definition: tile_iterator.h:1185
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:686
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:783
Base::ThreadOffset ThreadOffset
ThreadOffset functor.
Definition: tile_iterator.h:900
Definition: tile_iterator.h:65
static FragmentElementType::Kind const kFragmentElementType
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:434
CUTLASS_HOST_DEVICE void inc_h()
Increment in the H dimension.
Definition: tile_iterator.h:683
static CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &predicate_func, Coord< 3 > const &offset)
Initializes a predicate vector.
Definition: tile_iterator.h:345
Scalar_ Scalar
Scalar element.
Definition: tile_iterator.h:149
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: tile_iterator.h:601
Coord< 3 > bounds
Dimensions of the bounding volume.
Definition: tile_iterator.h:116
Traits::Delta Delta
Distance along each dimension.
Definition: tile_iterator.h:176
static int const kFragmentSize
The size of storage needed per fragment.
Definition: tile_iterator.h:194
Index inc_advance
Definition: tile_iterator.h:230
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment)
Stores a fragment and advances to the next tile.
Definition: tile_iterator.h:1178
long long LongIndex
Long index.
Definition: tile_iterator.h:167
CUTLASS_HOST_DEVICE Params()
Definition: tile_iterator.h:943
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr)
Initialize params to access storage object.
Definition: tile_iterator.h:983
Base::FragmentElement FragmentElement
Fragment element.
Definition: tile_iterator.h:867
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:1249
CUTLASS_HOST_DEVICE Params(TensorRef const &ref)
Constructs with a CompactTensorRef<>
Definition: tile_iterator.h:951
CUTLASS_HOST_DEVICE void inc_d()
Increment in the D dimension.
Definition: tile_iterator.h:680
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:769
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
Parameters to the iterator.
Definition: tile_iterator.h:216
CUTLASS_HOST_DEVICE TileStoreIterator & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: tile_iterator.h:1124
CUTLASS_HOST_DEVICE void inc_stage()
Increment the stage.
Definition: tile_iterator.h:1110
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:473
CUTLASS_HOST_DEVICE bool operator()(Coord< 3 > iteration, Coord< 3 > offset) const
Computes the predicate given the logical position of an access.
Definition: tile_iterator.h:124
Base::Index Index
Index type.
Definition: tile_iterator.h:440
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:1104
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:870
int stage
Stage argument enables wrapping after some number of tiles have been loaded.
Definition: tile_iterator.h:604
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:574
CUTLASS_HOST_DEVICE Params(Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Constructs params.
Definition: tile_iterator.h:242
CUTLASS_HOST_DEVICE TileLoadIterator(Params const &, Scalar const *ptr, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile load iterator.
Definition: tile_iterator.h:665
CUTLASS_HOST_DEVICE int initialize(SharedStorage &storage)
Initialize params to access storage object.
Definition: tile_iterator.h:976
Index inc_w
Definition: tile_iterator.h:228
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &functor, Coord< 3 > const &block_offset)
Initializes a predicate vector using an arbitrary predicate functor.
Definition: tile_iterator.h:1055
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:841
Index stride_w
Definition: tile_iterator.h:224
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:1213