#include <Python.h>
#include <torch/script.h>
#include <torch/extension.h>

#include "cpu/relabel_cpu.h"
#include "cpu/partitioning.h"
#include "cuda/async_cuda.h"
#include "cpu/masking.h"


std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
           torch::Tensor>
relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
                torch::optional<torch::Tensor> optional_value,
                torch::Tensor idx, bool bipartite) {
  if (rowptr.device().is_cuda()) {
    AT_ERROR("Do not use relabeling for CUDA tensor... plz .cpu()");
  } else {
    return relabel_one_hop_cpu(rowptr, col, optional_value, idx, bipartite);
  }
}

std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
           torch::Tensor>
generate_contiguous_heterograph(torch::Tensor rowptr, torch::Tensor col,
                torch::optional<torch::Tensor> optional_value,
                torch::Tensor idx, bool bipartite) {
  if (rowptr.device().is_cuda()) {
    AT_ERROR("Do not use heterograph generation for CUDA tensor... plz .cpu()");
  } else {
    return generate_contiguous_heterograph_cpu(rowptr, col, optional_value, idx, bipartite);
  }
}

void synchronize() {
  synchronize_cuda();
}

void d2h_synchronize() {
  d2h_synchronize_cuda();
}

void h2d_synchronize() {
  h2d_synchronize_cuda();
}

void upload_async(torch::Tensor src, torch::Tensor dst) {
  upload_async_cuda(src, dst);
}

void fill_async(torch::Tensor src, torch::Tensor dst) {
  fill_async_cuda(src, dst);
}

void gather_async(int pid, std::vector<torch::Tensor> srcs, torch::Tensor dst, std::vector<torch::Tensor> bndries) {
  gather_async_cuda(pid, srcs, dst, bndries);
}

void scatter_async(int pid, torch::Tensor src, std::vector<torch::Tensor> dsts,
                  std::vector<torch::Tensor> bndries) {
  scatter_async_cuda(pid, src, dsts, bndries);
}

void read_async(torch::Tensor src,
                torch::optional<torch::Tensor> optional_offset,
                torch::optional<torch::Tensor> optional_count,
                torch::Tensor index, torch::Tensor dst, torch::Tensor buffer) {
  read_async_cuda(src, optional_offset, optional_count, index, dst, buffer);
}

void write_async(torch::Tensor src, torch::Tensor offset, torch::Tensor count,
                 torch::Tensor dst) {
  write_async_cuda(src, offset, count, dst);
}

void contiguous_write_async(torch::Tensor src, torch::Tensor dst) {
  contiguous_write_async_cuda(src, dst);
}

void write_with_reduction_async(torch::Tensor src, torch::Tensor dst, torch::Tensor index) {
  write_with_reduction_async_cuda(src, dst, index);
}

void conti_write_with_reduction_async(torch::Tensor src, torch::Tensor dst) {
  conti_write_with_reduction_async_cuda(src, dst);
}

torch::Tensor
spinner(torch::Tensor rowptr, torch::Tensor col,
            int num_parts, float capacity, float beta, int max_iter,
            float halting_eps, int halting_window,
            bool async, bool log, int num_threads) {
  if (async) {
    return spinner_async_cpu(rowptr, col, num_parts, capacity, beta, max_iter, halting_eps, halting_window, log, num_threads);
  } else {
    return spinner_cpu(rowptr, col, num_parts, capacity, beta, max_iter, halting_eps, halting_window, log, num_threads);
   }
}

torch::Tensor
spinner_gas(torch::Tensor rowptr, torch::Tensor col,
            int num_parts, float capacity, float beta, int max_iter,
            float halting_eps, int halting_window,
            bool async, bool log, int num_threads) {
  return spinner_gas_cpu(rowptr, col, num_parts, capacity, beta, max_iter, halting_eps, halting_window, log, num_threads);
}


torch::Tensor
grinnder(torch::Tensor rowptr, torch::Tensor col,
          int num_parts, float capacity, float beta, int max_iter,
          float halting_eps, int halting_window,
          bool async, bool reuse_aware, bool refine, bool log, int num_threads,
          torch::optional<torch::Tensor> orig_labels) {
  return grinnder_async_cpu(rowptr, col, num_parts, capacity, beta, \
  max_iter, halting_eps, halting_window, reuse_aware, refine, log, num_threads, orig_labels);
}

torch::Tensor
fast_grinnder(torch::Tensor rowptr, torch::Tensor col,
          int num_parts, float capacity, float beta, int max_iter,
          int progressive_window, 
          float halting_eps, int halting_window,
          bool async, bool reuse_aware, bool refine, bool log, int num_threads,
          torch::optional<torch::Tensor> orig_labels) {
  return grinnder_fast_async_cpu(rowptr, col, num_parts, capacity, beta, \
  max_iter, progressive_window, halting_eps, halting_window, reuse_aware, refine, log, num_threads, orig_labels);
}

torch::Tensor
random_partition(torch::Tensor rowptr, torch::Tensor col,
                    int num_parts, bool log) {
  return random_cpu(rowptr, col, num_parts, log);
}

std::tuple<std::vector<torch::Tensor>, std::vector<torch::Tensor>>
gen_cache_mask(std::vector<torch::Tensor> caches,
                    int dst_id, torch::Tensor dst, std::vector<torch::Tensor> dst_bndries,
                    std::vector<torch::Tensor> reuse_masks,
                    int num_threads) {
    return gen_cache_mask_cpu(caches, dst_id, dst, dst_bndries, reuse_masks, num_threads);
}

std::tuple<std::vector<torch::Tensor>, std::vector<torch::Tensor>>
gen_reuse_mask(int src_id, torch::Tensor src, std::vector<torch::Tensor> src_bndries,
              int dst_id, torch::Tensor dst, std::vector<torch::Tensor> dst_bndries,
              int num_threads) {
  return gen_reuse_mask_cpu(src_id, src, src_bndries,
                            dst_id, dst, dst_bndries, num_threads);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("spinner", &spinner, "spinner partitioning");
  m.def("spinner_gas", &spinner_gas, "spinner partitioning with GAS");
  m.def("grinnder", &grinnder, "grinnder partitioning");
  m.def("fast_grinnder", &fast_grinnder, "fast grinnder partitioning");
  m.def("random_partition", &random_partition, "random partition");
  m.def("gen_cache_mask", &gen_cache_mask, "generate cache mask");
  m.def("gen_reuse_mask", &gen_reuse_mask, "generate reuse mask");
  m.def("relabel_one_hop", &relabel_one_hop, "relabel one hop");
  m.def("generate_contiguous_heterograph", &generate_contiguous_heterograph, \
                      "make congtiguous heterograph");
  m.def("synchronize", &synchronize, "synchronize func");
  m.def("d2h_synchronize", &d2h_synchronize, "d2h synchronize func");
  m.def("h2d_synchronize", &h2d_synchronize, "h2d synchronize func");
  m.def("upload_async", &upload_async, "upload async manner");
  m.def("fill_async", &fill_async, "fill async manner");
  m.def("gather_async", &gather_async, "gather async manner");
  m.def("scatter_async", &scatter_async, "scatter async manner");
  m.def("read_async", &read_async, "read async manner");
  m.def("write_async", &write_async, "write async manner");
  m.def("contiguous_write_async", &contiguous_write_async, "write async in a contiguous manner");
  m.def("write_with_reduction_async", &write_with_reduction_async, "write async and reduction");
  m.def("conti_write_with_reduction_async", &conti_write_with_reduction_async, "write async and reduction");
}