#include "async_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "../thread.h"

Thread &getThread() {
  static Thread thread;
  return thread;
}

Thread &getH2DThread() {
  static Thread h2d_thread;
  return h2d_thread;
}

Thread &getD2HThread() {
  static Thread d2h_thread;
  return d2h_thread;
}

void h2d_synchronize_cuda() { getH2DThread().synchronize(); }
void d2h_synchronize_cuda() { getD2HThread().synchronize(); }

void synchronize_cuda() { getThread().synchronize(); }

void fill_async_cuda(torch::Tensor src, torch::Tensor dst) {
  AT_ASSERTM(src.is_cuda(), "Source tensor must be a CUDA tensor");
  AT_ASSERTM(!dst.is_cuda(), "Target tensor must be a CPU tensor");

  AT_ASSERTM(src.is_contiguous(), "Source tensor must be contiguous");
  AT_ASSERTM(dst.is_contiguous(), "Target tensor must be contiguous");

  auto stream = at::cuda::getCurrentCUDAStream(src.get_device());
  AT_ASSERTM(stream != at::cuda::getDefaultCUDAStream(src.get_device()),
             "Asynchronous fill requires a non-default CUDA stream");

  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "fill_async", [&] {
    getD2HThread().run([=] {
      auto src_data = src.data_ptr<scalar_t>();
      auto dst_data = dst.data_ptr<scalar_t>();
      cudaMemcpyAsync(dst_data, src_data, src.numel() * sizeof(scalar_t),
                      cudaMemcpyDeviceToHost, stream);
    });
  });
}

void upload_async_cuda(torch::Tensor src, torch::Tensor dst) {
  AT_ASSERTM(!src.is_cuda(), "Source tensor must be a CPU tensor");
  AT_ASSERTM(dst.is_cuda(), "Target tensor must be a CUDA tensor");

  AT_ASSERTM(src.is_contiguous(), "Source tensor must be contiguous");
  AT_ASSERTM(dst.is_contiguous(), "Target tensor must be contiguous");

  auto stream = at::cuda::getCurrentCUDAStream(src.get_device());
  AT_ASSERTM(stream != at::cuda::getDefaultCUDAStream(src.get_device()),
             "Asynchronous fill requires a non-default CUDA stream");

  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "upload_async", [&] {
    getH2DThread().run([=] {
      auto src_data = src.data_ptr<scalar_t>();
      auto dst_data = dst.data_ptr<scalar_t>();
      cudaMemcpyAsync(dst_data, src_data, src.numel() * sizeof(scalar_t),
                      cudaMemcpyHostToDevice, stream);
    });
  });
}

void scatter_async_cuda(int pid, torch::Tensor src,
                        std::vector<torch::Tensor> dsts,
                        std::vector<torch::Tensor> bndries) {
  AT_ASSERTM(src.is_cuda(), "Source tensor must be a CUDA tensor");
  for (auto &dst : dsts) {
    AT_ASSERTM(!dst.is_cuda(), "Target tensor must be a CPU tensor");
    AT_ASSERTM(src.scalar_type() == dst.scalar_type(),
               "Source and target tensor must have the same data type");
  }
  AT_ASSERTM(src.is_contiguous(), "Source tensor must be contiguous");
  for (auto &dst : dsts) {
    AT_ASSERTM(dst.is_contiguous(), "Target tensor must be contiguous");
  }

  auto stream = at::cuda::getCurrentCUDAStream(src.get_device());
  AT_ASSERTM(stream != at::cuda::getDefaultCUDAStream(src.get_device()),
             "Asynchronous scatter requires a non-default CUDA stream");
  
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_async", [&] {
    getD2HThread().run([=] {
      auto src_data = src.data_ptr<scalar_t>();
      int64_t size = src.numel() / src.size(0); // other dim size

      int64_t total_offset = 0; // for checking the final size

      // We need to implement accumulate scatter
      auto self_dst_data = dsts[pid].data_ptr<scalar_t>();
      auto self_dst_size = dsts[pid].size(0);

      auto self_cpied = dsts[pid].detach().clone();

      cudaMemcpyAsync(self_dst_data, src_data,
                      self_dst_size * size * sizeof(scalar_t),
                      cudaMemcpyDeviceToHost, stream);

      // Accumulation
      dsts[pid].data() += self_cpied.data();
      total_offset += self_dst_size;

      // copy offset: to dst
      // as the following code requires accumulation
      // we need to use omp
      // omp_set_num_threads(16); // Set the number of threads to 16
      // #pragma omp parallel for
      for (int i = 0; i < dsts.size(); i++) {
        if (i == pid) { continue; }
        auto dst = dsts[i];
        auto index = bndries[i];
        auto indexed_dst = dst.index({index});
        auto indexed_dst_data = indexed_dst.data_ptr<scalar_t>();
        cudaMemcpyAsync(indexed_dst_data, src_data + total_offset * size,
                        index.size(0) * size * sizeof(scalar_t),
                        cudaMemcpyDeviceToHost, stream);
        dst.index_put_({index}, indexed_dst, /*accumulate=*/true);
        total_offset += index.size(0);
      }

      // final check
      AT_ASSERTM(total_offset == src.size(0), "We did not copy all the data");
    });
  });
}

void gather_async_cuda(int pid, std::vector<torch::Tensor> srcs,
                       torch::Tensor dst, std::vector<torch::Tensor> bndries) {
  AT_ASSERTM(dst.is_cuda(), "Target tensor must be a CUDA tensor");
  for (auto &src : srcs) {
    AT_ASSERTM(!src.is_cuda(), "Source tensor must be a CPU tensor");
    AT_ASSERTM(src.scalar_type() == dst.scalar_type(),
               "Source and target tensor must have the same data type");
  }

  AT_ASSERTM(dst.is_contiguous(), "Target tensor must be contiguous");
  for (auto &src : srcs) {
    AT_ASSERTM(src.is_contiguous(), "Source tensor must be contiguous");
  }

  auto stream = at::cuda::getCurrentCUDAStream(srcs[0].get_device());
  AT_ASSERTM(stream != at::cuda::getDefaultCUDAStream(srcs[0].get_device()),
             "Asynchronous gather requires a non-default CUDA stream");

  AT_DISPATCH_ALL_TYPES(dst.scalar_type(), "gather_async", [&] {
    getH2DThread().run([=] {
      auto dst_data = dst.data_ptr<scalar_t>();
      int64_t size = dst.numel() / dst.size(0); // other dim size

      int64_t offset = srcs[pid].size(0);

      // copy 0:offset to dst
      auto src_data = srcs[pid].data_ptr<scalar_t>();
      cudaMemcpyAsync(dst_data, src_data,
                      offset * size * sizeof(scalar_t),
                      cudaMemcpyHostToDevice, stream);

      // copy offset: to dst
      for (int i = 0; i < srcs.size(); i++) {
        if (i == pid) { continue; }
        auto src = srcs[i];
        auto src_data = src.data_ptr<scalar_t>();
        auto bndry = bndries[i];
        // auto bndry_data = bndry.data_ptr<int64_t>();
        torch::Tensor selected_src = torch::index_select(src, 0, bndry);
        AT_ASSERTM(!selected_src.is_cuda(), "Selected source tensor must be a CPU tensor");
        auto selected_src_data = selected_src.data_ptr<scalar_t>();
        int64_t selected_src_size = selected_src.size(0);
        cudaMemcpyAsync(dst_data + (offset * size), selected_src_data,
                        selected_src_size * size * sizeof(scalar_t),
                        cudaMemcpyHostToDevice, stream);
        offset += selected_src_size;
      }

      // final check
      AT_ASSERTM(offset == dst.size(0), "We did not copy all the data");
    });
  });
}

void read_async_cuda(torch::Tensor src,
                     torch::optional<torch::Tensor> optional_offset,
                     torch::optional<torch::Tensor> optional_count,
                     torch::Tensor index, torch::Tensor dst,
                     torch::Tensor buffer) {

  AT_ASSERTM(!src.is_cuda(), "Source tensor must be a CPU tensor");
  AT_ASSERTM(!index.is_cuda(), "Index tensor must be a CPU tensor");
  AT_ASSERTM(dst.is_cuda(), "Target tensor must be a CUDA tensor");
  AT_ASSERTM(!buffer.is_cuda(), "Buffer tensor must be a CPU tensor");

  AT_ASSERTM(buffer.is_pinned(), "Buffer tensor must be pinned");

  AT_ASSERTM(src.is_contiguous(), "Source tensor must be contiguous");
  AT_ASSERTM(dst.is_contiguous(), "Target tensor must be contiguous");
  AT_ASSERTM(buffer.is_contiguous(), "Buffer tensor must be contiguous");

  AT_ASSERTM(index.dim() == 1, "Index tensor must be one-dimensional");

  int64_t numel = 0;
  if (optional_offset.has_value()) {
    AT_ASSERTM(src.is_pinned(), "Source tensor must be pinned");
    auto offset = optional_offset.value();
    AT_ASSERTM(!offset.is_cuda(), "Offset tensor must be a CPU tensor");
    AT_ASSERTM(offset.is_contiguous(), "Offset tensor must be contiguous");
    AT_ASSERTM(offset.dim() == 1, "Offset tensor must be one-dimensional");
    AT_ASSERTM(optional_count.has_value(), "Count tensor is undefined");
    auto count = optional_count.value();
    AT_ASSERTM(!count.is_cuda(), "Count tensor must be a CPU tensor");
    AT_ASSERTM(count.is_contiguous(), "Count tensor must be contiguous");
    AT_ASSERTM(count.dim() == 1, "Count tensor must be one-dimensional");
    AT_ASSERTM(offset.numel() == count.numel(), "Size mismatch");
    numel = count.sum().data_ptr<int64_t>()[0];
  }

  AT_ASSERTM(numel + index.numel() <= buffer.size(0),
             "Buffer tensor size too small");
  AT_ASSERTM(numel + index.numel() <= dst.size(0),
             "Target tensor size too small");

  auto stream = at::cuda::getCurrentCUDAStream(src.get_device());
  AT_ASSERTM(stream != at::cuda::getDefaultCUDAStream(src.get_device()),
             "Asynchronous read requires a non-default CUDA stream");

  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "read_async", [&] {
    getThread().run([=] {
      int64_t size = src.numel() / src.size(0);
      auto src_data = src.data_ptr<scalar_t>();
      auto dst_data = dst.data_ptr<scalar_t>();

      if (optional_offset.has_value()) {
        auto offset = optional_offset.value();
        auto count = optional_count.value();
        auto offset_data = offset.data_ptr<int64_t>();
        auto count_data = count.data_ptr<int64_t>();
        int64_t src_offset, dst_offset = 0, c;
        for (int64_t i = 0; i < offset.numel(); i++) {
          src_offset = offset_data[i], c = count_data[i];
          AT_ASSERTM(src_offset + c <= src.size(0), "Invalid index");
          AT_ASSERTM(dst_offset + c <= dst.size(0), "Invalid index");
          cudaMemcpyAsync(
              dst_data + (dst_offset * size), src_data + (src_offset * size),
              c * size * sizeof(scalar_t), cudaMemcpyHostToDevice, stream);
          dst_offset += c;
        }
      }

      auto _buffer = buffer.narrow(0, 0, index.numel()); // convert to non-const
      torch::index_select_out(_buffer, src, 0, index);
      int64_t dim = src.numel() / src.size(0);
      cudaMemcpyAsync(dst_data + numel * size, buffer.data_ptr<scalar_t>(),
                      index.numel() * dim * sizeof(scalar_t),
                      cudaMemcpyHostToDevice, stream);
    });
  });
}

void write_async_cuda(torch::Tensor src, torch::Tensor offset,
                      torch::Tensor count, torch::Tensor dst) {
  AT_ASSERTM(src.is_cuda(), "Source tensor must be a CUDA tensor");
  AT_ASSERTM(!offset.is_cuda(), "Offset tensor must be a CPU tensor");
  AT_ASSERTM(!count.is_cuda(), "Count tensor must be a CPU tensor");
  AT_ASSERTM(!dst.is_cuda(), "Target tensor must be a CPU tensor");

  AT_ASSERTM(dst.is_pinned(), "Target tensor must be pinned");

  AT_ASSERTM(src.is_contiguous(), "Index tensor must be contiguous");
  AT_ASSERTM(offset.is_contiguous(), "Offset tensor must be contiguous");
  AT_ASSERTM(count.is_contiguous(), "Count tensor must be contiguous");
  AT_ASSERTM(dst.is_contiguous(), "Target tensor must be contiguous");

  AT_ASSERTM(offset.dim() == 1, "Offset tensor must be one-dimensional");
  AT_ASSERTM(count.dim() == 1, "Count tensor must be one-dimensional");
  AT_ASSERTM(offset.numel() == count.numel(), "Size mismatch");

  auto stream = at::cuda::getCurrentCUDAStream(src.get_device());
  AT_ASSERTM(stream != at::cuda::getDefaultCUDAStream(src.get_device()),
             "Asynchronous write requires a non-default CUDA stream");

  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "write_async", [&] {
    int64_t size = src.numel() / src.size(0);
    auto src_data = src.data_ptr<scalar_t>();
    auto offset_data = offset.data_ptr<int64_t>();
    auto count_data = count.data_ptr<int64_t>();
    auto dst_data = dst.data_ptr<scalar_t>();
    int64_t src_offset = 0, dst_offset, c;
    for (int64_t i = 0; i < offset.numel(); i++) {
      dst_offset = offset_data[i], c = count_data[i];
      AT_ASSERTM(src_offset + c <= src.size(0), "Invalid index");
      AT_ASSERTM(dst_offset + c <= dst.size(0), "Invalid index");
      cudaMemcpyAsync(
          dst_data + (dst_offset * size), src_data + (src_offset * size),
          c * size * sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);
      src_offset += c;
    }
  });
}

void contiguous_write_async_cuda(torch::Tensor src, torch::Tensor dst) {
  AT_ASSERTM(src.is_cuda(), "Source tensor must be a CUDA tensor");
  AT_ASSERTM(!dst.is_cuda(), "Target tensor must be a CPU tensor");

  // AT_ASSERTM(dst.is_pinned(), "Target tensor must be pinned");

  AT_ASSERTM(src.is_contiguous(), "Index tensor must be contiguous");
  AT_ASSERTM(dst.is_contiguous(), "Target tensor must be contiguous");

  AT_ASSERTM(src.numel() <= dst.numel(), "Dst size musth be larger than the Src size");

  auto stream = at::cuda::getCurrentCUDAStream(src.get_device());
  AT_ASSERTM(stream != at::cuda::getDefaultCUDAStream(src.get_device()),
             "Asynchronous write requires a non-default CUDA stream");

  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "write_async", [&] {
    // int64_t size = src.numel() / src.size(0);
    auto src_data = src.data_ptr<scalar_t>();
    auto dst_data = dst.data_ptr<scalar_t>();

    cudaMemcpyAsync(
      dst_data, src_data,
      src.numel() * sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);
  });
}

void write_with_reduction_async_cuda(torch::Tensor src, torch::Tensor dst, torch::Tensor index) {
  AT_ASSERTM(src.is_cuda(), "Source tensor must be a CUDA tensor");
  AT_ASSERTM(!dst.is_cuda(), "Target tensor must be a CPU tensor");

  AT_ASSERTM(src.is_contiguous(), "Index tensor must be contiguous");
  AT_ASSERTM(dst.is_contiguous(), "Target tensor must be contiguous");

  auto stream = at::cuda::getCurrentCUDAStream(src.get_device());
  AT_ASSERTM(stream != at::cuda::getDefaultCUDAStream(src.get_device()),
             "Asynchronous write requires a non-default CUDA stream");

  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "write_async", [&] {
    // int64_t size = src.numel() / src.size(0);
    auto src_data = src.data_ptr<scalar_t>();
    auto indexed_dst = dst.index({index});
    auto indexed_dst_data = indexed_dst.data_ptr<scalar_t>();

    // auto cpied = indexed_dst.detach().clone();

    cudaMemcpyAsync(
      indexed_dst_data, src_data,
      src.numel() * sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);

    // Accumulation
    // indexed_dst.data() += cpied.data();

    // write back
    // dst.index_put_({index}, indexed_dst);
    dst.index_put_({index}, indexed_dst, /*accumulate=*/true);
  });
}

void conti_write_with_reduction_async_cuda(torch::Tensor src, torch::Tensor dst) {
  AT_ASSERTM(src.is_cuda(), "Source tensor must be a CUDA tensor");
  AT_ASSERTM(!dst.is_cuda(), "Target tensor must be a CPU tensor");

  // AT_ASSERTM(dst.is_pinned(), "Target tensor must be pinned");

  AT_ASSERTM(src.is_contiguous(), "Index tensor must be contiguous");
  AT_ASSERTM(dst.is_contiguous(), "Target tensor must be contiguous");

  AT_ASSERTM(src.numel() <= dst.numel(), "Dst size musth be larger than the Src size");

  auto stream = at::cuda::getCurrentCUDAStream(src.get_device());
  AT_ASSERTM(stream != at::cuda::getDefaultCUDAStream(src.get_device()),
             "Asynchronous write requires a non-default CUDA stream");

  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "write_async", [&] {
    // int64_t size = src.numel() / src.size(0);
    auto src_data = src.data_ptr<scalar_t>();
    auto dst_data = dst.data_ptr<scalar_t>();

    auto cpied = dst.detach().clone();

    cudaMemcpyAsync(
      dst_data, src_data,
      src.numel() * sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);

    // Accumulation
    dst.data() += cpied.data();
  });
}