#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/warp/mma_mixed_input_tensor_op.h"

#include "cutlass_extensions/gemm/warp/mma_mx_tensor_op.h"
#include "cutlass_extensions/gemm/warp/default_mma_mx_tensor_op.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace warp {

/////////////////////////////////////////////////////////////////////////////////////////////////


/// Partial Specialization - inputs are mixed types  - uses wider datatype internally.
/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32)
template <
    /// Shape of one matrix production operation (concept: GemmShape)
    typename WarpShape_,
    /// Element type of A matrix
    typename ElementA,
    /// Layout of A matrix (concept: MatrixLayout)
    typename LayoutA,
    /// Element type of B matrix
    typename ElementB,
    /// Layout of B matrix (concept: MatrixLayout)
    typename LayoutB,
    /// Element type of C matrix
    typename ElementC,
    /// Layout of C matrix (concept: MatrixLayout)
    typename LayoutC,
    /// Number of partitions along K dimension
    int PartitionsK,
    /// Store the accumulators in row major or column major.  Row major is used
    /// when output layout is interleaved.
    bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<
  WarpShape_,
  GemmShape<16, 8, 16>,                 // InstructionShape
  ElementA,                             // Element type of A matrix in Global Memory
  LayoutA,                              // Layout of A matrix in Global Memory
  ElementB,                             // Element type of B matrix in Global Memory
  LayoutB,                              // Layout of B matrix in Global Memory
  ElementC,                             // Element type of C matrix in Global Memory
  LayoutC,                              // Layout of C matrix in Global Memory
  arch::OpMultiplyAddMixedInputUpcast,  // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
  PartitionsK, AccumulatorsInRowMajor> {


  // Check if the ElementA and ElementB are of different data types
  static_assert(!platform::is_same<ElementA, ElementB>::value,
    "DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");

  // Data type used for internal computation - use the wider of the two data types for mma.sync operands
  using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
                                                    ElementA, ElementB>::type;

  // Operand datatypes in the internal MMA instruction - use the wider of the two data types
  using ElementAMma = ElementOperand;
  using ElementBMma = ElementOperand;
  using MmaElementC = ElementC;

  // Uses
  using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
      cutlass::arch::Mma<
        GemmShape<16, 8, 16>,
        32,
        ElementAMma, cutlass::layout::RowMajor,
        ElementBMma, cutlass::layout::ColumnMajor,
        MmaElementC, cutlass::layout::RowMajor,
        arch::OpMultiplyAdd
      >,
      cutlass::MatrixShape<1, 1> >;

  // Define the warp-level tensor op
  using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
      WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
      Policy, PartitionsK, AccumulatorsInRowMajor>;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Partial Specialization - inputs are mixed types  - uses wider datatype internally.
/// (e.g. S32 <= S4 x S8 + S32, S32 <= S8 x S4 + S32)
template <
    /// Shape of one matrix production operation (concept: GemmShape)
    typename WarpShape_,
    /// Element type of A matrix
    typename ElementA,
    /// Layout of A matrix (concept: MatrixLayout)
    typename LayoutA,
    /// Element type of B matrix
    typename ElementB,
    /// Layout of B matrix (concept: MatrixLayout)
    typename LayoutB,
    /// Element type of C matrix
    typename ElementC,
    /// Layout of C matrix (concept: MatrixLayout)
    typename LayoutC,
    /// Number of partitions along K dimension
    int PartitionsK,
    /// Store the accumulators in row major or column major.  Row major is used
    /// when output layout is interleaved.
    bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<
  WarpShape_,
  GemmShape<16, 8, 32>,                 // InstructionShape
  ElementA,                             // Element type of A matrix in Global Memory
  LayoutA,                              // Layout of A matrix in Global Memory
  ElementB,                             // Element type of B matrix in Global Memory
  LayoutB,                              // Layout of B matrix in Global Memory
  ElementC,                             // Element type of C matrix in Global Memory
  LayoutC,                              // Layout of C matrix in Global Memory
  arch::OpMultiplyAddMixedInputUpcast,  // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
  PartitionsK, AccumulatorsInRowMajor> {


  // Check if the ElementA and ElementB are of different data types
  static_assert(!platform::is_same<ElementA, ElementB>::value,
    "DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");

  // Data type used for internal computation - use the wider of the two data types for mma.sync operands
  using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
                                                    ElementA, ElementB>::type;

  // Operand datatypes in the internal MMA instruction - use the wider of the two data types
  using MmaElementA = ElementOperand;
  using MmaElementB = ElementOperand;
  using MmaElementC = ElementC;

  // Uses
  using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
      cutlass::arch::Mma<
        GemmShape<16, 8, 32>,
        32,
        MmaElementA, cutlass::layout::RowMajor,
        MmaElementB, cutlass::layout::ColumnMajor,
        MmaElementC, cutlass::layout::RowMajor,
        arch::OpMultiplyAddSaturate
      >,
      cutlass::MatrixShape<1, 1> >;

  // Define the warp-level tensor op
  using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
      WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
      Policy, PartitionsK, AccumulatorsInRowMajor>;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace warp
} // namespace gemm
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////

#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h"

/////////////////////////////////////////////////////////////////////////////////////////////////
