Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_coord.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include "cutlass/cutlass.h"
33 #include "cutlass/coord.h"
34 #include "cutlass/util/platform.h"
35 
36 namespace cutlass {
37 namespace gemm {
38 
40 
43 struct GemmCoord : public Coord<4, int> {
44 
46  typedef int Index;
47 
50 
52  static int const kK = 0;
53 
55  static int const kN = 1;
56 
58  static int const kM = 2;
59 
61  static int const kBatch = 3;
62 
63  //
64  // Methods
65  //
66 
69  GemmCoord() { }
70 
73  GemmCoord(Coord<3, Index> const &coord, Index _batch = 0): Base(make_Coord(coord[0], coord[1], coord[2], _batch)) { }
74 
77  GemmCoord(Coord<4, Index> const &coord): Base(coord) { }
78 
81  GemmCoord(Index coord[4]): Base(coord) { }
82 
86 
89  Index const & m() const { return this->at(kM); }
90 
93  Index & m() { return this->at(kM); }
94 
97  Index const & n() const { return this->at(kN); }
98 
101  Index & n() { return this->at(kN); }
102 
105  Index const & k() const { return this->at(kK); }
106 
109  Index & k() { return this->at(kK); }
110 
113  Index const & batch() const { return this->at(kBatch); }
114 
117  Index & batch() { return this->at(kBatch); }
118 
121  Coord<3> knm() const {
122  return make_Coord(k(), n(), m());
123  }
124 
127  Coord<2> nm() const {
128  return make_Coord(n(), m());
129  }
130 
133  Coord<2> mn() const {
134  return make_Coord(m(), n());
135  }
136 
139  Coord<2> km() const {
140  return make_Coord(k(), m());
141  }
142 
145  Coord<2> kn() const {
146  return make_Coord(k(), n());
147  }
148 
149  //
150  // Coord operators
151  //
152 
155  GemmCoord operator+(Base const& b) const {
156  return GemmCoord(Base::operator+(b));
157  }
158 
161  GemmCoord operator-(Base const& b) const {
162  return GemmCoord(Base::operator-(b));
163  }
164 
167  GemmCoord operator*(Base const& b) const {
168  return GemmCoord(Base::operator*(b));
169  }
170 
173  GemmCoord operator/(Base const& b) const {
174  return GemmCoord(Base::operator/(b));
175  }
176 
179  GemmCoord& operator+=(Base const& b) {
180  Base::operator+=(b);
181  return *this;
182  }
183 
186  GemmCoord& operator-=(Base const& b) {
187  Base::operator-=(b);
188  return *this;
189  }
190 
193  GemmCoord& operator*=(Base const& b) {
194  Base::operator*=(b);
195  return *this;
196  }
197 
200  GemmCoord& operator/=(Base const& b) {
201  Base::operator/=(b);
202  return *this;
203  }
204 };
205 
207 
208 } // namespace gemm
209 } // namespace cutlass
Definition: convert.h:33
CUTLASS_HOST_DEVICE GemmCoord & operator/=(Base const &b)
In-place division.
Definition: gemm_coord.h:200
int Index
Integer-valued index.
Definition: gemm_coord.h:46
CUTLASS_HOST_DEVICE Coord< 3 > knm() const
Obtains a Coord<3> from GemmCoord.
Definition: gemm_coord.h:121
CUTLASS_HOST_DEVICE Coord< 2 > km() const
Obtains a Coord<2> from GemmCoord.
Definition: gemm_coord.h:139
Coord< 4, Index > Base
Base type is a Coord of rank=4.
Definition: gemm_coord.h:49
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: gemm_coord.h:97
CUTLASS_HOST_DEVICE GemmCoord operator/(Base const &b) const
Element-wise division.
Definition: gemm_coord.h:173
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
CUTLASS_HOST_DEVICE Index & m()
Returns reference to the GEMM M coordinate.
Definition: gemm_coord.h:93
Definition: gemm_coord.h:43
static int const kBatch
Batch dimension - for generalizing to larger problems.
Definition: gemm_coord.h:61
CUTLASS_HOST_DEVICE GemmCoord operator-(Base const &b) const
Element-wise subtraction.
Definition: gemm_coord.h:161
CUTLASS_HOST_DEVICE Coord & operator*=(Coord const &b)
In-place multiplication.
Definition: coord.h:197
CUTLASS_HOST_DEVICE GemmCoord(Coord< 4, Index > const &coord)
Constructs from Coord<4>
Definition: gemm_coord.h:77
C++ features that may be otherwise unimplemented for CUDA device functions.
CUTLASS_HOST_DEVICE Index & k()
Returns reference to the GEMM K coordinate.
Definition: gemm_coord.h:109
CUTLASS_HOST_DEVICE GemmCoord(Coord< 3, Index > const &coord, Index _batch=0)
Constructs from Coord<3> and a batch.
Definition: gemm_coord.h:73
CUTLASS_HOST_DEVICE GemmCoord operator+(Base const &b) const
Element-wise addition.
Definition: gemm_coord.h:155
CUTLASS_HOST_DEVICE Coord & operator-=(Coord const &b)
In-place subtraction.
Definition: coord.h:188
static int const kK
GEMM K dimension - inner dimension of the GEMM problem.
Definition: gemm_coord.h:52
static int const kN
GEMM N dimension - columns of the output C matrix.
Definition: gemm_coord.h:55
CUTLASS_HOST_DEVICE Coord & operator+=(Coord const &b)
In-place addition.
Definition: coord.h:179
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE Coord< 2 > kn() const
Obtains a Coord<2> from GemmCoord.
Definition: gemm_coord.h:145
CUTLASS_HOST_DEVICE Index & at()
Gets the index of a given Coord element.
Definition: coord.h:240
CUTLASS_HOST_DEVICE Coord< 2 > mn() const
Obtains a Coord<2> from GemmCoord.
Definition: gemm_coord.h:133
CUTLASS_HOST_DEVICE Coord & operator/=(Coord const &b)
In-place division.
Definition: coord.h:206
CUTLASS_HOST_DEVICE GemmCoord & operator+=(Base const &b)
In-place addition.
Definition: gemm_coord.h:179
static int const kM
GEMM M dimension - rows of the output C matrix.
Definition: gemm_coord.h:58
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:49
CUTLASS_HOST_DEVICE GemmCoord()
Default ctor.
Definition: gemm_coord.h:69
CUTLASS_HOST_DEVICE GemmCoord & operator-=(Base const &b)
In-place subtraction.
Definition: gemm_coord.h:186
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: gemm_coord.h:89
CUTLASS_HOST_DEVICE GemmCoord operator*(Base const &b) const
Element-wise multiplication.
Definition: gemm_coord.h:167
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: gemm_coord.h:105
CUTLASS_HOST_DEVICE GemmCoord & operator*=(Base const &b)
In-place multiplication.
Definition: gemm_coord.h:193
CUTLASS_HOST_DEVICE Index const & batch() const
Returns the GEMM batch coordinate.
Definition: gemm_coord.h:113
Basic include for CUTLASS macros.
CUTLASS_HOST_DEVICE Index & n()
Returns reference to the GEMM N coordinate.
Definition: gemm_coord.h:101
CUTLASS_HOST_DEVICE Coord< 2 > nm() const
Obtains a Coord<2> from GemmCoord.
Definition: gemm_coord.h:127
CUTLASS_HOST_DEVICE GemmCoord(Index k, Index n, Index m, Index batch=0)
Helper to construct from a K, N, M, batch variables.
Definition: gemm_coord.h:85
CUTLASS_HOST_DEVICE Index & batch()
Returns reference to the GEMM batch coordinate.
Definition: gemm_coord.h:117
CUTLASS_HOST_DEVICE GemmCoord(Index coord[4])
Constructs from an array of coordinate elements.
Definition: gemm_coord.h:81