#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>

// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>

// atomicAdd for double-precision floating-point numbers on hardware with
// compute capability < 6.0 from:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
__device__ double atomicAdd(
    double* address,
    double val
) {
  unsigned long long int* address_as_ull = (unsigned long long int*)address;
  unsigned long long int old = *address_as_ull, assumed;

  do {
    assumed = old;
    old = atomicCAS(
      address_as_ull,
      assumed,
      __double_as_longlong(val + __longlong_as_double(assumed))
    );

  // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
  } while (assumed != old);

  return __longlong_as_double(old);
}
#endif

const int BLOCKWIDTH   = 128;
const int BLOCKHEIGHT2 = 8;
const int BLOCKHEIGHT3 = 12;
const int BLOCKHEIGHT4 = 16;

__device__ inline unsigned int as_unsigned(int i) {
  return *reinterpret_cast<unsigned int*>(&i);
}

__device__ inline int as_int(int i) {
  return *reinterpret_cast<int*>(&i);
}

//batched version (2-bit)
__global__ void VecQuant2MatMulKernelNUQPerChannelBatched(
    const float* __restrict__ vec,
    const   int* __restrict__ mat,
          float* __restrict__ mul,
    const float* __restrict__ lookup_table,
    int M,
    int N,
    int K, 
    int num_data
);

//batched version (3-bit)
__global__ void VecQuant3MatMulKernelNUQPerChannelBatched(
    const float* __restrict__ vec,
    const   int* __restrict__ mat,
          float* __restrict__ mul,
    const float* __restrict__ lookup_table,
    int M,
    int N,
    int K,
    int num_data
);

//batched version (4-bit)
__global__ void VecQuant4MatMulKernelNUQPerChannelBatched(
    const float* __restrict__ vec,
    const   int* __restrict__ mat,
          float* __restrict__ mul,
    const float* __restrict__ lookup_table,
    int M,
    int N,
    int K, 
    int num_data
);

__global__ void SPMV_ATOMIC_BATCHED(
  const   int* __restrict__ rows,
  const   int* __restrict__ cols,
  const float* __restrict__ mat,
  const float* __restrict__ vec,
        float* __restrict__ mul,
  const int M,
  const int N,
  const int K
);

// 2-bit batched matvec kernel (LUT-based)
void vecquant2matmul_nuq_perchannel_batched_cuda(
  torch::Tensor vec,
  torch::Tensor mat,
  torch::Tensor mul,
  torch::Tensor lookup_table
) {
  //vec: M * N matrix
  //mat: N * K matrix,  num_data = K * 32 / 2
  int M = vec.size(0);
  int N = vec.size(1);
  int K = mat.size(1);
  int num_data = K * 32 / 2;

  dim3 blocks(
    (M + BLOCKWIDTH - 1) / BLOCKWIDTH,
    N,
    (K + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2
  );
  dim3 threads(BLOCKWIDTH);

  VecQuant2MatMulKernelNUQPerChannelBatched<<<blocks, threads>>>(
    vec.data_ptr<float>(),
    mat.data_ptr<int>(),
    mul.data_ptr<float>(),
    lookup_table.data_ptr<float>(),
    M, N, K, num_data
  );
}

// 3-bit batched matvec kernel (LUT-based)
void vecquant3matmul_nuq_perchannel_batched_cuda(
  torch::Tensor vec,
  torch::Tensor mat,
  torch::Tensor mul,
  torch::Tensor lookup_table
) {
  //vec: M * N matrix
  //mat: N * K matrix,  num_data = K * 32 / 3
  int M = vec.size(0);
  int N = vec.size(1);
  int K = mat.size(1);
  int num_data = K * 32 / 3;

  dim3 blocks(
    (M + BLOCKWIDTH - 1) / BLOCKWIDTH,
    N,
    (K + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3
  );
  dim3 threads(BLOCKWIDTH);

  VecQuant3MatMulKernelNUQPerChannelBatched<<<blocks, threads>>>(
    vec.data_ptr<float>(),
    mat.data_ptr<int>(),
    mul.data_ptr<float>(),
    lookup_table.data_ptr<float>(),
    M, N, K, num_data
  );
}

// 4-bit batched matvec kernel (LUT-based)
void vecquant4matmul_nuq_perchannel_batched_cuda(
  torch::Tensor vec,
  torch::Tensor mat,
  torch::Tensor mul,
  torch::Tensor lookup_table
) {
  //vec: M * N matrix
  //mat: N * K matrix,  num_data = K * 32 / 4
  int M = vec.size(0);
  int N = vec.size(1);
  int K = mat.size(1);
  int num_data = K * 32 / 4;

  dim3 blocks(
    (M + BLOCKWIDTH - 1) / BLOCKWIDTH,
    N,
    (K + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4
  );
  dim3 threads(BLOCKWIDTH);

  VecQuant4MatMulKernelNUQPerChannelBatched<<<blocks, threads>>>(
    vec.data_ptr<float>(),
    mat.data_ptr<int>(),
    mul.data_ptr<float>(),
    lookup_table.data_ptr<float>(),
    M, N, K, num_data
  );
}

//NUQ + Sparse
void vecquant_spmv_cuda(
  torch::Tensor rows,
  torch::Tensor cols,
  torch::Tensor mat,
  torch::Tensor vec,
  torch::Tensor mul,
  int K
) {
  //vec: M * N matrix
  //mat: N * K sparse matrix
  //mul: M * K matrix
  int M = vec.size(0);
  int N = vec.size(1);

  int block_size = BLOCKWIDTH;
  int num_blocks = (K + BLOCKWIDTH - 1) / BLOCKWIDTH;

  SPMV_ATOMIC_BATCHED<<<num_blocks, block_size>>>(
    rows.data<int>(),
    cols.data<int>(),
    mat.data<float>(),
    vec.data<float>(),
    mul.data<float>(),
    M, N, K
  );
}


//batched version (2-bit)
__global__ void VecQuant2MatMulKernelNUQPerChannelBatched(
    const float* __restrict__ vec,
    const   int* __restrict__ mat,
          float* __restrict__ mul,
    const float* __restrict__ lookup_table,
    int M,
    int N,
    int K,
    int num_data
) {
  int m = blockIdx.x;
  int n = blockIdx.y;
  int k = blockIdx.z;
  int vec_row_start = m * BLOCKWIDTH;
  int vec_row_idx = vec_row_start + threadIdx.x;
  __shared__ float blockvec[BLOCKWIDTH];// a row vector of mat with length = BLOCKWIDTH
  if(vec_row_idx < M)
  {
    blockvec[threadIdx.x] = vec[vec_row_idx * N + n];
  }
  __syncthreads();

  //--- threadIdx.x = group_idx * 16 + off ---//
  int group_idx = threadIdx.x / 16; // 0 - 8
  int off = threadIdx.x % 16; // 0 - 15
  int mat_idx = n * K + k * BLOCKHEIGHT2 + group_idx;
  unsigned int tmp = as_unsigned(mat[mat_idx]);
  float mat_value = lookup_table[n * 4 + ((tmp >> (2 * off)) & 0x3)];

  int MAX_ITER = min(BLOCKWIDTH, M - vec_row_start);
  int mul_col_idx = k * BLOCKWIDTH + threadIdx.x;
  for(int i = 0; i < MAX_ITER; i ++)
  {
    atomicAdd(&mul[(vec_row_start + i) * num_data + mul_col_idx], mat_value * blockvec[i]);
  }
}

//batched version (3-bit)
__global__ void VecQuant3MatMulKernelNUQPerChannelBatched(
    const float* __restrict__ vec,
    const   int* __restrict__ mat,
          float* __restrict__ mul,
    const float* __restrict__ lookup_table,
    int M,
    int N,
    int K,
    int num_data
) {
  int m = blockIdx.x;
  int n = blockIdx.y;
  int k = blockIdx.z;
  int vec_row_start = m * BLOCKWIDTH;
  int vec_row_idx = vec_row_start + threadIdx.x;
  __shared__ float blockvec[BLOCKWIDTH];// a row vector of mat with length = BLOCKWIDTH
  if(vec_row_idx < M)
  {
    blockvec[threadIdx.x] = vec[vec_row_idx * N + n];
  }
  __syncthreads();

  unsigned int tmp;
  unsigned int tmp1;
  unsigned int tmp2;

  //--- threadIdx.x = group_idx * 32 + off ---//
  int group_idx = threadIdx.x / 32; // 0, 1, 2, 3
  int off = threadIdx.x % 32; // 0 - 31
  int mat_start_idx = n * K + k * BLOCKHEIGHT3 + group_idx * 3;
  float mat_value;
  if (off < 10) // 0 - 9
  {
    tmp = as_unsigned(mat[mat_start_idx]);
    mat_value = lookup_table[n * 8 + ((tmp >> (3 * off)) & 0x7)];
  }
  else if (off == 10) // 10
  {
    tmp1 = as_unsigned(mat[mat_start_idx]);
    tmp2 = as_unsigned(mat[mat_start_idx + 1]);
    tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
    mat_value = lookup_table[n * 8 + ((tmp) & 0x7)];
  }
  else if (off < 21) // 11 - 20
  {
    tmp = as_unsigned(mat[mat_start_idx + 1]);
    mat_value = lookup_table[n * 8 + ((tmp >> (3 * off - 32)) & 0x7)];
  }
  else if (off == 21) // 21
  {
    tmp1 = as_unsigned(mat[mat_start_idx + 1]);
    tmp2 = as_unsigned(mat[mat_start_idx + 2]);
    tmp = (tmp1 >> 31) | ((tmp2 << 1) & 0x6);
    mat_value = lookup_table[n * 8 + ((tmp) & 0x7)];
  }
  else // 22 - 31
  {
    tmp = as_unsigned(mat[mat_start_idx + 2]);
    mat_value = lookup_table[n * 8 + ((tmp >> (3 * off - 64)) & 0x7)];
  }

  int MAX_ITER = min(BLOCKWIDTH, M - vec_row_start);
  int mul_col_idx = k * BLOCKWIDTH + threadIdx.x;
  for(int i = 0; i < MAX_ITER; i ++)
  {
    atomicAdd(&mul[(vec_row_start + i) * num_data + mul_col_idx], mat_value * blockvec[i]);
  }
}

//batched version (4-bit)
__global__ void VecQuant4MatMulKernelNUQPerChannelBatched(
    const float* __restrict__ vec,
    const   int* __restrict__ mat,
          float* __restrict__ mul,
    const float* __restrict__ lookup_table,
    int M,
    int N,
    int K,
    int num_data
) {
  int m = blockIdx.x;
  int n = blockIdx.y;
  int k = blockIdx.z;
  int vec_row_start = m * BLOCKWIDTH;
  int vec_row_idx = vec_row_start + threadIdx.x;
  __shared__ float blockvec[BLOCKWIDTH];// a row vector of mat with length = BLOCKWIDTH
  if(vec_row_idx < M)
  {
    blockvec[threadIdx.x] = vec[vec_row_idx * N + n];
  }
  __syncthreads();

  //--- threadIdx.x = group_idx * 8 + off ---//
  int group_idx = threadIdx.x / 8; // 0 - 15
  int off = threadIdx.x % 8; // 0 - 7
  int mat_idx = n * K + k * BLOCKHEIGHT4 + group_idx;
  unsigned int tmp = as_unsigned(mat[mat_idx]);
  float mat_value = lookup_table[n * 16 + ((tmp >> (4 * off)) & 0xf)];

  int MAX_ITER = min(BLOCKWIDTH, M - vec_row_start);
  int mul_col_idx = k * BLOCKWIDTH + threadIdx.x;
  for(int i = 0; i < MAX_ITER; i ++)
  {
    atomicAdd(&mul[(vec_row_start + i) * num_data + mul_col_idx], mat_value * blockvec[i]);
  }
}



__global__ void SPMV_ATOMIC_BATCHED(
  const   int* __restrict__ rows,
  const   int* __restrict__ cols,
  const float* __restrict__ mat,
  const float* __restrict__ vec,
        float* __restrict__ mul,
  const int M,
  const int N,
  const int K
) {
    int row = blockIdx.x * blockDim.x + threadIdx.x;
    if (row >= K) return;
    int start_elem = rows[row];
    int end_elem = rows[row + 1];
    for(int b = 0; b < M; b ++)
    {
      float dot = 0;
      for (int i = start_elem; i < end_elem; i++)
      {
        dot += vec[b * N + cols[i]] * mat[i];
      }
      atomicAdd(&mul[b * K + row], dot);
    }
}



