Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
tensor_view.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
37 #pragma once
38 
39 #include <cmath>
40 
41 #include "cutlass/cutlass.h"
42 #include "cutlass/tensor_ref.h"
43 
44 namespace cutlass {
45 
47 
49 template <
51  typename Storage_,
53  int Rank_ = 4,
55  typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
57  int StorageRank_ = MapFunc_::kStorageRank,
59  typename Index_ = int,
61  typename LongIndex_ = long long
62 >
63 class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> {
64  public:
67 
69  typedef TensorRef<
71  Rank_,
72  MapFunc_,
73  StorageRank_,
74  Index_,
75  LongIndex_> ConstTensorRef;
76 
78  typedef Base TensorRef;
79 
81  typedef typename Base::Storage Storage;
82 
84  typedef typename Base::Index Index;
85 
88 
91 
95 
97  typedef TensorView<
99  Rank_,
100  MapFunc_,
101  StorageRank_,
102  Index_,
103  LongIndex_> ConstTensorView;
104 
105  //
106  // Definitions included for backwards compatibility - to be removed in next major release
107  //
108 
111 
113  static int const Rank = Base::kRank;
114 
116  typedef typename Base::LongIndex Offset_t;
117 
120 
123 
124  private:
125  //
126  // Data members
127  //
128 
130  TensorCoord size_;
131 
132  public:
133  //
134  // Device and Host Methods
135  //
136 
140 
143  TensorView(Base const& _ref, TensorCoord const& _size) : Base(_ref), size_(_size) {}
144 
148  Storage *ptr,
149  StrideVector const &stride,
150  TensorCoord const& size
151  ):
152  Base(ptr, stride), size_(size) {}
153 
157  Storage *ptr,
158  StorageCoord const &stride,
159  TensorCoord const& size
160  ):
161  Base(ptr, stride), size_(size) {}
162 
165  void reset(Base const& _ref = Base(), TensorCoord const& _size = TensorCoord()) {
166  Base::operator=(_ref);
167  size_ = _size;
168  }
169 
172  TensorCoord const& size() const { return size_; }
173 
176  Index size(int dim) const { return size_.at(dim); }
177 
180  TensorView& operator=(TensorView const& _tensor) {
181  Base::operator=(_tensor);
182  size_ = _tensor.size_;
183  return *this;
184  }
185 
188  bool contains(TensorCoord const& coord) const {
190  for (int dim = 0; dim < Rank_; ++dim) {
191  if (coord[dim] >= size_[dim]) {
192  return false;
193  }
194  }
195  return true;
196  }
197 
200  TensorRef ref() const {
201  return TensorRef(*this);
202  }
203 
207  return ConstTensorRef(*this);
208  }
209 
212  TensorView subview(TensorCoord const& location, TensorCoord size) const {
213  return TensorView((*this) + location, size.clamp(size_ - location));
214  }
215 
218  size_t capacity() const {
219  int max_rank = 0;
220 
221  StorageCoord mapped_size(this->map(size()));
222 
224  for (int i = 0; i < Base::kStorageRank; ++i) {
225  if (!i ||
226  this->stride(i) * mapped_size[i] > this->stride(max_rank) * mapped_size[max_rank]) {
227  max_rank = i;
228  }
229  }
230  return this->stride(max_rank) * mapped_size[max_rank];
231  }
232 
235  TensorView operator+(TensorCoord const& b) const {
236  TensorView result(*this);
237  result.add_pointer_offset(this->offset(b));
238  return result;
239  }
240 
244  this->add_pointer_offset(this->offset(b));
245  return *this;
246  }
247 
250  TensorView operator-(TensorCoord const& b) const {
251  TensorRef result(*this);
252  result.add_pointer_offset(-this->offset(b));
253  return result;
254  }
255 
259  this->add_pointer_offset(-this->offset(b));
260  return *this;
261  }
262 };
263 
265 
266 } // namespace cutlass
Base::Index Index
Index type.
Definition: tensor_view.h:84
#define CUTLASS_PRAGMA_UNROLL
Definition: performance_tuning.h:35
Definition: convert.h:33
CUTLASS_HOST_DEVICE void reset(Base const &_ref=Base(), TensorCoord const &_size=TensorCoord())
Updates the reference and size of a Tensor_view object.
Definition: tensor_view.h:165
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE TensorView(Storage *ptr, StorageCoord const &stride, TensorCoord const &size)
Constructs a TensorView from a pointer, a stride vector, and size.
Definition: tensor_view.h:156
T type
Definition: platform.h:377
TensorRef TensorRef_t
Base class.
Definition: tensor_view.h:119
CUTLASS_HOST_DEVICE Index size(int dim) const
Accesses the size.
Definition: tensor_view.h:176
Base::Storage Storage
Storage type.
Definition: tensor_view.h:81
Base TensorRef
Base tensor reference.
Definition: tensor_view.h:78
TensorView< typename platform::remove_const< Storage >::type const, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > ConstTensorView
TensorView of constant value.
Definition: tensor_view.h:103
Storage_ Storage
Data type of individual access.
Definition: tensor_ref.h:134
CUTLASS_HOST_DEVICE StorageCoord stride() const
Returns the stride of the tensor.
Definition: tensor_ref.h:300
CUTLASS_HOST_DEVICE TensorCoord const & size() const
Accesses the size.
Definition: tensor_view.h:172
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex delta)
Adds an offset to each pointer.
Definition: tensor_ref.h:357
Index_ Index
Index type.
Definition: tensor_ref.h:146
TensorRef< typename platform::remove_const< Storage_ >::type const, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > ConstTensorRef
Tensor reference to of constant value.
Definition: tensor_view.h:75
CUTLASS_HOST_DEVICE TensorView & operator-=(TensorCoord const &b)
Returns a TensorRef offset by a given amount.
Definition: tensor_view.h:258
CUTLASS_HOST_DEVICE StorageCoord map(TensorCoord const &coord) const
Maps a logical coordinate to an n-D array in memory.
Definition: tensor_ref.h:325
Defines a view into a logical tensor.
Definition: tensor_view.h:63
CUTLASS_HOST_DEVICE TensorView(Base const &_ref, TensorCoord const &_size)
Constructs a TensorView from a TensorRef and size.
Definition: tensor_view.h:143
TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > Base
Base tensor reference.
Definition: tensor_view.h:66
Definition: tensor_ref.h:131
CUTLASS_HOST_DEVICE TensorView & operator=(TensorView const &_tensor)
Assigns the Tensor_view.
Definition: tensor_view.h:180
CUTLASS_HOST_DEVICE size_t capacity() const
Returns the number of scalar elements needed to store tensor.
Definition: tensor_view.h:218
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE Index & at()
Gets the index of a given Coord element.
Definition: coord.h:240
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:331
TensorRef::TensorCoord TensorCoord
Coordinate in logical tensor space.
Definition: tensor_view.h:87
Base::LongIndex Offset_t
Type used to compute the offset of an element to the base of a tensor.
Definition: tensor_view.h:116
TensorCoord Coord_t
Coordinate in logical tensor space.
Definition: tensor_view.h:110
CUTLASS_HOST_DEVICE TensorView subview(TensorCoord const &location, TensorCoord size) const
Returns a Tensor_view given location and size quantities.
Definition: tensor_view.h:212
CUTLASS_HOST_DEVICE TensorRef ref() const
Returns a TensorRef pointing to the first element of the tensor.
Definition: tensor_view.h:200
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:49
static int const Rank
Logical rank of tensor index space.
Definition: tensor_view.h:113
CUTLASS_HOST_DEVICE TensorView & operator+=(TensorCoord const &b)
Returns a TensorRef offset by a given amount.
Definition: tensor_view.h:243
CUTLASS_HOST_DEVICE ConstTensorRef const_ref() const
Returns a TensorRef pointing to the first element of the tensor.
Definition: tensor_view.h:206
TensorRef::StrideVector StrideVector
Definition: tensor_view.h:94
CUTLASS_HOST_DEVICE TensorView operator-(TensorCoord const &b) const
Returns a TensorRef offset by a given amount.
Definition: tensor_view.h:250
CUTLASS_HOST_DEVICE TensorView()
Default constructor.
Definition: tensor_view.h:139
CUTLASS_HOST_DEVICE TensorView(Storage *ptr, StrideVector const &stride, TensorCoord const &size)
Constructs a TensorView from a pointer, a stride vector, and size.
Definition: tensor_view.h:147
CUTLASS_HOST_DEVICE Coord & clamp(Coord< kRank > const &max, Coord< kRank > const &min=Coord< kRank >())
Clamps a coordinate to a range specified by maximum and minimum values.
Definition: coord.h:274
CUTLASS_HOST_DEVICE TensorView operator+(TensorCoord const &b) const
Returns a TensorView offset by a given amount.
Definition: tensor_view.h:235
TensorRef::ConstTensorRef ConstTensorRef_t
TensorRef to const-valued type.
Definition: tensor_view.h:122
Basic include for CUTLASS macros.
TensorRef::StorageCoord StorageCoord
Coordinate in storage n-D array.
Definition: tensor_view.h:90
CUTLASS_HOST_DEVICE bool contains(TensorCoord const &coord) const
Determines whether a location is within a tensor.
Definition: tensor_view.h:188