Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
fragment.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <assert.h>
32 #include "cutlass/shape.h"
34 #include "cutlass/vector.h"
35 
36 namespace cutlass {
37 
39 
56 
73 
75 template <int alignment>
76 struct StorageType {
77  typedef uint64_t Type;
78 };
79 template <>
80 struct StorageType<4> {
81  typedef uint32_t Type;
82 };
83 template <>
84 struct StorageType<2> {
85  typedef uint16_t Type;
86 };
87 template <>
88 struct StorageType<1> {
89  typedef uint8_t Type;
90 };
91 
93 
98 template <typename Element_, int kElements_, size_t kAlignment_ = 16>
99 struct Fragment : public AlignedStruct<kAlignment_> {
101  static_assert(kAlignment_ == 16 || kAlignment_ >= sizeof(Element_), "Alignment is too small");
103  static_assert(is_pow2<kAlignment_>::value, "Alignment must be a power of two");
104 
108  typedef Element_ Element;
110  static int const kElements = kElements_;
112  static int const kAlignment = kAlignment_;
113 
116  // Avoid element-wise access for sub 32b element type
117  if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
118  uint64_t* ptr = reinterpret_cast<uint64_t*>(storage);
119  for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) {
120  ptr[i] = uint64_t(0);
121  }
122  } else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) {
123  uint32_t* ptr = reinterpret_cast<uint32_t*>(storage);
124  for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) {
125  ptr[i] = uint32_t(0);
126  }
127  } else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) {
128  uint16_t* ptr = reinterpret_cast<uint16_t*>(storage);
129  for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) {
130  ptr[i] = uint16_t(0);
131  }
132  } else {
133  for (int i = 0; i < kElements; ++i) {
134  storage[i] = 0;
135  }
136  }
137  }
138 
140  CUTLASS_HOST_DEVICE Element& operator[](int i) { return reinterpret_cast<Element*>(storage)[i]; }
141 
143  CUTLASS_HOST_DEVICE Element const& operator[](int i) const {
144  return reinterpret_cast<Element const*>(storage)[i];
145  }
146 
147  private:
150 
152  static int const kStorageCount =
153  (sizeof(Element_) * kElements_ + sizeof(StorageType) - 1) / sizeof(StorageType);
155  StorageType storage[kStorageCount];
156 
158  static_assert(sizeof(StorageType) <= kAlignment_, "StorageType is too big for given alignment");
159 };
160 
162 
167 template <typename Fragment_, typename Iterations_, typename AccessType_>
172  typedef Fragment_ Fragment;
174  typedef Iterations_ Iterations;
176  typedef AccessType_ AccessType;
177 
179  typedef typename Fragment::Element Element;
181  static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
186 
188  template <typename OtherFragment_>
189  CUTLASS_HOST_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
190  : pointer(reinterpret_cast<Element*>(&fragment[offset])) {
191  static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
192  }
193 
195  CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
196  int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
197  return reinterpret_cast<AccessType const&>(pointer[imm]);
198  }
199 
201  CUTLASS_HOST_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
202  int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
203  return reinterpret_cast<AccessType&>(pointer[imm]);
204  }
205 
208  return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
209  }
210 
213  return reinterpret_cast<AccessType&>(pointer[i * kElementsPerAccess]);
214  }
215 
217  CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
218 
221 };
222 
224 
225 template <typename Fragment_, typename Iterations_, typename AccessType_>
230  typedef Fragment_ Fragment;
232  typedef Iterations_ Iterations;
234  typedef AccessType_ AccessType;
235 
237  typedef typename Fragment::Element Element;
239  static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
244 
246  template <typename OtherFragment_>
247  CUTLASS_HOST_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
248  : pointer(reinterpret_cast<Element const*>(&fragment[offset])) {
249  static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
250  }
254  : pointer(reinterpret_cast<Element const*>(rhs_.offset)) {}
255 
257  CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
258  int const imm = ComputeOffsetFromStrides<IterationsStrides>::get(d, h, w, c);
259  return reinterpret_cast<AccessType const&>(pointer[imm]);
260  }
261 
264  return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
265  }
266 
268  CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
269 
271  Element const* pointer;
272 };
273 
275 
276 } // namespace cutlass
CUTLASS_HOST_DEVICE void clear()
Clear a fragment.
Definition: fragment.h:115
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: fragment.h:217
Definition: convert.h:33
Shape< Shape_::kH *Shape_::kW *Shape_::kC, Shape_::kW *Shape_::kC, Shape_::kC, elementsPerAccess > Shape
Definition: shape.h:170
Definition: vector.h:42
Definition: fragment.h:226
CUTLASS_HOST_DEVICE FragmentIterator(OtherFragment_ &fragment, int offset=0)
Ctor.
Definition: fragment.h:189
A template defining Fragment Concept.
Definition: fragment.h:99
Fragment::Element Element
The element.
Definition: fragment.h:179
static int const kElementsPerAccess
The number of elements per access.
Definition: fragment.h:239
Fragment_ Fragment
The fragment.
Definition: fragment.h:172
Fragment_ Fragment
The fragment.
Definition: fragment.h:230
Fragment::Element Element
The element.
Definition: fragment.h:237
Fragment< Element_, kElements_ > This_
Make sure the alignment makes sense wrt the size of elements.
Definition: fragment.h:101
FragmentIterator< Fragment_, Iterations_, AccessType_ > This_
This class.
Definition: fragment.h:170
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:241
Math utilities.
Definition: fragment.h:76
uint32_t Type
Definition: fragment.h:81
uint8_t Type
Definition: fragment.h:89
Element * pointer
The pointer.
Definition: fragment.h:220
CUTLASS_HOST_DEVICE Element const & operator[](int i) const
The accessor.
Definition: fragment.h:143
AccessType_ AccessType
The access type.
Definition: fragment.h:234
ShapeStrides< FragmentShape, kElementsPerAccess >::Shape IterationsStrides
The linear strides for iterations.
Definition: fragment.h:243
Definition: shape.h:118
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:183
CUTLASS_HOST_DEVICE FragmentConstIterator(OtherFragment_ &fragment, int offset=0)
Ctor.
Definition: fragment.h:247
A template defining Fragment Iterator Concept.
Definition: fragment.h:168
static int const kElements
The number of elements.
Definition: fragment.h:110
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Iterations_ Iterations
The number of iterations.
Definition: fragment.h:232
CUTLASS_HOST_DEVICE AccessType const & at(int d, int h, int w, int c=0) const
The accessor.
Definition: fragment.h:195
#define static_assert(__e, __m)
Definition: platform.h:153
Iterations_ Iterations
The number of iterations.
Definition: fragment.h:174
CUTLASS_HOST_DEVICE FragmentConstIterator(FragmentIterator< Fragment_, Iterations_, AccessType_ > const &rhs_)
Create from non-constant FragmentIterator.
Definition: fragment.h:252
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Element_ Element
The element.
Definition: fragment.h:108
FragmentIterator< Fragment_, Iterations_, AccessType_ > This_
This class.
Definition: fragment.h:228
CUTLASS_HOST_DEVICE AccessType const & operator[](int i) const
The accessor.
Definition: fragment.h:263
CUTLASS_HOST_DEVICE Element & operator[](int i)
The accessor.
Definition: fragment.h:140
CUTLASS_HOST_DEVICE AccessType const & operator[](int i) const
The accessor.
Definition: fragment.h:207
uint16_t Type
Definition: fragment.h:85
Defines a 1D vector of elements held in the registers of each thread.
uint64_t Type
Definition: fragment.h:77
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: fragment.h:268
ShapeStrides< FragmentShape, kElementsPerAccess >::Shape Strides
The linear strides for iterations.
Definition: fragment.h:185
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:199
CUTLASS_HOST_DEVICE AccessType & operator[](int i)
The accessor.
Definition: fragment.h:212
CUTLASS_HOST_DEVICE AccessType & at(int d, int h, int w, int c=0)
The accessor.
Definition: fragment.h:201
static int const kElementsPerAccess
The number of elements per access.
Definition: fragment.h:181
Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
AccessType_ AccessType
The access type.
Definition: fragment.h:176
static int const kAlignment
Alignment.
Definition: fragment.h:112
Definition: cutlass_math.h:45
CUTLASS_HOST_DEVICE AccessType const & at(int d, int h, int w, int c=0) const
The accessor.
Definition: fragment.h:257
Element const * pointer
The pointer.
Definition: fragment.h:271