
#include "kernel.cuh"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#define CHECK_CUDA(x)                                                          \
  TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x)                                                    \
  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.")
#define CHECK_INPUT(x)                                                         \
  CHECK_CUDA(x);                                                               \
  CHECK_CONTIGUOUS(x)



torch::Tensor fast_gemv(torch::Tensor x, torch::Tensor W, const int w_zero,
                        unsigned int num_threads_per_row,
                        unsigned int cols_per_warp) {
  // TORCH_CHECK(W_nbits == 8 | W_nbits == 4 | W_nbits == 2 | W_nbits == 14,
  // "Unsupported W_nbits."); TORCH_CHECK((warp_size*warps_per_block) <= 1024,
  // "Invalid warp_sze / warps_per_block.");

  CHECK_INPUT(x);
  CHECK_INPUT(W);

  size_t x_rows = x.size(0);
  size_t x_cols_packed = x.size(1);
  size_t W_rows = W.size(0);
  size_t W_cols_packed = W.size(1);
  // matmul(x, W.T)
  //  [x_rows, x_cols] @ [w_rows, w_cols].T
  size_t W_cols = W_cols_packed * 2;
  size_t x_cols = x_cols_packed * 2;
  TORCH_CHECK(x_rows == 1, "Only batch-size=1 is supported.");
  TORCH_CHECK(x_cols == W_cols, "Invalid x_cols == W_cols.");
  // FIXME: Above checks can be done in python

  auto dev = x.device();
  auto dtype = c10::ScalarType::Int;
  auto y = torch::empty({(int)x_rows, (int)W_rows},
                        torch::TensorOptions().dtype(dtype).device(dev));

  // Inputs / outputs ptr
  const uint4 *x_ptr =
      reinterpret_cast<const uint4 *>(x.const_data_ptr<uint8_t>());
  const uint4 *W_ptr =
      reinterpret_cast<const uint4 *>(W.const_data_ptr<uint8_t>());
  int32_t *y_ptr = reinterpret_cast<int32_t *>(y.data_ptr<int32_t>());

  const unsigned num_per_threads = W_cols / num_threads_per_row;
  dim3 blockDim(num_threads_per_row, cols_per_warp);
  dim3 gridDim((W_rows + cols_per_warp - 1)/cols_per_warp, 1);

  size_t shared_mem_size = x_cols * sizeof(uint4) / 32; // W_rows , x_cols

  // std::cout << "num_per_threads: " << num_per_threads << std::endl;
  // std::cout << "blockDim: " << blockDim.x << " " << blockDim.y << std::endl;
  // std::cout << "gridDim: " << gridDim.x << " " << gridDim.y << std::endl;

  // Launch kernel
  gemv_WeightInt4_ActInt4_OutInt32_dp4a<<<gridDim, blockDim, shared_mem_size>>>(
      x_ptr, W_ptr, y_ptr, w_zero, num_per_threads, x_cols, W_rows);

  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return y;
}
