Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
|
Template performing matrix multiply-add operation within a thread.
#include <hgemm_multiply_add.h>
Public Types | |
typedef Shape< 1, 1, 2, 1 > | InstructionShape |
The shape of the instruction. More... | |
typedef ThreadGemmShape_ | ThreadGemmShape |
The number of accumulators per thread. More... | |
typedef ThreadGemmShape | AccumulatorsPerThread |
Aliased for compatibility. Will be removed for CUTLASS v2.0. More... | |
typedef ThreadsPerWarp_ | ThreadsPerWarp |
The number of threads per warp. More... | |
typedef ShapeMul< ThreadGemmShape, ThreadsPerWarp >::Shape | AccumulatorsPerWarp |
The number of accumulators per warp. More... | |
typedef half | ScalarA |
The type for A. More... | |
typedef Fragment< ScalarA, AccumulatorsPerThread::kW > | FragmentA |
The fragment for A. More... | |
typedef half | ScalarB |
The type for B. More... | |
typedef Fragment< ScalarB, AccumulatorsPerThread::kH > | FragmentB |
The fragment for B. More... | |
typedef half | ScalarC |
The type for C and D. More... | |
typedef Fragment< half, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW > | Accumulators |
The accumulators. More... | |
Public Member Functions | |
CUTLASS_DEVICE | ThreadMultiplyAdd () |
Make sure there's an even number of elements in both dimensions. More... | |
CUTLASS_DEVICE void | multiply_add (FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d) |
Multiply : d = a*b + c. More... | |
typedef Fragment<half, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW> cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >::Accumulators |
typedef ThreadGemmShape cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >::AccumulatorsPerThread |
typedef ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >::AccumulatorsPerWarp |
typedef Fragment<ScalarA, AccumulatorsPerThread::kW> cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >::FragmentA |
typedef Fragment<ScalarB, AccumulatorsPerThread::kH> cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >::FragmentB |
typedef Shape<1, 1, 2, 1> cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >::InstructionShape |
typedef half cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >::ScalarA |
typedef half cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >::ScalarB |
typedef half cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >::ScalarC |
typedef ThreadGemmShape_ cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >::ThreadGemmShape |
typedef ThreadsPerWarp_ cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >::ThreadsPerWarp |
|
inline |
Ctor.
|
inline |