#include <torch/extension.h>

#include "include/fused_softsignsgd_kernel.cuh"

// x is torch::Tensor
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
  AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
  CHECK_CUDA(x);       \
  CHECK_CONTIGUOUS(x)

// C++ interface

void softsignsgd_single_tensor(at::Tensor& p, 
          at::Tensor& p_copy, 
          at::Tensor& g, 
          at::Tensor& exp_avg, 
          at::Tensor& exp_avg_abs, 
          float beta, float lr, float decay, float eps, float power) {
  CHECK_INPUT(p);
  if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
  CHECK_INPUT(g);
  CHECK_INPUT(exp_avg);
  CHECK_INPUT(exp_avg_abs);
  int64_t num_elem = p.numel();
  AT_ASSERTM(g.numel() == num_elem,
             "number of elements in g and p tensors should be equal");
  AT_ASSERTM(exp_avg.numel() == num_elem,
             "number of elements in exp_avg and p tensors should be equal");
  AT_ASSERTM(exp_avg_abs.numel() == num_elem,
             "number of elements in exp_avg_abs and p tensors should be equal");
  AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0,
             "number of elements in p_copy and p tensors should be equal, or "
             "p_copy should be empty");

  fused_softsignsgd_cuda(p, p_copy, g, 
                  exp_avg, exp_avg_abs,
                  beta, lr, decay, eps, power);  
}

void softsignsgd_multi_tensor(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  const float beta,
  const float lr,
  const float decay,
  const float epsilon,
  const float power){
    multi_tensor_softsignsgd_cuda(
      chunk_size,
      noop_flag,
      tensor_lists,
      beta,
      lr,
      decay,
      epsilon,
      power
    );
  }

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("softsignsgd_single_tensor", &softsignsgd_single_tensor, "softsignsgd optimized CUDA single tensor implementation.");
  m.def("softsignsgd_multi_tensor", &softsignsgd_multi_tensor, "softsignsgd optimized CUDA multi tensor implementation.");
}
