Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
predicate_vector.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <assert.h>
32 #include <stdint.h>
33 
34 #include "cutlass/cutlass.h"
35 #include "cutlass/shape.h"
36 
37 #include "cutlass/util/platform.h"
38 
39 namespace cutlass {
40 
42 
59 
79 
95 
98 template <
100  int kPredicates_,
102  int kPredicatesPerByte_ = 4,
104  int kPredicateStart_ = 0>
107  static int const kPredicates = kPredicates_;
108 
110  static int const kPredicatesPerByte = kPredicatesPerByte_;
111 
113  static int const kPredicateStart = kPredicateStart_;
114 
115  // Make sure no one tries to put more than 8 bits in a byte :)
116  static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");
117  // Make sure the "offsetted" bits fit in one byte.
119  "The offsetted predicates must fit within an actual byte.");
120 
122  typedef uint32_t Storage;
123 
126 
128  static int const kWordCount = (kBytes + sizeof(Storage) - 1) / sizeof(Storage);
129 
130  private:
131  //
132  // Data members
133  //
134 
136  Storage storageData[kWordCount];
137 
138  //
139  // Methods
140  //
141 
143  CUTLASS_HOST_DEVICE void computeStorageOffset(int &word, int &bit, int idx) const {
145 
146  int byte = (idx / kPredicatesPerByte);
147  int bit_offset = (idx % kPredicatesPerByte);
148 
149  word = byte / sizeof(Storage);
150  int byte_offset = (byte % sizeof(Storage));
151 
152  bit = byte_offset * 8 + bit_offset + kPredicateStart;
153  }
154 
156  CUTLASS_HOST_DEVICE Storage &storage(int word) {
157  CUTLASS_ASSERT(word < kWordCount);
158  return storageData[word];
159  }
160 
162  CUTLASS_HOST_DEVICE Storage const &storage(int word) const {
163  CUTLASS_ASSERT(word < kWordCount);
164  return storageData[word];
165  }
166 
167  public:
168  //
169  // Iterator
170  //
171 
179  PredicateVector const &vec_;
180 
182  int bit_;
183 
184  public:
187  ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
188 
191  ConstIterator(PredicateVector const &_vec, int _start = 0) : vec_(_vec), bit_(_start) {}
192 
196  ++bit_;
197  return *this;
198  }
199 
203  --bit_;
204  return *this;
205  }
206 
210  ConstIterator ret(*this);
211  ret.bit_++;
212  return ret;
213  }
214 
218  ConstIterator ret(*this);
219  ret.bit_--;
220  return ret;
221  }
222 
225  bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; }
226 
229  bool operator!=(ConstIterator const &it) const { return bit_ != it.bit_; }
230 
233  bool operator*() const { return vec_[bit_]; }
234  };
235 
241  class Iterator {
243  PredicateVector &vec_;
244 
246  int bit_;
247 
248  public:
251  Iterator(Iterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
252 
255  Iterator(PredicateVector &_vec, int _start = 0) : vec_(_vec), bit_(_start) {}
256 
260  ++bit_;
261  return *this;
262  }
263 
267  --bit_;
268  return *this;
269  }
270 
274  Iterator ret(*this);
275  ret.bit_++;
276  return ret;
277  }
278 
282  Iterator ret(*this);
283  ret.bit_--;
284  return ret;
285  }
286 
289  bool operator==(Iterator const &it) const { return bit_ == it.bit_; }
290 
293  bool operator!=(Iterator const &it) const { return bit_ != it.bit_; }
294 
297  bool get() { return vec_[bit_]; }
298 
301  bool operator*() const { return vec_[bit_]; }
302 
305  void set(bool value = true) { vec_.set(bit_, value); }
306  };
307 
313 
316  TrivialIterator(Iterator const &it) {}
317 
321 
324  TrivialIterator &operator++() { return *this; }
325 
328  TrivialIterator operator++(int) { return *this; }
329 
332  bool operator*() const { return true; }
333  };
334 
335  public:
336  //
337  // Methods
338  //
339 
341  CUTLASS_HOST_DEVICE PredicateVector(bool value = true) { fill(value); }
342 
344  CUTLASS_HOST_DEVICE void fill(bool value = true) {
345  Storage item = (value ? ~Storage(0) : Storage(0));
346 
348  for (int i = 0; i < kWordCount; ++i) {
349  storage(i) = item;
350  }
351  }
352 
354  CUTLASS_HOST_DEVICE bool operator[](int idx) const { return at(idx); }
355 
357  CUTLASS_HOST_DEVICE bool at(int idx) const {
358  int bit, word;
359  computeStorageOffset(word, bit, idx);
360 
361  return ((storage(word) >> bit) & 1);
362  }
363 
365  CUTLASS_HOST_DEVICE void set(int idx, bool value = true) {
366  int bit, word;
367  computeStorageOffset(word, bit, idx);
368 
369  Storage disable_mask = (~(Storage(1) << bit));
370  Storage enable_mask = (Storage(value) << bit);
371 
372  storage(word) = ((storage(word) & disable_mask) | enable_mask);
373  }
374 
378  for (int i = 0; i < kWordCount; ++i) {
379  storage(i) = (storage(i) & predicates.storage(i));
380  }
381  return *this;
382  }
383 
387  for (int i = 0; i < kWordCount; ++i) {
388  storage(i) = (storage(i) | predicates.storage(i));
389  }
390  return *this;
391  }
392 
395  Storage mask(0);
396  for (int byte = 0; byte < sizeof(Storage); ++byte) {
397  Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
398  mask |= (byte_mask << (byte * 8));
399  }
400  uint32_t result = 0;
401  for (int word = 0; word < kWordCount; ++word) {
402  result |= storage(word);
403  }
404  return result == 0;
405  }
406 
408  CUTLASS_DEVICE
409  Iterator begin() { return Iterator(*this); }
410 
412  CUTLASS_DEVICE
413  Iterator end() { return Iterator(*this, kPredicates); }
414 
416  CUTLASS_DEVICE
417  ConstIterator const_begin() const { return ConstIterator(*this); }
418 
420  CUTLASS_DEVICE
421  ConstIterator const_end() const { return ConstIterator(*this, kPredicates); }
422 };
423 
425 
430 
432  CUTLASS_HOST_DEVICE bool at(int, int, int, int) const { return true; }
433 };
434 
436 
438 template <typename PredicateVector_, typename Iterations_>
441  typedef PredicateVector_ PredicateVector;
443  typedef Iterations_ Iterations;
444 
445  private:
447  PredicateVector &predicates;
448 
449  public:
451  CUTLASS_DEVICE PredicateTileAdapter(PredicateVector &predicates_) : predicates(predicates_) {}
452 
454  CUTLASS_DEVICE bool at(int d, int h, int w, int c) const {
455  int const bit = ComputeOffsetFromShape<Iterations>::get(d, h, w, c);
456  return predicates.at(bit);
457  }
458 
460  CUTLASS_DEVICE void set(int d, int h, int w, int c, bool value) {
461  int const bit = ComputeOffsetFromShape<Iterations>::get(d, h, w, c);
462  predicates.set(bit, value);
463  }
464 };
465 
467 
469 template <typename PredicateVector_, typename Iterations_>
472  typedef PredicateVector_ PredicateVector;
474  typedef Iterations_ Iterations;
475 
476  private:
478  PredicateVector const &predicates;
479 
480  public:
482  CUTLASS_DEVICE ConstPredicateTileAdapter(PredicateVector const &predicates_)
483  : predicates(predicates_) {}
484 
486  CUTLASS_DEVICE bool at(int d, int h, int w, int c) const {
487  int const bit = ComputeOffsetFromShape<Iterations>::get(d, h, w, c);
488  return predicates.at(bit);
489  }
490 };
491 
493 
494 } // namespace cutlass
CUTLASS_HOST_DEVICE Iterator(PredicateVector &_vec, int _start=0)
Constructs an iterator from a PredicateVector.
Definition: predicate_vector.h:255
CUTLASS_HOST_DEVICE bool operator!=(ConstIterator const &it) const
Returns false if iterators point to the same bit.
Definition: predicate_vector.h:229
CUTLASS_HOST_DEVICE PredicateVector & operator|=(PredicateVector const &predicates)
Computes the union of two identical predicate vectors.
Definition: predicate_vector.h:385
CUTLASS_HOST_DEVICE TrivialIterator & operator++()
Pre-increment.
Definition: predicate_vector.h:324
#define CUTLASS_PRAGMA_UNROLL
Definition: performance_tuning.h:35
Definition: convert.h:33
CUTLASS_HOST_DEVICE bool is_zero() const
Returns true if entire predicate array is zero.
Definition: predicate_vector.h:394
uint32_t Storage
Storage type of individual elements.
Definition: predicate_vector.h:116
CUTLASS_HOST_DEVICE TrivialIterator(PredicateVector const &_vec)
Constructs an iterator from a PredicateVector.
Definition: predicate_vector.h:320
CUTLASS_HOST_DEVICE ConstIterator & operator--()
Pre-decrement.
Definition: predicate_vector.h:202
static int const kBytes
Number of bytes needed.
Definition: predicate_vector.h:125
CUTLASS_DEVICE ConstIterator const_begin() const
Returns a ConstIterator.
Definition: predicate_vector.h:417
CUTLASS_HOST_DEVICE ConstIterator(PredicateVector const &_vec, int _start=0)
Definition: predicate_vector.h:191
CUTLASS_HOST_DEVICE bool at(int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:357
CUTLASS_HOST_DEVICE ConstIterator & operator++()
Pre-increment.
Definition: predicate_vector.h:195
PredicateVector_ PredicateVector
The vector of predicates.
Definition: predicate_vector.h:441
CUTLASS_HOST_DEVICE ConstIterator operator++(int)
Post-increment.
Definition: predicate_vector.h:209
CUTLASS_HOST_DEVICE Iterator operator++(int)
Post-increment.
Definition: predicate_vector.h:273
Adapter to enable random access to predicates via logical coordinate within a tile.
Definition: predicate_vector.h:439
CUTLASS_HOST_DEVICE TrivialIterator(Iterator const &it)
Copy constructor.
Definition: predicate_vector.h:316
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:181
C++ features that may be otherwise unimplemented for CUDA device functions.
Iterator that always returns true.
Definition: predicate_vector.h:309
CUTLASS_HOST_DEVICE TrivialIterator operator++(int)
Post-increment.
Definition: predicate_vector.h:328
CUTLASS_HOST_DEVICE bool operator==(Iterator const &it) const
Returns true if iterators point to the same bit.
Definition: predicate_vector.h:289
CUTLASS_DEVICE PredicateTileAdapter(PredicateVector &predicates_)
Ctor.
Definition: predicate_vector.h:451
CUTLASS_DEVICE bool at(int d, int h, int w, int c) const
Get the value at location (d, h, w, c).
Definition: predicate_vector.h:454
CUTLASS_DEVICE bool at(int d, int h, int w, int c) const
Get the value at location (d, h, w, c).
Definition: predicate_vector.h:486
CUTLASS_HOST_DEVICE Iterator & operator--()
Pre-decrement.
Definition: predicate_vector.h:266
PredicateVector_ PredicateVector
The vector of predicates.
Definition: predicate_vector.h:472
CUTLASS_HOST_DEVICE PredicateVector & operator &=(PredicateVector const &predicates)
Computes the intersection of two identical predicate vectors.
Definition: predicate_vector.h:376
CUTLASS_HOST_DEVICE Iterator(Iterator const &it)
Copy constructor.
Definition: predicate_vector.h:251
CUTLASS_HOST_DEVICE bool operator[](int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:354
CUTLASS_HOST_DEVICE bool operator*() const
Dereferences iterator.
Definition: predicate_vector.h:301
CUTLASS_HOST_DEVICE bool operator*() const
Dereferences iterator.
Definition: predicate_vector.h:332
CUTLASS_HOST_DEVICE void fill(bool value=true)
Fills all predicates with a given value.
Definition: predicate_vector.h:344
static int const kPredicates
Number of bits stored by the PredicateVector.
Definition: predicate_vector.h:107
CUTLASS_DEVICE Iterator end()
Returns an iterator.
Definition: predicate_vector.h:413
#define CUTLASS_ASSERT(x)
Definition: cutlass.h:50
CUTLASS_HOST_DEVICE bool at(int, int, int, int) const
The value at location (d, h, w, c).
Definition: predicate_vector.h:432
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
static int const kPredicatesPerByte
Number of bits stored within each byte of the predicate bit vector.
Definition: predicate_vector.h:110
#define static_assert(__e, __m)
Definition: platform.h:153
Statically sized array of bits implementing.
Definition: predicate_vector.h:105
static int const kWordCount
Number of storage elements needed.
Definition: predicate_vector.h:128
CUTLASS_DEVICE ConstIterator const_end() const
Returns a ConstIterator.
Definition: predicate_vector.h:421
Always returns true predicate.
Definition: predicate_vector.h:427
CUTLASS_HOST_DEVICE Iterator & operator++()
Pre-increment.
Definition: predicate_vector.h:259
A const iterator implementing Predicate Iterator Concept enabling sequential read-only access to pred...
Definition: predicate_vector.h:177
CUTLASS_HOST_DEVICE void set(int idx, bool value=true)
Set a bit within the predicate vector.
Definition: predicate_vector.h:365
CUTLASS_HOST_DEVICE bool operator==(ConstIterator const &it) const
Returns true if iterators point to the same bit.
Definition: predicate_vector.h:225
Iterations_ Iterations
The iterations.
Definition: predicate_vector.h:474
Iterations_ Iterations
The iterations.
Definition: predicate_vector.h:443
CUTLASS_HOST_DEVICE bool operator*() const
Dereferences iterator.
Definition: predicate_vector.h:233
CUTLASS_HOST_DEVICE bool operator!=(Iterator const &it) const
Returns false if iterators point to the same bit.
Definition: predicate_vector.h:293
static int const kPredicateStart
First bit withing each byte containing predicates.
Definition: predicate_vector.h:113
CUTLASS_HOST_DEVICE ConstIterator(ConstIterator const &it)
Copy constructor.
Definition: predicate_vector.h:187
CUTLASS_HOST_DEVICE TrivialPredicateTileAdapter()
Ctor.
Definition: predicate_vector.h:429
CUTLASS_HOST_DEVICE ConstIterator operator--(int)
Post-decrement.
Definition: predicate_vector.h:217
Adapter to enable random access to predicates via logical coordinate within a tile.
Definition: predicate_vector.h:470
CUTLASS_DEVICE ConstPredicateTileAdapter(PredicateVector const &predicates_)
Ctor.
Definition: predicate_vector.h:482
Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
CUTLASS_HOST_DEVICE PredicateVector(bool value=true)
Initialize the predicate vector.
Definition: predicate_vector.h:341
CUTLASS_DEVICE Iterator begin()
Returns an iterator to the start of the bit vector.
Definition: predicate_vector.h:409
Basic include for CUTLASS macros.
An iterator implementing Predicate Iterator Concept enabling sequential read and write access to pred...
Definition: predicate_vector.h:241
CUTLASS_HOST_DEVICE Iterator operator--(int)
Post-decrement.
Definition: predicate_vector.h:281
CUTLASS_HOST_DEVICE TrivialIterator()
Constructor.
Definition: predicate_vector.h:312