#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/arch/wmma.h"

#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"

////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace device {

////////////////////////////////////////////////////////////////////////////////

template <
  typename OperatorClass,
  typename ArchTag,
  typename ElementA,
  typename ElementB,
  typename ElementC,
  typename ElementAccumulator
>
struct DefaultGemmConfiguration;

////////////////////////////////////////////////////////////////////////////////

template <
  typename ElementC>
struct DefaultGemmConfiguration<
  arch::OpClassTensorOp,
  arch::Sm80,
  float_e2m1_t,
  float_e2m1_t,
  ElementC,
  float> {

  static int const kAlignmentA = 128 / sizeof_bits<float_e2m1_t>::value;
  static int const kAlignmentB = 128 / sizeof_bits<float_e2m1_t>::value;

  using ThreadblockShape = GemmShape<128, 256, 128>;
  using WarpShape = GemmShape<64, 64, 128>;
  using InstructionShape = GemmShape<16, 8, 64>;
  static int const kStages = 3;

  using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
      ElementC, 128 / sizeof_bits<ElementC>::value, float, float>;

  using Operator = arch::OpMultiplyAddSaturate;
};

} // namespace device
} // namespace gemm
} // namespace cutlass

////////////////////////////////////////////////////////////////////////////////
