#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include <iostream>

#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"

#include "fused_quantize_host.h"
#include "quartet/gemm/device/gemm_quant.h"

namespace QUTLASS::quartet {

using ElementInputA = cutlass::bfloat16_t;
using ElementInputB = cutlass::bfloat16_t;
using ElementGemmOutput = cutlass::bfloat16_t; //FIXME: float
using ElementOutput = cutlass::float_e2m1_t; //FIXME: bfloat16_t
using ElementAuxOutput = ElementOutput;

using ElementAccumulator = float;
using ElementComputeEpilogue = float;

using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::RowMajor;
using LayoutOutput = cutlass::layout::RowMajor;

template <typename ShapeMMAThreadBlock, typename ShapeMMAWarp, typename InstructionShape>
using Gemm_ =
    cutlass::gemm::device::quartet::GemmQuantBwd<
        ElementInputA, LayoutInputA,
        ElementInputB, LayoutInputB,
        ElementGemmOutput, LayoutOutput, //ElementGemmOutput
        ElementOutput, LayoutOutput, //ElementOuput
        ElementAccumulator,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        ShapeMMAThreadBlock,
        ShapeMMAWarp,
        InstructionShape
    >;

// Command line options parsing
struct Options {

  cutlass::gemm::GemmCoord problem_size;

  float alpha;
  float beta;

  Options(int M, int N, int K, float scale=1.f):
    beta(0.f)
  {
    problem_size = cutlass::gemm::GemmCoord{M, N, K};
    alpha = scale;
  }

  /// Compute performance in GFLOP/s
  float gflops(float runtime_s) const {
    // Two flops per multiply-add
    return 2.0f * float(problem_size.product()) / float(1.0e9) / runtime_s;
  }
};

/// Helper class to run the kernel
template <typename Gemm>
struct TestbedRunner {
  uint64_t seed;


  //
  // Methods
  //

  TestbedRunner() { }


  bool run(
    Options& options,
    torch::Tensor out,  // FP32/FP16/BF16 (TODO)
    torch::Tensor out_sf,
    torch::Tensor x,    // float_e4m3_t
    torch::Tensor y,     // float_e4m3_t
    int32_t M, int32_t N, int32_t K)
  {

    using GemmCoord = cutlass::gemm::GemmCoord;

    Gemm gemmOp;

    typename Gemm::Arguments arguments{
      {static_cast<GemmCoord::Index>(M),
       static_cast<GemmCoord::Index>(N),
       static_cast<GemmCoord::Index>(K)},
      {(cutlass::bfloat16_t *)x.data_ptr(), K},
      {(cutlass::bfloat16_t *)y.data_ptr(), N},
      {(cutlass::float_e2m1_t *)out.data_ptr(), N}, //FIXME: bfloat16_t
      {(cutlass::float_e2m1_t *)out.data_ptr(), N}, //FIXME: bfloat16_t
      {(cutlass::float_ue8m0_t *)out_sf.data_ptr(), M}, //FIXME: bfloat16_t
      cutlass::bfloat16_t(0) //FIXME: float
    };

    auto status = gemmOp(arguments);

    TORCH_CHECK(status == cutlass::Status::kSuccess,
              cutlassGetStatusString(status))

    return true;
  }

};


void fusedQuantize_bwd_host(torch::Tensor& D,
                            torch::Tensor& D_sf,
                            torch::Tensor const& A,
                            torch::Tensor const& B)
{
  //int32_t M = A.size(0)*A.size(1)/32;
  int32_t M = A.numel() / 32;
  int32_t N = B.size(1);
  int32_t K = 32;

  Options options(M, N, K, 1.f);

  using TileShape = typename cutlass::gemm::GemmShape<128, 32, 32>;
  using WarpShape = typename cutlass::gemm::GemmShape<32, 32, 32>; // Important: w_n == 32
  using MmaShape  = typename cutlass::gemm::GemmShape<16, 8, 16>;

  TestbedRunner<Gemm_<TileShape, WarpShape, MmaShape>> testbed_fast_accum;
  bool result = testbed_fast_accum.run(options, D, D_sf, A, B, M, N, K);
}

} // namespace QUTLASS::quartet