#include <cuda.h>
#include <cuda_runtime.h>

// ==------------------------------------------------------------------------==
// Warp-reduce
// ==------------------------------------------------------------------------==

template <typename dtype>
__device__ __forceinline__ dtype warpReduceSum(dtype sum, size_t num_threads) {
  constexpr unsigned int mask = 0xffffffff;
  if (num_threads >= 32)
    sum += __shfl_xor_sync(mask, sum, 16);
  if (num_threads >= 16)
    sum += __shfl_xor_sync(mask, sum, 8);
  if (num_threads >= 8)
    sum += __shfl_xor_sync(mask, sum, 4);
  if (num_threads >= 4)
    sum += __shfl_xor_sync(mask, sum, 2);
  if (num_threads >= 2)
    sum += __shfl_xor_sync(mask, sum, 1);
  return sum;
}

// ==------------------------------------------------------------------------==
// GEMV
// ==------------------------------------------------------------------------==
__global__ void gemv_WeightInt4_ActInt4_OutInt32_dp4a(
    const uint4 *__restrict__ x_ptr, const uint4 *__restrict__ W_ptr,
    int32_t *y_ptr, const int w_zero, const unsigned int num_per_threads,
    const unsigned int x_cols, const unsigned int W_rows) {
  int32_t sum = 0;
  unsigned int thread_idx = threadIdx.x;
  unsigned int row = blockIdx.x * blockDim.y + threadIdx.y;
  // each time we load one uint4, it includes 32 uint4 elements
  // each thread needs handle num_per_threads elements
  // so we need to load num_per_threads/32 uint4
  const unsigned num_elements_per_uint4 = 32;
  unsigned int iters = num_per_threads / num_elements_per_uint4;
  // Move w to current row
  W_ptr += row * (x_cols / num_elements_per_uint4);
  unsigned int start_idx = threadIdx.x;
  unsigned int num_threads_per_row = blockDim.x;

  extern __shared__ uint4 x_shared[];
// Load x
#pragma unroll
  for (size_t iter = 0; iter < iters; iter++) {
    unsigned int x_index = (start_idx + iter * num_threads_per_row);
    x_shared[x_index] = x_ptr[x_index];
  }
  __syncthreads();

#pragma unroll
  for (int iter = 0; iter < iters; iter++) {
    unsigned int x_index = (start_idx + iter * num_threads_per_row);

    // load w
    const uint4 _w = __ldg(&W_ptr[x_index]);
    uint4 _x = x_shared[x_index];
    // TODO: if zero is not 0, replace uchar4 with char4
    uchar4 x1_x, w1_x;
    x1_x.x = static_cast<unsigned char>((_x.x >> 28) & 0x0F);
    w1_x.x = static_cast<unsigned char>((_w.x >> 28) & 0x0F);
    x1_x.y = static_cast<unsigned char>((_x.x >> 24) & 0x0F);
    w1_x.y = static_cast<unsigned char>((_w.x >> 24) & 0x0F);
    x1_x.z = static_cast<unsigned char>((_x.x >> 20) & 0x0F);
    w1_x.z = static_cast<unsigned char>((_w.x >> 20) & 0x0F);
    x1_x.w = static_cast<unsigned char>((_x.x >> 16) & 0x0F);
    w1_x.w = static_cast<unsigned char>((_w.x >> 16) & 0x0F);

    uchar4 x2_x, w2_x;
    x2_x.x = static_cast<unsigned char>((_x.x >> 12) & 0x0F);
    w2_x.x = static_cast<unsigned char>((_w.x >> 12) & 0x0F);
    x2_x.y = static_cast<unsigned char>((_x.x >> 8) & 0x0F);
    w2_x.y = static_cast<unsigned char>((_w.x >> 8) & 0x0F);
    x2_x.z = static_cast<unsigned char>((_x.x >> 4) & 0x0F);
    w2_x.z = static_cast<unsigned char>((_w.x >> 4) & 0x0F);
    x2_x.w = static_cast<unsigned char>(_x.x & 0x0F);
    w2_x.w = static_cast<unsigned char>(_w.x & 0x0F);

    uchar4 x3_y, w3_y;
    x3_y.x = static_cast<unsigned char>((_x.y >> 28) & 0x0F);
    w3_y.x = static_cast<unsigned char>((_w.y >> 28) & 0x0F);
    x3_y.y = static_cast<unsigned char>((_x.y >> 24) & 0x0F);
    w3_y.y = static_cast<unsigned char>((_w.y >> 24) & 0x0F);
    x3_y.z = static_cast<unsigned char>((_x.y >> 20) & 0x0F);
    w3_y.z = static_cast<unsigned char>((_w.y >> 20) & 0x0F);
    x3_y.w = static_cast<unsigned char>((_x.y >> 16) & 0x0F);
    w3_y.w = static_cast<unsigned char>((_w.y >> 16) & 0x0F);

    uchar4 x4_y, w4_y;
    x4_y.x = static_cast<unsigned char>((_x.y >> 12) & 0x0F);
    w4_y.x = static_cast<unsigned char>((_w.y >> 12) & 0x0F);
    x4_y.y = static_cast<unsigned char>((_x.y >> 8) & 0x0F);
    w4_y.y = static_cast<unsigned char>((_w.y >> 8) & 0x0F);
    x4_y.z = static_cast<unsigned char>((_x.y >> 4) & 0x0F);
    w4_y.z = static_cast<unsigned char>((_w.y >> 4) & 0x0F);
    x4_y.w = static_cast<unsigned char>(_x.y & 0x0F);
    w4_y.w = static_cast<unsigned char>(_w.y & 0x0F);

    uchar4 x5_z, w5_z;
    x5_z.x = static_cast<unsigned char>((_x.z >> 28) & 0x0F);
    w5_z.x = static_cast<unsigned char>((_w.z >> 28) & 0x0F);
    x5_z.y = static_cast<unsigned char>((_x.z >> 24) & 0x0F);
    w5_z.y = static_cast<unsigned char>((_w.z >> 24) & 0x0F);
    x5_z.z = static_cast<unsigned char>((_x.z >> 20) & 0x0F);
    w5_z.z = static_cast<unsigned char>((_w.z >> 20) & 0x0F);
    x5_z.w = static_cast<unsigned char>((_x.z >> 16) & 0x0F);
    w5_z.w = static_cast<unsigned char>((_w.z >> 16) & 0x0F);

    uchar4 x6_z, w6_z;
    x6_z.x = static_cast<unsigned char>((_x.z >> 12) & 0x0F);
    w6_z.x = static_cast<unsigned char>((_w.z >> 12) & 0x0F);
    x6_z.y = static_cast<unsigned char>((_x.z >> 8) & 0x0F);
    w6_z.y = static_cast<unsigned char>((_w.z >> 8) & 0x0F);
    x6_z.z = static_cast<unsigned char>((_x.z >> 4) & 0x0F);
    w6_z.z = static_cast<unsigned char>((_w.z >> 4) & 0x0F);
    x6_z.w = static_cast<unsigned char>(_x.z & 0x0F);
    w6_z.w = static_cast<unsigned char>(_w.z & 0x0F);

    uchar4 x7_w, w7_w;
    x7_w.x = static_cast<unsigned char>((_x.w >> 28) & 0x0F);
    w7_w.x = static_cast<unsigned char>((_w.w >> 28) & 0x0F);
    x7_w.y = static_cast<unsigned char>((_x.w >> 24) & 0x0F);
    w7_w.y = static_cast<unsigned char>((_w.w >> 24) & 0x0F);
    x7_w.z = static_cast<unsigned char>((_x.w >> 20) & 0x0F);
    w7_w.z = static_cast<unsigned char>((_w.w >> 20) & 0x0F);
    x7_w.w = static_cast<unsigned char>((_x.w >> 16) & 0x0F);
    w7_w.w = static_cast<unsigned char>((_w.w >> 16) & 0x0F);

    uchar4 x8_w, w8_w;
    x8_w.x = static_cast<unsigned char>((_x.w >> 12) & 0x0F);
    w8_w.x = static_cast<unsigned char>((_w.w >> 12) & 0x0F);
    x8_w.y = static_cast<unsigned char>((_x.w >> 8) & 0x0F);
    w8_w.y = static_cast<unsigned char>((_w.w >> 8) & 0x0F);
    x8_w.z = static_cast<unsigned char>((_x.w >> 4) & 0x0F);
    w8_w.z = static_cast<unsigned char>((_w.w >> 4) & 0x0F);
    x8_w.w = static_cast<unsigned char>(_x.w & 0x0F);
    w8_w.w = static_cast<unsigned char>(_w.w & 0x0F);

    sum = __dp4a(x1_x, w1_x, sum);
    sum = __dp4a(x2_x, w2_x, sum);
    sum = __dp4a(x3_y, w3_y, sum);
    sum = __dp4a(x4_y, w4_y, sum);
    sum = __dp4a(x5_z, w5_z, sum);
    sum = __dp4a(x6_z, w6_z, sum);
    sum = __dp4a(x7_w, w7_w, sum);
    sum = __dp4a(x8_w, w8_w, sum);
  }
  sum = warpReduceSum(sum, blockDim.x);
  if (thread_idx == 0) {
    y_ptr[row] = sum;
  }
  return;
}
