80 typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
82 int StorageRank_ = MapFunc_::kStorageRank,
84 typename Index_ = int,
86 typename LongIndex_ =
long long 89 public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> {
143 ref.add_pointer_offset(offset_);
206 return offset_ - it.offset_;
253 template <
typename TensorRef_>
255 TensorRefBatchStrided<
256 typename TensorRef_::Storage,
258 typename TensorRef_::MapFunc,
259 TensorRef_::kStorageGrank,
260 typename TensorRef_::Index,
261 typename TensorRef_::LongIndex
263 TensorRef_
const &ref,
264 typename TensorRef_::LongIndex batch_stride = 0) {
267 typename TensorRef_::Storage,
269 typename TensorRef_::MapFunc,
270 TensorRef_::kStorageGrank,
271 typename TensorRef_::Index,
272 typename TensorRef_::LongIndex
273 >(ref, batch_stride);
292 typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
294 int StorageRank_ = MapFunc_::kStorageRank,
296 typename Index_ = int,
298 typename LongIndex_ =
long long 341 return ref_.reference(idx_);
TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > TensorRef
Containing class's tensor rev.
Definition: tensor_ref_collection.h:322
Constant iterator over tensors implied by TensorRefBatchStrided.
Definition: tensor_ref_collection.h:118
CUTLASS_HOST_DEVICE ConstIterator & operator--()
Moves to the previous tensor.
Definition: tensor_ref_collection.h:177
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE ConstIterator & operator--()
Definition: tensor_ref_collection.h:371
Index * strides[kStorageRank - 1]
Array of strides.
Definition: tensor_ref_collection.h:407
Storage_ Storage
Element pointed to by the TensorRef.
Definition: tensor_ref_collection.h:306
Definition: tensor_ref_collection.h:300
static int const kStorageRank
Rank of the stride vector.
Definition: tensor_ref_collection.h:315
Base::Storage Storage
Storage type.
Definition: tensor_ref_collection.h:99
CUTLASS_HOST_DEVICE ConstIterator & operator-=(Index idx)
Definition: tensor_ref_collection.h:385
LongIndex_ LongIndex
Typically, strides in memory can be very large.
Definition: tensor_ref_collection.h:312
CUTLASS_HOST_DEVICE ConstIterator begin()
Returns an iterator.
Definition: tensor_ref_collection.h:247
LongIndex tensor_stride
Stride between tensors.
Definition: tensor_ref_collection.h:215
CUTLASS_HOST_DEVICE LongIndex operator-(ConstIterator const &it)
Returns the difference in offset between two iterators.
Definition: tensor_ref_collection.h:205
CUTLASS_HOST_DEVICE ConstIterator operator++(int)
Advances the iterator to point to the next tensor.
Definition: tensor_ref_collection.h:156
Base TensorRef
Tensor reference implied by the TensorRefBatchStrided.
Definition: tensor_ref_collection.h:115
CUTLASS_HOST_DEVICE ConstIterator operator--(int)
Moves to the previous tensor.
Definition: tensor_ref_collection.h:184
CUTLASS_HOST_DEVICE ConstIterator begin()
Returns an TesnorRefIterator over the TensorRef objects in this collection.
Definition: tensor_ref_collection.h:442
TensorRefIterator over TensorRef objects in TensorRefArray.
Definition: tensor_ref_collection.h:318
Index_ Index
Index type.
Definition: tensor_ref.h:146
CUTLASS_HOST_DEVICE TensorRef operator*() const
Obtains a TensorRef pointed to by the iterator.
Definition: tensor_ref_collection.h:141
TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > TensorRef
TensorRef type obtained from the TensorRefArray.
Definition: tensor_ref_collection.h:397
Index_ Index
Index type.
Definition: tensor_ref_collection.h:309
static int const kRank
Rank of the logical tensor.
Definition: tensor_ref_collection.h:102
CUTLASS_HOST_DEVICE ConstIterator operator-(Index idx)
Definition: tensor_ref_collection.h:391
LongIndex_ LongIndex
Typically, strides in memory can be very large.
Definition: tensor_ref.h:149
CUTLASS_HOST_DEVICE ConstIterator operator--(int)
Advances to next TensorRef.
Definition: tensor_ref_collection.h:378
CUTLASS_HOST_DEVICE TensorRef at(Index idx=0) const
Definition: tensor_ref_collection.h:239
CUTLASS_HOST_DEVICE TensorRefBatchStrided< typename TensorRef_::Storage, TensorRef_::kRank, typename TensorRef_::MapFunc, TensorRef_::kStorageGrank, typename TensorRef_::Index, typename TensorRef_::LongIndex > make_TensorRefBatchStrided(TensorRef_ const &ref, typename TensorRef_::LongIndex batch_stride=0)
Helper to construct a TensorRefBatchStrided<> object using type deduction.
Definition: tensor_ref_collection.h:262
LongIndex_ LongIndex
Typically, strides in memory can be very large.
Definition: tensor_ref_collection.h:108
Coord< kRank > TensorCoord
Coordinate in logical tensor space.
Definition: tensor_ref_collection.h:112
CUTLASS_HOST_DEVICE ConstIterator operator-(Index idx)
Returns an iterator moved forward by (idx) amount.
Definition: tensor_ref_collection.h:192
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE TensorRefBatchStrided()
Definition: tensor_ref_collection.h:223
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:331
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:49
Index_ Index
Index type.
Definition: tensor_ref_collection.h:105
CUTLASS_HOST_DEVICE TensorRefArray()
Definition: tensor_ref_collection.h:415
CUTLASS_HOST_DEVICE ConstIterator(TensorRefArray const &ref, int idx=0)
Constructs a ConstIterator over the TensorRef objects.
Definition: tensor_ref_collection.h:336
CUTLASS_HOST_DEVICE ConstIterator operator+(Index idx)
Returns an iterator advanced by (idx) amount.
Definition: tensor_ref_collection.h:164
CUTLASS_HOST_DEVICE ConstIterator & operator+=(Index idx)
Advances this iterator by (idx) and returns a reference to self.
Definition: tensor_ref_collection.h:170
Storage ** pointers
Base addresses.
Definition: tensor_ref_collection.h:404
CUTLASS_HOST_DEVICE LongIndex get_pointer_offset(Index idx) const
Gets the pointer offset.
Definition: tensor_ref_collection.h:233
Definition: tensor_ref_collection.h:88
CUTLASS_HOST_DEVICE ConstIterator operator+(Index idx)
Definition: tensor_ref_collection.h:360
CUTLASS_HOST_DEVICE ConstIterator & operator-=(Index idx)
Moves this iterator by (idx) and returns a reference to self.
Definition: tensor_ref_collection.h:198
Base TensorRef
TensorRef returned by the iterator.
Definition: tensor_ref_collection.h:121
CUTLASS_HOST_DEVICE TensorRefBatchStrided(TensorRef const &ref, LongIndex _tensor_stride=0)
Definition: tensor_ref_collection.h:227
CUTLASS_HOST_DEVICE ConstIterator & operator+=(Index idx)
Definition: tensor_ref_collection.h:365
CUTLASS_HOST_DEVICE ConstIterator(TensorRefBatchStrided const &ref, LongIndex offset=0)
Constructs a ConstIterator from a parent TensorRefBatchStrided.
Definition: tensor_ref_collection.h:135
CUTLASS_HOST_DEVICE ConstIterator & operator++()
Advances the iterator to point to the next tensor.
Definition: tensor_ref_collection.h:149
TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > Base
Underlying TensorRef type.
Definition: tensor_ref_collection.h:96
CUTLASS_HOST_DEVICE TensorRef operator*() const
Obtains a TensorRef pointed to by this iterator.
Definition: tensor_ref_collection.h:340
CUTLASS_HOST_DEVICE TensorRefArray(Storage **_pointers, Index _strides[kStorageRank - 1])
Definition: tensor_ref_collection.h:419
CUTLASS_HOST_DEVICE ConstIterator & operator++()
Advances to next TensorRef.
Definition: tensor_ref_collection.h:346
CUTLASS_HOST_DEVICE TensorRef at(Index idx=0) const
Definition: tensor_ref_collection.h:431
CUTLASS_HOST_DEVICE ConstIterator operator++(int)
Advances to next TensorRef.
Definition: tensor_ref_collection.h:353