
// This code is based off of torch_sparse's
// https://github.com/rusty1s/pytorch_sparse/blob/master/csrc/cuda/spmm_cuda.cu
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#define CHECK_CUDA(x)                                                          \
  AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")

#define THREADS 256
#define FULL_MASK 0xffffffff

// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code:  https://github.com/owensgroup/merge-spmm
__global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
                            const float *value_data, const float *mat_data,
                            float *out_data, int B, int M, int N, int K) {

  // We ignore blockIdx.y here, because threads
  // across `blockIdx.y` are treated equally.
  int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

  int row = thread_idx >> 5;            // thread_idx / 32
  int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
  int batch_idx = row / M;

  // Compute the column index of `mat` in which the thread is operating.
  int mat_col_idx = lane_idx + (blockIdx.y << 5);

  // Compute the output index (row-major order).
  int out_idx = row * K + mat_col_idx;

  // Helper arrays for warp communication.
  int mat_row, mat_rows[32];
  float val, vals[32];

  // Do not aggregate/write across the Y-axis (lane_idx < leftover).
  int leftover = K - (blockIdx.y << 5);

  if (batch_idx < B) {
    int row_start = __ldg(rowptr_data + (row % M));
    int row_end = __ldg(rowptr_data + (row % M) + 1);
    int col_idx = row_start + lane_idx;

    float result_sum = 0.0f;

    // Iterate over all `col` indices in parallel within a warp.
    for (int c = row_start; c < row_end; c += 32) {

      if (col_idx < row_end) {
        // Coalesced memory access into `col` and `val`.
        mat_row = __ldg(col_data + col_idx) * K;
        val = __ldg(value_data + col_idx);
      } else {
        mat_row = -1;
        val = 0.0f;
      }
      col_idx += 32;

#pragma unroll
      for (int i = 0; i < 32; i++) {
        // Communication between all threads in a warp.
        mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
        vals[i] = __shfl_sync(FULL_MASK, val, i);
      }

#pragma unroll
      for (int i = 0; i < 32; i++) {
        if (lane_idx < leftover && mat_rows[i] != -1.0f) {
          // Coalesced memory access into `mat`.
          val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
          val = vals[i] * val;
          result_sum = result_sum + val;
        }
      }
    }

    if (lane_idx < leftover) {
      // Coalesced write into `out`.
      *(out_data + out_idx) = result_sum;
    }
  }
}

torch::Tensor ts_spmm_fp32(torch::Tensor rowptr, torch::Tensor col,
                           torch::Tensor value, torch::Tensor mat) {

  CHECK_CUDA(rowptr);
  CHECK_CUDA(col);
  CHECK_CUDA(value);
  CHECK_CUDA(mat);
  cudaSetDevice(rowptr.get_device());

  CHECK_INPUT(rowptr.dim() == 1);
  CHECK_INPUT(col.dim() == 1);
  CHECK_INPUT(value.dim() == 1);
  CHECK_INPUT(value.size(0) == col.size(0));
  CHECK_INPUT(mat.dim() >= 2);

  // datatypes
  TORCH_CHECK(rowptr.dtype() == torch::kInt64, "dtype rowptr");
  TORCH_CHECK(col.dtype() == torch::kInt64, "dtype col");
  TORCH_CHECK(value.dtype() == torch::kFloat32, "dtype value");
  TORCH_CHECK(mat.dtype() == torch::kFloat32, "dtype mat");

  mat = mat.contiguous();

  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
  auto out = torch::empty(sizes, mat.options().dtype(torch::kFloat32));

  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();

  auto M = rowptr.numel() - 1;
  auto N = mat.size(-2);
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);
  auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);

  auto stream = at::cuda::getCurrentCUDAStream();
  float *mat_data = mat.data_ptr<float>();
  float *value_data = value.data_ptr<float>();
  float *out_data = out.data_ptr<float>();
  spmm_kernel<<<BLOCKS, THREADS, 0, stream>>>(rowptr_data, col_data, value_data,
                                              mat_data, out_data, B, M, N, K);

  return out;
}

#define NUM_AGGS 3

__global__ void fuse_kernel(const int64_t *rowptr_data, const int64_t *col_data,
                            const float *value_data, const float *mat_data,
                            float *out_data, float *w_data, int B, int M, int N,
                            int K) {

  // We ignore blockIdx.y here, because threads
  // across `blockIdx.y` are treated equally.
  int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

  int row = thread_idx >> 5;            // thread_idx / 32
  int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
  int batch_idx = row / M;

  // Compute the column index of `mat` in which the thread is operating.
  int mat_col_idx = lane_idx + (blockIdx.y << 5);

  // Compute the output index (row-major order).
  int out_idx = row * K + mat_col_idx;

  // Helper arrays for warp communication.
  int mat_row, mat_rows[32];
  float val, vals[32];

  // Do not aggregate/write across the Y-axis (lane_idx < leftover).
  int leftover = K - (blockIdx.y << 5);

  if (batch_idx < B) {
    int row_start = __ldg(rowptr_data + (row % M));
    int row_end = __ldg(rowptr_data + (row % M) + 1);

    float sum_w = __ldg(w_data + (3 * row) + 0);
    float min_w = __ldg(w_data + (3 * row) + 1);
    float max_w = __ldg(w_data + (3 * row) + 2);

    int col_idx = row_start + lane_idx;

    float result_sum = 0.0f;
    float result_min = std::numeric_limits<float>::max();
    float result_max = std::numeric_limits<float>::lowest();

    // Iterate over all `col` indices in parallel within a warp.
    for (int c = row_start; c < row_end; c += 32) {

      if (col_idx < row_end) {
        // Coalesced memory access into `col` and `val`.
        mat_row = __ldg(col_data + col_idx) * K;
        val = __ldg(value_data + col_idx);
      } else {
        mat_row = -1;
        val = 0.0f;
      }
      col_idx += 32;

#pragma unroll
      for (int i = 0; i < 32; i++) {
        // Communication between all threads in a warp.
        mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
        vals[i] = __shfl_sync(FULL_MASK, val, i);
      }

#pragma unroll
      for (int i = 0; i < 32; i++) {
        if (lane_idx < leftover && mat_rows[i] != -1.0f) {
          // Coalesced memory access into `mat`.
          val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
          result_min = val < result_min ? val : result_min;
          result_max = val > result_max ? val : result_max;
          val = vals[i] * val;
          result_sum = result_sum + val;
        }
      }
    }

    if (lane_idx < leftover) {
      // Coalesced write into `out`.
      if (row_end - row_start == 0) {
        result_min = 0.0f;
        result_max = 0.0f;
      }
      *(out_data + out_idx) =
          sum_w * result_sum + min_w * result_min + max_w * result_max;
    }
  }
}


__global__ void fuse_mat_kernel(const int64_t *rowptr_data, const int64_t *col_data,
                            const float *value_data, const float *mat_data,
                            float *out_data, float *w_data, int B, int M, int N,
                            int K, float *out_mater) {
  // NOTE: this forces the explicit materialization of values that otherwise don't
  // need to be materialized. This is really a hack to force the optimizer
  // to do what I want.

  // We ignore blockIdx.y here, because threads
  // across `blockIdx.y` are treated equally.
  int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

  int row = thread_idx >> 5;            // thread_idx / 32
  int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
  int batch_idx = row / M;

  // Compute the column index of `mat` in which the thread is operating.
  int mat_col_idx = lane_idx + (blockIdx.y << 5);

  // Compute the output index (row-major order).
  int out_idx = row * K + mat_col_idx;
  int shift = B * N * K;
  float *sum_mat = out_mater + 0 * shift;
  float *max_mat = out_mater + 1 * shift;
  float *min_mat = out_mater + 2 * shift;

  // Helper arrays for warp communication.
  int mat_row, mat_rows[32];
  float val, vals[32];

  // Do not aggregate/write across the Y-axis (lane_idx < leftover).
  int leftover = K - (blockIdx.y << 5);

  if (batch_idx < B) {
    int row_start = __ldg(rowptr_data + (row % M));
    int row_end = __ldg(rowptr_data + (row % M) + 1);

    float sum_w = __ldg(w_data + (3 * row) + 0);
    float min_w = __ldg(w_data + (3 * row) + 1);
    float max_w = __ldg(w_data + (3 * row) + 2);

    int col_idx = row_start + lane_idx;

    float result_sum = 0.0f;
    float result_min = std::numeric_limits<float>::max();
    float result_max = std::numeric_limits<float>::lowest();

    // Iterate over all `col` indices in parallel within a warp.
    for (int c = row_start; c < row_end; c += 32) {

      if (col_idx < row_end) {
        // Coalesced memory access into `col` and `val`.
        mat_row = __ldg(col_data + col_idx) * K;
        val = __ldg(value_data + col_idx);
      } else {
        mat_row = -1;
        val = 0.0f;
      }
      col_idx += 32;

#pragma unroll
      for (int i = 0; i < 32; i++) {
        // Communication between all threads in a warp.
        mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
        vals[i] = __shfl_sync(FULL_MASK, val, i);
      }

#pragma unroll
      for (int i = 0; i < 32; i++) {
        if (lane_idx < leftover && mat_rows[i] != -1.0f) {
          // Coalesced memory access into `mat`.
          val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
          result_min = val < result_min ? val : result_min;
          result_max = val > result_max ? val : result_max;
          val = vals[i] * val;
          result_sum = result_sum + val;
        }
      }
    }

    if (lane_idx < leftover) {
      // Coalesced write into `out`.
      if (row_end - row_start == 0) {
        result_min = 0.0f;
        result_max = 0.0f;
      }
      *(out_data + out_idx) =
          sum_w * result_sum + min_w * result_min + max_w * result_max;
      *(sum_mat + out_idx) = result_sum;
      *(max_mat + out_idx) = result_max;
      *(min_mat + out_idx) = result_min;
    }
  }
}

torch::Tensor ts_fuse_fp32(torch::Tensor rowptr, torch::Tensor col,
                           torch::Tensor value, torch::Tensor mat,
                           torch::Tensor w, bool materialize) {

  CHECK_CUDA(rowptr);
  CHECK_CUDA(col);
  CHECK_CUDA(value);
  CHECK_CUDA(mat);
  CHECK_CUDA(w);
  cudaSetDevice(rowptr.get_device());

  CHECK_INPUT(rowptr.dim() == 1);
  CHECK_INPUT(col.dim() == 1);
  CHECK_INPUT(value.dim() == 1);
  CHECK_INPUT(value.size(0) == col.size(0));
  CHECK_INPUT(mat.dim() >= 2);
  CHECK_INPUT(w.dim() == mat.dim());
  CHECK_INPUT(w.size(-2) == rowptr.numel() - 1);
  CHECK_INPUT(w.size(-1) == NUM_AGGS);

  // datatypes
  TORCH_CHECK(rowptr.dtype() == torch::kInt64, "dtype rowptr");
  TORCH_CHECK(col.dtype() == torch::kInt64, "dtype col");
  TORCH_CHECK(value.dtype() == torch::kFloat32, "dtype value");
  TORCH_CHECK(mat.dtype() == torch::kFloat32, "dtype mat");

  mat = mat.contiguous();

  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
  auto out = torch::empty(sizes, mat.options().dtype(torch::kFloat32));


  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();

  auto M = rowptr.numel() - 1;
  auto N = mat.size(-2);
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);
  auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);

  auto stream = at::cuda::getCurrentCUDAStream();
  float *mat_data = mat.data_ptr<float>();
  float *value_data = value.data_ptr<float>();
  float *out_data = out.data_ptr<float>();
  float *w_data = w.data_ptr<float>();
  if (materialize) {
    sizes.insert(sizes.begin(), 3);
    auto out_mat = torch::empty(sizes, mat.options().dtype(torch::kFloat32));
    float *out_mat_data = out_mat.data_ptr<float>();
    fuse_mat_kernel<<<BLOCKS, THREADS, 0, stream>>>(rowptr_data, col_data, value_data,
                                                mat_data, out_data, w_data, B, M,
                                                N, K, out_mat_data);
  }
  else {
    fuse_kernel<<<BLOCKS, THREADS, 0, stream>>>(rowptr_data, col_data, value_data,
                                                  mat_data, out_data, w_data, B, M,
                                                  N, K);
  }

  return out;
}