#pragma once

#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/layout/permute.h"
#include "cutlass/numeric_types.h"

#include "cutlass_extensions/epilogue/thread/linear_combination_quant.h"
#include "cutlass_extensions/gemm/kernel/default_gemm_quant.h"
////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////

template <
    /// Element type for A matrix operand
    typename ElementA_,
    /// Layout type for A matrix operand
    typename LayoutA_,
    /// Element type for B matrix operand
    typename ElementB_,
    /// Layout type for B matrix operand
    typename LayoutB_,
    /// Element type for C and D matrix operands
    typename ElementC_,
    /// Layout type for C and D matrix operands
    typename LayoutC_,
    ///
    typename ElementOut_,
    ///
    typename LayoutOut_,
    /// Element type for internal accumulation
    typename ElementAccumulator_ = ElementC_,
    /// Operator class tag
    typename OperatorClass_ = arch::OpClassTensorOp,
    /// Tag indicating architecture to tune for
    typename ArchTag_ = arch::Sm80, //FIXME:
    /// Threadblock-level tile size (concept: GemmShape)
    typename ThreadblockShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::ThreadblockShape,
    /// Warp-level tile size (concept: GemmShape)
    typename WarpShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::WarpShape,
    /// Instruction-level tile size (concept: GemmShape)
    typename InstructionShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::InstructionShape,
    bool is_quartet = true,
    int RotationSize = 32,
    /// Epilogue output operator
    typename EpilogueOutputOp_ =
        cutlass::epilogue::thread::LinearCombinationQuantWushMx<
            ElementOut_,
            128 / cutlass::sizeof_bits<ElementC_>::value,
            ElementAccumulator_,
            ElementC_,
            cutlass::epilogue::thread::MyScaleType::Quantize,
            cutlass::FloatRoundStyle::round_to_nearest, //RLC: change?
            ElementC_>,
    /// Threadblock-level swizzling operator
    typename ThreadblockSwizzle_ =
        typename threadblock::GemmIdentityThreadblockSwizzle<>,
    /// Number of stages used in the pipelined mainloop
    int Stages = //1,
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementC_, ElementAccumulator_>::kStages,
    /// Access granularity of A matrix in units of elements
    int AlignmentA =
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementC_, ElementAccumulator_>::kAlignmentA,
    /// Access granularity of B matrix in units of elements
    int AlignmentB =
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementC_, ElementAccumulator_>::kAlignmentB,
    /// If true, kernel supports split-K with serial reduction
    bool SplitKSerial = false,
    /// Operation performed by GEMM
    typename Operator_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::Operator,
    /// Gather operand A by using an index array
    bool GatherA = false,
    /// Gather operand B by using an index array
    bool GatherB = false,
    /// Scatter result D by using an index array
    bool ScatterD = false,
    /// Permute result D
    typename PermuteDLayout = layout::NoPermute>
class GemmQuantWushMx {
 public:
  using ElementA = ElementA_;
  using LayoutA = LayoutA_;
  using TensorRefA = TensorRef<ElementA const, LayoutA>;
  using ElementB = ElementB_;
  using LayoutB = LayoutB_;
  using TensorRefB = TensorRef<ElementB const, LayoutB>;
  using ElementC = ElementC_;
  using LayoutC = LayoutC_;
  using ElementOut = ElementOut_;
  using LayoutOut = LayoutOut_;
  using TensorRefC = TensorRef<ElementOut const, LayoutOut>;
  using TensorRefD = TensorRef<ElementOut, LayoutOut>;
  using ElementAccumulator = ElementAccumulator_;
  using OperatorClass = OperatorClass_;
  using ArchTag = ArchTag_;
  using ThreadblockShape = ThreadblockShape_;
  using WarpShape = WarpShape_;
  using InstructionShape = InstructionShape_;
  using EpilogueOutputOp = EpilogueOutputOp_;
  using ThreadblockSwizzle = ThreadblockSwizzle_;
  using Operator = Operator_;
  static int const kStages = Stages;
  static int const kAlignmentA = AlignmentA;
  static int const kAlignmentB = AlignmentB;
  static int const kAlignmentC = EpilogueOutputOp::kCount;
  static bool const kSplitKSerial = SplitKSerial;
  static ComplexTransform const kTransformA = ComplexTransform::kNone;
  static ComplexTransform const kTransformB = ComplexTransform::kNone;

  /// Define the kernel
  using GemmKernel = typename kernel::DefaultGemmQuantWushMx<
      ElementA, LayoutA, kAlignmentA,
      ElementB, LayoutB, kAlignmentB,
      ElementC, LayoutC,
      ElementOut, LayoutOut,
      ElementAccumulator,
      OperatorClass,
      ArchTag,
      ThreadblockShape, WarpShape, InstructionShape,
      EpilogueOutputOp,
      ThreadblockSwizzle,
      kStages,
      kSplitKSerial,
      Operator,
      SharedMemoryClearOption::kNone,
      GatherA, GatherB, ScatterD, is_quartet, RotationSize, PermuteDLayout>::GemmKernel;

  /// Argument structure
  struct Arguments {
    //
    // Data members
    //

    GemmCoord problem_size;
    TensorRef<ElementA const, LayoutA> ref_A;
    TensorRef<ElementB const, LayoutB> ref_B;
    TensorRef<ElementOut const, LayoutOut> ref_C;
    TensorRef<ElementOut, LayoutOut> ref_D;
    TensorRef<cutlass::float_ue8m0_t, LayoutC> ref_D_sf;
    typename EpilogueOutputOp::Params epilogue;
    int split_k_slices;
    // For gather+scatter operations
    int const *gather_A_indices;
    int const *gather_B_indices;
    int const *scatter_D_indices;

    //
    // Methods
    //

    /// Default ctor
    CUTLASS_HOST_DEVICE
    Arguments() : problem_size(0, 0, 0), split_k_slices(1) {}

    /// Constructs an Arguments structure
    CUTLASS_HOST_DEVICE
    Arguments(GemmCoord problem_size_,
              TensorRef<ElementA const, LayoutA> ref_A_,
              TensorRef<ElementB const, LayoutB> ref_B_,
              TensorRef<ElementOut const, LayoutOut> ref_C_,
              TensorRef<ElementOut, LayoutOut> ref_D_,
              TensorRef<cutlass::float_ue8m0_t, LayoutC> ref_D_sf_,
              typename EpilogueOutputOp::Params epilogue_ =
                  typename EpilogueOutputOp::Params(),
              int split_k_slices = 1,
              int const *gather_A_indices_ = nullptr,
              int const *gather_B_indices_ = nullptr,
              int const *scatter_D_indices_ = nullptr)
        : problem_size(problem_size_),
          ref_A(ref_A_),
          ref_B(ref_B_),
          ref_C(ref_C_),
          ref_D(ref_D_),
          ref_D_sf(ref_D_sf_),
          epilogue(epilogue_),
          split_k_slices(split_k_slices),
          gather_A_indices(gather_A_indices_),
          gather_B_indices(gather_B_indices_),
          scatter_D_indices(scatter_D_indices_) {}
  };

 private:
  /// Kernel parameters object
  typename GemmKernel::Params params_;

 public:
  /// Constructs the GEMM.
  GemmQuantWushMx() {}

  /// Determines whether the GEMM can execute the given problem.
  static Status can_implement(Arguments const &args) {
    if (!kSplitKSerial && args.split_k_slices > 1) {
      return Status::kErrorInvalidProblem;
    }

    //TODO (later): include
    /* Status status = GemmKernel::can_implement(
        args.problem_size, args.ref_A.non_const_ref(),
        args.ref_B.non_const_ref(), args.ref_C.non_const_ref(), args.ref_D,
        args.ref_row_vec.non_const_ref(), args.ref_col_vec.non_const_ref(),
        args.ref_vec_a_add.non_const_ref(), args.ref_vec_b_add.non_const_ref());

    if (status != Status::kSuccess) {
      return status;
    } */

    return Status::kSuccess;
  }

  /// Gets the workspace size
  static size_t get_workspace_size(Arguments const &args) {
    size_t bytes = 0;

    // Determine grid shape
    ThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
        args.problem_size,
        {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
        args.split_k_slices);

    if (kSplitKSerial && args.split_k_slices > 1) {
      bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
    }

    return bytes;
  }

  /// Initializes GEMM state from arguments.
  Status initialize(Arguments const &args, void *workspace = nullptr,
                    cudaStream_t stream = nullptr) {
    // Determine grid shape
    ThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
        args.problem_size,
        {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
        args.split_k_slices);

    if (kSplitKSerial) {
      if (args.split_k_slices > 1) {
        if (!workspace) {
          return Status::kErrorWorkspaceNull;
        }

        size_t bytes = get_workspace_size(args);

        cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);

        if (result != cudaSuccess) {
          return Status::kErrorInternal;
        }
      }
    } else {
      if (args.split_k_slices > 1) {
        return Status::kErrorInvalidProblem;
      }
    }

    // Initialize the Params structure
    params_ = typename GemmKernel::Params{args.problem_size,
                                          grid_shape,
                                          args.ref_A.non_const_ref(),
                                          args.ref_B.non_const_ref(),
                                          args.ref_C.non_const_ref(),
                                          args.ref_D,
                                          args.ref_D_sf,
                                          args.epilogue,
                                          static_cast<int *>(workspace),
                                          args.gather_A_indices,
                                          args.gather_B_indices,
                                          args.scatter_D_indices};

    return Status::kSuccess;
  }

  /// Lightweight update given a subset of arguments
  Status update(Arguments const &args, void *workspace = nullptr) {
    if (kSplitKSerial && args.split_k_slices > 1) {
      if (!workspace) {
        return Status::kErrorWorkspaceNull;
      }
    }

    params_.ref_A.reset(args.ref_A.non_const_ref().data());
    params_.ref_B.reset(args.ref_B.non_const_ref().data());
    params_.ref_C.reset(args.ref_C.non_const_ref().data());
    params_.ref_D.reset(args.ref_D.data());
    params_.ref_D_sf.reset(args.ref_D_sf.data());
    params_.output_op = args.epilogue;
    params_.semaphore = static_cast<int *>(workspace);

    return Status::kSuccess;
  }

  /// Runs the kernel using initialized state.
  Status run(cudaStream_t stream = nullptr) {
    ThreadblockSwizzle threadblock_swizzle;

    dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
    dim3 block(GemmKernel::kThreadCount, 1, 1);

    cudaError_t result;

    int smem_size = int(sizeof(typename GemmKernel::SharedStorage));

    if (smem_size >= (48 << 10)) {
      result = cudaFuncSetAttribute(Kernel<GemmKernel>,
                                    cudaFuncAttributeMaxDynamicSharedMemorySize,
                                    smem_size);

      if (result != cudaSuccess) {
        return Status::kErrorInternal;
      }
    }

    cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);

    result = cudaGetLastError();

    return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
  }

  /// Runs the kernel using initialized state.
  Status operator()(cudaStream_t stream = nullptr) { return run(stream); }

  /// Runs the kernel using initialized state.
  Status operator()(Arguments const &args, void *workspace = nullptr,
                    cudaStream_t stream = nullptr) {
    Status status = initialize(args, workspace, stream);

    if (status == Status::kSuccess) {
      status = run(stream);
    }

    return status;
  }
};

////////////////////////////////////////////////////////////////////////////////

}  // namespace device
}  // namespace gemm
}  // namespace cutlass

////////////////////////////////////////////////////////////////////////////////
