31 #if !defined(__CUDACC_RTC__) 44 template <
typename batched_reduction_>
45 __global__
__launch_bounds__(batched_reduction_::Traits::kThreads, 1) void batched_reduction_kernel(typename batched_reduction_::Params params) {
47 batched_reduction_ batched_reduction(params);
48 batched_reduction.run();
51 template <
typename BatchedReductionTraits_>
56 typedef BatchedReductionTraits_
Traits;
58 typedef typename Traits::Params
Params;
68 CUTLASS_DEVICE
void run() {
69 #if (__CUDA_ARCH__ >= 600) 71 typename Traits::BlockSwizzle block_swizzle;
73 block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::SubTile>());
75 int subTileSize = gridDim.x * Traits::SubTile::kW;
76 int tileSize =
params.problem_size[1] *
params.problem_size[2];
77 int subTileOffset = threadblock_offset[2] + threadIdx.x * Traits::ThreadShape::kW;
81 typename Traits::ScalarA inRegs[Traits::maxInReg];
82 typename Traits::ScalarAccum AccumRegs[Traits::maxOutReg];
84 for (
int subTile = 0; subTile < tileSize; subTile += subTileSize) {
85 int tileOffset = subTileBase + subTileOffset;
87 for (
int i = 0; i < Traits::ThreadShape::kW; i++)
88 AccumRegs[i] = static_cast<typename Traits::ScalarAccum>(0.0f);
90 typename Traits::ScalarAccum c0[Traits::ThreadShape::kW];
91 for (
int i = 0; i< Traits::ThreadShape::kW; i++)
92 c0[i] = static_cast<typename Traits::ScalarAccum>(
params.d_c[tileOffset + i]);
96 for (
int s = 0; s < Traits::ReductionSize; s++) {
97 int inRegOffset = s * Traits::ThreadShape::kW;
98 int dOffset = (s * tileSize) + tileOffset;
100 for (
int i = 0; i< Traits::ThreadShape::kW; i++) {
101 inRegs[inRegOffset + i] =
params.d_a[dOffset + i];
107 for (
int s = 0; s < Traits::ReductionSize; s++) {
108 int inRegOffset = s * Traits::ThreadShape::kW;
110 for (
int i = 0; i < Traits::ThreadShape::kW; i++) {
113 AccumRegs[i] =
static_cast<typename Traits::ScalarAccum
>(inRegs[inRegOffset + i]) + AccumRegs[i];
117 functor_caller<Traits::ThreadShapeMultiple2>(AccumRegs, c0, AccumRegs);
121 for (
int i = 0; i < Traits::ThreadShape::kW; i++) {
122 params.d_d[tileOffset + i] =
static_cast<typename Traits::ScalarD
>(AccumRegs[i]);
126 subTileBase += subTileSize;
128 #endif //#if (__CUDA_ARCH__ >= 600) 131 template<
bool ThreadShapeMultiple2>
132 CUTLASS_DEVICE
void functor_caller(
typename Traits::ScalarAccum
const *accum,
typename Traits::ScalarAccum
const *old,
typename Traits::ScalarAccum *output) {
133 if (ThreadShapeMultiple2 ==
true) {
134 for (
int i = 0; i < Traits::ThreadShape::kW / 2; i++) {
135 functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 2>(&accum[2 * i], &old[2 * i], &output[2 * i]);
139 for (
int i = 0; i < Traits::ThreadShape::kW; i++) {
140 functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 1>(&accum[i], &old[i], &output[i]);
148 #if !defined(__CUDACC_RTC__) 151 cudaStream_t stream = cudaStreamDefault) {
153 typename Traits::BlockSwizzle block_swizzle;
154 dim3 grid = block_swizzle.get_grid_layout(
params.problem_size,
155 make_Coord_from_shape<typename Traits::OutputTile>());
158 block.x = Traits::kThreads;
159 batched_reduction_kernel<This_><<<grid, block, 0, stream>>>(
params);
160 return cudaGetLastError();
Params const & params
The params.
Definition: batched_reduction.h:169
__global__ __launch_bounds__(batched_reduction_::Traits::kThreads, 1) void batched_reduction_kernel(typename batched_reduction_
Definition: batched_reduction.h:45
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_DEVICE void run()
Definition: batched_reduction.h:68
BatchedReduction< BatchedReductionTraits_ > This_
This class.
Definition: batched_reduction.h:54
Functor functor
Definition: batched_reduction.h:171
Definition: batched_reduction.h:52
CUTLASS_DEVICE BatchedReduction(Params const ¶ms_)
ctor
Definition: batched_reduction.h:63
Traits::Params Params
Params.
Definition: batched_reduction.h:58
CUTLASS_DEVICE void functor_caller(typename Traits::ScalarAccum const *accum, typename Traits::ScalarAccum const *old, typename Traits::ScalarAccum *output)
Definition: batched_reduction.h:132
Traits::Functor Functor
functor
Definition: batched_reduction.h:60
BatchedReductionTraits_ Traits
The traits.
Definition: batched_reduction.h:56
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...
static __host__ cudaError_t launch(Params const ¶ms, cudaStream_t stream=cudaStreamDefault)
Launch the kernel.
Definition: batched_reduction.h:150