Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
tensor_ref_collection.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/tensor_ref.h"
32 
33 namespace cutlass {
34 
36 //
37 // TensorRefCollection is a concept for storing a logical collection of TensorRef objects. Classes
38 // satisfying the TensorRefCollection concept must support the following:
39 //
40 // // Define storage type
41 // typedef typename TensorRefCollection::Storage Storage;
42 //
43 // // Define a type for offsets in memory
44 // typedef typename TensorRefCollection::LongIndex LongIndex;
45 //
46 // // Define a ConstIterator type satisfying TensorRefIterator
47 // typedef typename TensorRefCollection::ConstIterator TensorRefIterator;
48 //
49 // // Implement a begin() method.
50 // TensorRefIterator iterator = collection.begin();
51 //
52 //
53 // TensorRefIterator is a concept for accessing an element in a TensorRefCollection. Classes
54 // satisfying the TensorRefIterator concept must support the following:
55 //
56 // // Define a TensorRef type accessed by the iterator
57 // typedef typename TensorRefIterator::TensorRef TensorRef;
58 //
59 // // Access the TensorRef
60 // TensorRef ref = *iterator;
61 //
62 // // Pre-increment and post-increment
63 // ++iterator;
64 // iterator++;
65 //
66 // // Pre-decrement and post-decrement
67 // --iterator;
68 // iterator--;
69 //
71 
74 template <
76  typename Storage_,
78  int Rank_,
80  typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
82  int StorageRank_ = MapFunc_::kStorageRank,
84  typename Index_ = int,
86  typename LongIndex_ = long long
87 >
89  public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> {
90 
91  //
92  // Type definitions
93  //
94 
97 
99  typedef typename Base::Storage Storage;
100 
102  static int const kRank = Rank_;
103 
105  typedef Index_ Index;
106 
108  typedef LongIndex_ LongIndex;
109 
110 
113 
115  typedef Base TensorRef;
116 
119  public:
121  typedef Base TensorRef;
122 
123  private:
124 
126  TensorRefBatchStrided const &ref_;
127 
129  LongIndex offset_;
130 
131  public:
132 
136  TensorRefBatchStrided const &ref,
137  LongIndex offset = 0): ref_(ref), offset_(offset) { }
138 
142  TensorRef ref(ref_);
143  ref.add_pointer_offset(offset_);
144  return ref;
145  }
146 
150  offset_ += ref_.tensor_stride;
151  return *this;
152  }
153 
157  ConstIterator ret(*this);
158  offset_ += ref_.tensor_stride;
159  return ret;
160  }
161 
165  return ConstIterator(ref_, offset_ + ref_.tensor_stride * idx);
166  }
167 
171  offset_ += ref_.tensor_stride * idx;
172  return *this;
173  }
174 
178  offset_ -= ref_.tensor_stride;
179  return *this;
180  }
181 
185  ConstIterator ret(*this);
186  offset_ -= ref_.tensor_stride;
187  return ret;
188  }
189 
193  return ConstIterator(ref_, offset_ - ref_.tensor_stride * idx);
194  }
195 
199  offset_ -= ref_.tensor_stride * idx;
200  return *this;
201  }
202 
206  return offset_ - it.offset_;
207  }
208  };
209 
210  //
211  // Data members
212  //
213 
216 
217  //
218  // Methods
219  //
220 
221  // Default ctor
224 
225  // Constructs form a tensor reference and
227  TensorRefBatchStrided(TensorRef const &ref, LongIndex _tensor_stride = 0):
228  TensorRef(ref),
229  tensor_stride(_tensor_stride) { }
230 
234  return idx * tensor_stride;
235  }
236 
237  // Returns a reference
239  TensorRef at(Index idx = 0) const {
240  TensorRef ref(*this);
241  ref.add_pointer_offset(get_pointer_offset(idx));
242  return ref;
243  }
244 
248  return ConstIterator(*this);
249  }
250 };
251 
253 template <typename TensorRef_>
255 TensorRefBatchStrided<
256  typename TensorRef_::Storage,
257  TensorRef_::kRank,
258  typename TensorRef_::MapFunc,
259  TensorRef_::kStorageGrank,
260  typename TensorRef_::Index,
261  typename TensorRef_::LongIndex
263  TensorRef_ const &ref,
264  typename TensorRef_::LongIndex batch_stride = 0) {
265 
266  return TensorRefBatchStrided<
267  typename TensorRef_::Storage,
268  TensorRef_::kRank,
269  typename TensorRef_::MapFunc,
270  TensorRef_::kStorageGrank,
271  typename TensorRef_::Index,
272  typename TensorRef_::LongIndex
273  >(ref, batch_stride);
274 }
275 
277 
286 template <
288  typename Storage_,
290  int Rank_,
292  typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
294  int StorageRank_ = MapFunc_::kStorageRank,
296  typename Index_ = int,
298  typename LongIndex_ = long long
299 >
301  //
302  // Type definitions
303  //
304 
306  typedef Storage_ Storage;
307 
309  typedef Index_ Index;
310 
312  typedef LongIndex_ LongIndex;
313 
315  static int const kStorageRank = StorageRank_;
316 
319  public:
320 
323 
324  private:
325 
327  TensorRefArray const &ref_;
328 
330  int idx_;
331 
332  public:
333 
336  ConstIterator(TensorRefArray const &ref, int idx = 0): ref_(ref), idx_(idx) { }
337 
341  return ref_.reference(idx_);
342  }
343 
347  ++idx_;
348  return *this;
349  }
350 
354  ConstIterator ret(*this);
355  idx_ ++;
356  return ret;
357  }
358 
361  return ConstIterator(ref_, idx_ + idx);
362  }
363 
366  idx_ += idx;
367  return *this;
368  }
369 
372  --idx_;
373  return *this;
374  }
375 
379  ConstIterator ret(*this);
380  --idx_;
381  return ret;
382  }
383 
386  idx_ -= idx;
387  return *this;
388  }
389 
392  return ConstIterator(ref_, idx_ + idx);
393  }
394  };
395 
398 
399  //
400  // Data members
401  //
402 
405 
408 
409  //
410  // Methods
411  //
412 
413  // Default ctor
416 
417  // Construct from pointers to arrays to strides
420  Storage **_pointers,
421  Index _strides[kStorageRank - 1]): pointers(_pointers) {
422 
423  // Copy pointers to strides arrays
424  for (int i = 0; i < kStorageRank - 1; ++i) {
425  strides[i] = _strides[i];
426  }
427  }
428 
429  // Returns a TensorRef at the given index in the collection
431  TensorRef at(Index idx = 0) const {
432  Coord<kStorageRank - 1, Index> stride;
434  for (int i = 0; i < kStorageRank - 1; ++i) {
435  stride[i] = strides[idx][i];
436  }
437  return TensorRef(pointers[idx], stride);
438  }
439 
443  return ConstIterator(*this);
444  }
445 };
446 
448 
449 } // namespace cutlass
TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > TensorRef
Containing class&#39;s tensor rev.
Definition: tensor_ref_collection.h:322
Constant iterator over tensors implied by TensorRefBatchStrided.
Definition: tensor_ref_collection.h:118
#define CUTLASS_PRAGMA_UNROLL
Definition: performance_tuning.h:35
Definition: convert.h:33
CUTLASS_HOST_DEVICE ConstIterator & operator--()
Moves to the previous tensor.
Definition: tensor_ref_collection.h:177
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE ConstIterator & operator--()
Definition: tensor_ref_collection.h:371
Index * strides[kStorageRank - 1]
Array of strides.
Definition: tensor_ref_collection.h:407
Storage_ Storage
Element pointed to by the TensorRef.
Definition: tensor_ref_collection.h:306
Definition: tensor_ref_collection.h:300
static int const kStorageRank
Rank of the stride vector.
Definition: tensor_ref_collection.h:315
Base::Storage Storage
Storage type.
Definition: tensor_ref_collection.h:99
CUTLASS_HOST_DEVICE ConstIterator & operator-=(Index idx)
Definition: tensor_ref_collection.h:385
LongIndex_ LongIndex
Typically, strides in memory can be very large.
Definition: tensor_ref_collection.h:312
CUTLASS_HOST_DEVICE ConstIterator begin()
Returns an iterator.
Definition: tensor_ref_collection.h:247
LongIndex tensor_stride
Stride between tensors.
Definition: tensor_ref_collection.h:215
CUTLASS_HOST_DEVICE LongIndex operator-(ConstIterator const &it)
Returns the difference in offset between two iterators.
Definition: tensor_ref_collection.h:205
CUTLASS_HOST_DEVICE ConstIterator operator++(int)
Advances the iterator to point to the next tensor.
Definition: tensor_ref_collection.h:156
Base TensorRef
Tensor reference implied by the TensorRefBatchStrided.
Definition: tensor_ref_collection.h:115
CUTLASS_HOST_DEVICE ConstIterator operator--(int)
Moves to the previous tensor.
Definition: tensor_ref_collection.h:184
CUTLASS_HOST_DEVICE ConstIterator begin()
Returns an TesnorRefIterator over the TensorRef objects in this collection.
Definition: tensor_ref_collection.h:442
TensorRefIterator over TensorRef objects in TensorRefArray.
Definition: tensor_ref_collection.h:318
Index_ Index
Index type.
Definition: tensor_ref.h:146
CUTLASS_HOST_DEVICE TensorRef operator*() const
Obtains a TensorRef pointed to by the iterator.
Definition: tensor_ref_collection.h:141
TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > TensorRef
TensorRef type obtained from the TensorRefArray.
Definition: tensor_ref_collection.h:397
Index_ Index
Index type.
Definition: tensor_ref_collection.h:309
static int const kRank
Rank of the logical tensor.
Definition: tensor_ref_collection.h:102
CUTLASS_HOST_DEVICE ConstIterator operator-(Index idx)
Definition: tensor_ref_collection.h:391
LongIndex_ LongIndex
Typically, strides in memory can be very large.
Definition: tensor_ref.h:149
CUTLASS_HOST_DEVICE ConstIterator operator--(int)
Advances to next TensorRef.
Definition: tensor_ref_collection.h:378
CUTLASS_HOST_DEVICE TensorRef at(Index idx=0) const
Definition: tensor_ref_collection.h:239
CUTLASS_HOST_DEVICE TensorRefBatchStrided< typename TensorRef_::Storage, TensorRef_::kRank, typename TensorRef_::MapFunc, TensorRef_::kStorageGrank, typename TensorRef_::Index, typename TensorRef_::LongIndex > make_TensorRefBatchStrided(TensorRef_ const &ref, typename TensorRef_::LongIndex batch_stride=0)
Helper to construct a TensorRefBatchStrided<> object using type deduction.
Definition: tensor_ref_collection.h:262
LongIndex_ LongIndex
Typically, strides in memory can be very large.
Definition: tensor_ref_collection.h:108
Coord< kRank > TensorCoord
Coordinate in logical tensor space.
Definition: tensor_ref_collection.h:112
CUTLASS_HOST_DEVICE ConstIterator operator-(Index idx)
Returns an iterator moved forward by (idx) amount.
Definition: tensor_ref_collection.h:192
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE TensorRefBatchStrided()
Definition: tensor_ref_collection.h:223
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:331
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:49
Index_ Index
Index type.
Definition: tensor_ref_collection.h:105
CUTLASS_HOST_DEVICE TensorRefArray()
Definition: tensor_ref_collection.h:415
CUTLASS_HOST_DEVICE ConstIterator(TensorRefArray const &ref, int idx=0)
Constructs a ConstIterator over the TensorRef objects.
Definition: tensor_ref_collection.h:336
CUTLASS_HOST_DEVICE ConstIterator operator+(Index idx)
Returns an iterator advanced by (idx) amount.
Definition: tensor_ref_collection.h:164
CUTLASS_HOST_DEVICE ConstIterator & operator+=(Index idx)
Advances this iterator by (idx) and returns a reference to self.
Definition: tensor_ref_collection.h:170
Storage ** pointers
Base addresses.
Definition: tensor_ref_collection.h:404
CUTLASS_HOST_DEVICE LongIndex get_pointer_offset(Index idx) const
Gets the pointer offset.
Definition: tensor_ref_collection.h:233
Definition: tensor_ref_collection.h:88
CUTLASS_HOST_DEVICE ConstIterator operator+(Index idx)
Definition: tensor_ref_collection.h:360
CUTLASS_HOST_DEVICE ConstIterator & operator-=(Index idx)
Moves this iterator by (idx) and returns a reference to self.
Definition: tensor_ref_collection.h:198
Base TensorRef
TensorRef returned by the iterator.
Definition: tensor_ref_collection.h:121
CUTLASS_HOST_DEVICE TensorRefBatchStrided(TensorRef const &ref, LongIndex _tensor_stride=0)
Definition: tensor_ref_collection.h:227
CUTLASS_HOST_DEVICE ConstIterator & operator+=(Index idx)
Definition: tensor_ref_collection.h:365
CUTLASS_HOST_DEVICE ConstIterator(TensorRefBatchStrided const &ref, LongIndex offset=0)
Constructs a ConstIterator from a parent TensorRefBatchStrided.
Definition: tensor_ref_collection.h:135
CUTLASS_HOST_DEVICE ConstIterator & operator++()
Advances the iterator to point to the next tensor.
Definition: tensor_ref_collection.h:149
TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > Base
Underlying TensorRef type.
Definition: tensor_ref_collection.h:96
CUTLASS_HOST_DEVICE TensorRef operator*() const
Obtains a TensorRef pointed to by this iterator.
Definition: tensor_ref_collection.h:340
CUTLASS_HOST_DEVICE TensorRefArray(Storage **_pointers, Index _strides[kStorageRank - 1])
Definition: tensor_ref_collection.h:419
CUTLASS_HOST_DEVICE ConstIterator & operator++()
Advances to next TensorRef.
Definition: tensor_ref_collection.h:346
CUTLASS_HOST_DEVICE TensorRef at(Index idx=0) const
Definition: tensor_ref_collection.h:431
CUTLASS_HOST_DEVICE ConstIterator operator++(int)
Advances to next TensorRef.
Definition: tensor_ref_collection.h:353