Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_desc.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/tensor_ref.h"
32 
33 namespace cutlass {
34 namespace gemm {
35 
37 template <
39  typename AType_,
41  typename BType_,
43  typename CType_,
45  typename DType_,
47  typename SType_,
49  typename Index_ = int
50 > struct GemmDesc {
51  //
52  // Type definitions
53  //
54 
56  typedef Index_ Index;
57 
59  typedef AType_ AType;
60 
63 
65  typedef BType_ BType;
66 
69 
71  typedef CType_ CType;
72 
75 
77  typedef DType_ DType;
78 
81 
83  typedef SType_ SType;
84 
85  //
86  // Data members
87  //
88 
91 
94 
97 
99  long long batch_stride_A;
100 
103 
105  long long batch_stride_B;
106 
109 
112 
114  long long batch_stride_C;
115 
118 
120  long long batch_stride_D;
121 
122  //
123  // Methods
124  //
125 
128  GemmDesc(): problem_size(0, 0, 0, 1), alpha(1), beta(0) {}
129 
132  GemmDesc(Coord<3> _problem_size,
133  SType _alpha,
134  TensorRefA const &_A,
135  TensorRefB const &_B,
136  SType _beta,
137  TensorRefC const &_C,
138  TensorRefD const &_D
139  ):
140  problem_size(_problem_size[0], _problem_size[1], _problem_size[2], 1),
141  alpha(_alpha),
142  A(_A),
143  batch_stride_A(0),
144  B(_B),
145  batch_stride_B(0),
146  beta(_beta),
147  C(_C),
148  batch_stride_C(0),
149  D(_D),
150  batch_stride_D(0) {}
151 
154  GemmDesc(GemmCoord _problem_size,
155  SType _alpha,
156  TensorRefA const &_A,
157  TensorRefB const &_B,
158  SType _beta,
159  TensorRefC const &_C,
160  TensorRefD const &_D
161  ):
162  problem_size(_problem_size.k(), _problem_size.n(), _problem_size.m(), 1),
163  alpha(_alpha),
164  A(_A),
165  batch_stride_A(0),
166  B(_B),
167  batch_stride_B(0),
168  beta(_beta),
169  C(_C),
170  batch_stride_C(0),
171  D(_D),
172  batch_stride_D(0) {
173 
174  assert(_problem_size.batch() == 1);
175  }
176 
179  GemmDesc(GemmCoord _problem_size,
180  SType _alpha,
181  TensorRefA const &_A,
182  long long _batch_stride_A,
183  TensorRefB const &_B,
184  long long _batch_stride_B,
185  SType _beta,
186  TensorRefC const &_C,
187  long long _batch_stride_C,
188  TensorRefD const &_D,
189  long long _batch_stride_D
190  ):
191  problem_size(_problem_size),
192  alpha(_alpha),
193  A(_A),
194  batch_stride_A(_batch_stride_A),
195  B(_B),
196  batch_stride_B(_batch_stride_B),
197  beta(_beta),
198  C(_C),
199  batch_stride_C(_batch_stride_C),
200  D(_D),
201  batch_stride_D(_batch_stride_D) {}
202 };
203 
204 } // namespace gemm
205 } // namespace cutlass
GEMM problem description.
Definition: gemm_desc.h:50
TensorRef< CType const, 2 > TensorRefC
Tensor reference to C operand.
Definition: gemm_desc.h:74
Definition: convert.h:33
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE GemmDesc(GemmCoord _problem_size, SType _alpha, TensorRefA const &_A, long long _batch_stride_A, TensorRefB const &_B, long long _batch_stride_B, SType _beta, TensorRefC const &_C, long long _batch_stride_C, TensorRefD const &_D, long long _batch_stride_D)
Constructor for strided batch GEMM GEMM.
Definition: gemm_desc.h:179
TensorRefC C
The source matrix C.
Definition: gemm_desc.h:111
SType alpha
The alpha scaling values.
Definition: gemm_desc.h:93
TensorRefA A
The source matrix A.
Definition: gemm_desc.h:96
GemmCoord problem_size
The dimensions of the GEMM.
Definition: gemm_desc.h:90
Definition: gemm_coord.h:43
long long batch_stride_D
batch stride for D operand
Definition: gemm_desc.h:120
TensorRefB B
The source matrix B.
Definition: gemm_desc.h:102
TensorRef< AType const, 2 > TensorRefA
Tensor reference to A operand.
Definition: gemm_desc.h:62
CUTLASS_HOST_DEVICE GemmDesc(GemmCoord _problem_size, SType _alpha, TensorRefA const &_A, TensorRefB const &_B, SType _beta, TensorRefC const &_C, TensorRefD const &_D)
Constructor for basic GEMM with batch count = 1.
Definition: gemm_desc.h:154
DType_ DType
Destination accumulator type.
Definition: gemm_desc.h:77
long long batch_stride_A
batch stride for A operand
Definition: gemm_desc.h:99
SType beta
The beta scaling values.
Definition: gemm_desc.h:108
SType_ SType
Scalar type for alpha and beta.
Definition: gemm_desc.h:83
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
AType_ AType
Source accumulator matrix type.
Definition: gemm_desc.h:59
CType_ CType
Source accumulator matrix type.
Definition: gemm_desc.h:71
TensorRef< BType const, 2 > TensorRefB
Tensor reference to B operand.
Definition: gemm_desc.h:68
long long batch_stride_B
batch stride for B operand
Definition: gemm_desc.h:105
TensorRefD D
The destination matrix D.
Definition: gemm_desc.h:117
TensorRef< DType, 2 > TensorRefD
Tensor reference to D operand.
Definition: gemm_desc.h:80
BType_ BType
Destination accumulator type.
Definition: gemm_desc.h:65
CUTLASS_HOST_DEVICE Index const & batch() const
Returns the GEMM batch coordinate.
Definition: gemm_coord.h:113
CUTLASS_HOST_DEVICE GemmDesc()
Default ctor.
Definition: gemm_desc.h:128
CUTLASS_HOST_DEVICE GemmDesc(Coord< 3 > _problem_size, SType _alpha, TensorRefA const &_A, TensorRefB const &_B, SType _beta, TensorRefC const &_C, TensorRefD const &_D)
Constructor for basic GEMM with batch count = 1.
Definition: gemm_desc.h:132
long long batch_stride_C
batch stride for C operand
Definition: gemm_desc.h:114
Index_ Index
Index type for dimensions and strides.
Definition: gemm_desc.h:56
GemmCoord is a structure derived from Coord<4> that specifies a location within the coordinate system...