63 template <
int kD_ = 1,
int kH_ = 1,
int kW_ = 1,
int kC_ = 1>
66 static int const kD = kD_;
68 static int const kH = kH_;
70 static int const kW = kW_;
72 static int const kC = kC_;
78 template <
typename Shape>
96 template <
typename A_,
int kScale_>
103 template <
typename A_,
typename B_>
110 template <
typename A_,
typename B_>
112 typedef Shape<A_::kD - B_::kD, A_::kH - B_::kH, A_::kW - B_::kW, A_::kC - B_::kC>
Shape;
117 template <
typename A_,
typename B_>
124 template <
typename A_,
typename B_>
126 typedef Shape<A_::kD / B_::kD, A_::kH / B_::kH, A_::kW / B_::kW, A_::kC / B_::kC>
Shape;
131 template <
typename A_,
typename B_>
133 typedef Shape<(A_::kD + B_::kD - 1) / B_::kD,
134 (A_::kH + B_::kH - 1) / B_::kH,
135 (A_::kW + B_::kW - 1) / B_::kW,
136 (A_::kC + B_::kC - 1) / B_::kC>
142 template <
typename A_,
typename B_>
145 (A_::kH > B_::kH ? A_::kH : B_::kH),
146 (A_::kW > B_::kW ? A_::kW : B_::kW),
147 (A_::kC > B_::kC ? A_::kC : B_::kC)>
153 template <
typename A_,
typename B_>
155 typedef Shape<(A_::kD < B_::kD ? A_::kD : B_::kD),
156 (A_::kH < B_::kH ? A_::kH : B_::kH),
157 (A_::kW < B_::kW ? A_::kW : B_::kW),
158 (A_::kC < B_::kC ? A_::kC : B_::kC)>
164 template <
typename Shape_,
int elementsPerAccess>
166 typedef Shape<Shape_::kH * Shape_::kW * Shape_::kC,
167 Shape_::kW * Shape_::kC,
179 template <
typename Shape_>
183 return d * Shape_::kH * Shape_::kW * Shape_::kC +
184 h * Shape_::kW * Shape_::kC +
197 template <
typename Str
ides_>
200 return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
212 template <
typename Threads_,
typename Str
ides_>
214 static CUTLASS_DEVICE
int get() {
216 int c = threadIdx.x % Threads_::kC;
217 int w = threadIdx.x / Threads_::kC % Threads_::kW;
218 int h = threadIdx.x / Threads_::kC / Threads_::kW % Threads_::kH;
219 int d = threadIdx.x / Threads_::kC / Threads_::kW / Threads_::kH;
222 return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
230 template <
int T_h_,
int T_w_,
int T_c_,
int S_h_,
int S_w_,
int S_c_>
232 static CUTLASS_DEVICE
int get() {
234 int c = threadIdx.x % T_c_;
235 int w = threadIdx.x / T_c_ % T_w_;
236 int h = threadIdx.x / T_c_ / T_w_ % T_h_;
239 return h * S_h_ + w * S_w_ + c * S_c_;
248 template <
int T_h_,
int T_w_,
int S_h_,
int S_w_>
250 static CUTLASS_DEVICE
int get() {
252 int w = threadIdx.x % T_w_;
253 int h = threadIdx.x / T_w_;
256 return h * S_h_ + w * S_w_;
Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_. Afterwards compute the offset of those coordinates using Strides_.
Definition: shape.h:213
static int const kWc
The number of elements per row.
Definition: shape.h:81
Shape< Shape_::kH *Shape_::kW *Shape_::kC, Shape_::kW *Shape_::kC, Shape_::kC, elementsPerAccess > Shape
Definition: shape.h:170
Shape< A_::kD+B_::kD, A_::kH+B_::kH, A_::kW+B_::kW, A_::kC+B_::kC > Shape
Definition: shape.h:105
Shape<(A_::kD+B_::kD - 1)/B_::kD,(A_::kH+B_::kH - 1)/B_::kH,(A_::kW+B_::kW - 1)/B_::kW,(A_::kC+B_::kC - 1)/B_::kC > Shape
Definition: shape.h:137
Shape< A_::kD *kScale_, A_::kH *kScale_, A_::kW *kScale_, A_::kC *kScale_ > Shape
Definition: shape.h:98
Shape< A_::kD *B_::kD, A_::kH *B_::kH, A_::kW *B_::kW, A_::kC *B_::kC > Shape
Definition: shape.h:119
Shape< A_::kD - B_::kD, A_::kH - B_::kH, A_::kW - B_::kW, A_::kC - B_::kC > Shape
Definition: shape.h:112
static int const kH
The height of the cube.
Definition: shape.h:68
static int const kC
The number of scalars per element.
Definition: shape.h:72
Compute the offset for the given coordinates in a cube.
Definition: shape.h:180
Shape< A_::kD/B_::kD, A_::kH/B_::kH, A_::kW/B_::kW, A_::kC/B_::kC > Shape
Definition: shape.h:126
static int const kDhw
The number of pixels per cube.
Definition: shape.h:87
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Compute the offset for the given coordinates in a cube.
Definition: shape.h:198
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
static int const kCount
The number of elements in the 4D space.
Definition: shape.h:91
static int const kDhwc
The number of elements in the 4D space.
Definition: shape.h:89
static int const kW
The width of the cube.
Definition: shape.h:70
static int const kHw
The number of pixels per image.
Definition: shape.h:83
static int const kD
The depth of the cube.
Definition: shape.h:66
Shape<(A_::kD > B_::kD ? A_::kD :B_::kD),(A_::kH > B_::kH ? A_::kH :B_::kH),(A_::kW > B_::kW ? A_::kW :B_::kW),(A_::kC > B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:148
Basic include for CUTLASS macros.
Shape<(A_::kD< B_::kD ? A_::kD :B_::kD),(A_::kH< B_::kH ? A_::kH :B_::kH),(A_::kW< B_::kW ? A_::kW :B_::kW),(A_::kC< B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:159
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
static int const kHwc
The number of elements per image.
Definition: shape.h:85