#include <torch/extension.h>
#include <iostream>

#include <vector>

torch::Tensor zif_backward(torch::Tensor x) {
    auto result = (torch::ones_like(x) - x.abs()).clamp_min(0);
    return result;
}

torch::Tensor zif_backward(torch::Tensor x, torch::Tensor thre) {
    // sqrt_thre = torch::sqrt(thre);
    auto x_ = x / thre;
    auto result = (torch::ones_like(x) - x_.abs()).clamp_min(0);
    // result = result * sqrt_thre;
    return result;
}

// Custom function to compute the forward and backward pass
std::vector<torch::Tensor> neuron_thre_forward(torch::Tensor x, torch::Tensor thre, int8_t T, float tau) {
    torch::Tensor spike_train = torch::zeros_like(x);
    torch::Tensor mem_traj = torch::zeros_like(x);
    auto mem = torch::zeros_like(x[0]);
    for(int t=0; t<T; t++) {
        mem = mem * tau + x[t];
        mem_traj[t] = mem - thre;
        auto spike = (mem >= thre);
        mem = mem - spike * thre;
        spike_train[t] = spike;
    }
    return {spike_train, mem_traj};
}

std::vector<torch::Tensor> neuron_thre_forward_hard_reset(torch::Tensor x, torch::Tensor thre, int8_t T, float tau) {
    torch::Tensor spike_train = torch::zeros_like(x);
    torch::Tensor mem_traj = torch::zeros_like(x);
    auto mem = torch::zeros_like(x[0]);
    for(int t=0; t<T; t++) {
        mem = mem * tau + x[t];
        mem_traj[t] = mem - thre;
        auto spike = (mem >= thre);
        mem = mem * (~spike);
        spike_train[t] = spike;
    }
    return {spike_train, mem_traj};
}

std::vector<torch::Tensor> neuron_thre_backward(
    torch::Tensor grad_spike, 
    torch::Tensor mem_traj, 
    torch::Tensor spike_train, 
    torch::Tensor thre,
    int8_t T,
    float tau,
    float sp)
{
    // Initialize gradient for the input
    torch::Tensor grad_x = torch::zeros_like(grad_spike);  
    
    torch::Tensor ds_du = zif_backward(mem_traj[T-1], thre);
    torch::Tensor grad_thre = (grad_spike[T-1] * ds_du + sp * spike_train[T-1] * ds_du);
    grad_x[T-1] = grad_spike[T-1] * ds_du;
    for (int t=T-2; t>=0; t--) {
        ds_du = zif_backward(mem_traj[t], thre);
        grad_thre += (grad_spike[t] * ds_du + sp * spike_train[t] * ds_du);
        // grad_x[t] = grad_spike[t] * zif_backward(mem_traj[t]) + (torch::ones_like(grad_spike[t])-grad_spike[t])*grad_x[t+1];
        // grad_x[t] = grad_spike[t] * ds_du + (torch::ones_like(spike_train[t]) - spike_train[t]) * grad_x[t+1]; // sbp
        grad_x[t] = grad_spike[t] * ds_du + tau * (torch::ones_like(ds_du) - ds_du) * grad_x[t+1]; // bptt
        // grad_x[t] = grad_spike[t] * zif_backward(mem_traj[t]) + grad_x[t+1];
    }
    grad_thre = - grad_thre.sum(0).mean();
    return {grad_x, grad_thre};  // Return the gradient for the input tensor
}

std::vector<torch::Tensor> neuron_thre_backward_hard_reset(
    torch::Tensor grad_spike, 
    torch::Tensor mem_traj, 
    torch::Tensor spike_train, 
    torch::Tensor thre,
    int8_t T,
    float tau,
    float sp)
{
    // Initialize gradient for the input
    torch::Tensor grad_x = torch::zeros_like(grad_spike);  
    
    torch::Tensor ds_du = zif_backward(mem_traj[T-1], thre);
    torch::Tensor grad_thre = (grad_spike[T-1] * ds_du + sp * spike_train[T-1] * ds_du);
    grad_x[T-1] = grad_spike[T-1] * ds_du;
    for (int t=T-2; t>=0; t--) {
        ds_du = zif_backward(mem_traj[t], thre);
        // grad_thre += (grad_spike[t] * ds_du - ds_du / thre * (mem_traj[t] + thre) * grad_x[t+1] + sp * spike_train[t] * ds_du);
        grad_thre += (grad_spike[t] * ds_du + sp * spike_train[t] * ds_du);
        // grad_x[t] = grad_spike[t] * zif_backward(mem_traj[t]) + (torch::ones_like(grad_spike[t])-grad_spike[t])*grad_x[t+1];
        // grad_x[t] = grad_spike[t] * ds_du + (torch::ones_like(spike_train[t]) - spike_train[t]) * grad_x[t+1]; // sbp
        grad_x[t] = grad_spike[t] * ds_du + tau * (torch::ones_like(ds_du) - spike_train[t] - ds_du / thre * (mem_traj[t] + thre)) * grad_x[t+1]; // bptt
        // grad_x[t] = grad_spike[t] * zif_backward(mem_traj[t]) + grad_x[t+1];
    }
    grad_thre = - grad_thre.sum(0).mean();
    return {grad_x, grad_thre};  // Return the gradient for the input tensor
}

std::vector<torch::Tensor> neuron_one_backward(
    torch::Tensor grad_spike, 
    torch::Tensor mem_traj, 
    torch::Tensor spike_train, 
    torch::Tensor thre,
    int8_t T,
    float tau,
    float sp)
{
    // Initialize gradient for the input
    torch::Tensor grad_x = torch::zeros_like(grad_spike);  
    
    torch::Tensor ds_du = zif_backward(mem_traj[T-1]);
    torch::Tensor grad_thre = (grad_spike[T-1] * ds_du + sp * spike_train[T-1] * ds_du);
    grad_x[T-1] = grad_spike[T-1] * ds_du;
    for (int t=T-2; t>=0; t--) {
        ds_du = zif_backward(mem_traj[t]);
        grad_thre += (grad_spike[t] * ds_du + sp * spike_train[t] * ds_du);
        // grad_x[t] = grad_spike[t] * zif_backward(mem_traj[t]) + (torch::ones_like(grad_spike[t])-grad_spike[t])*grad_x[t+1];
        // grad_x[t] = grad_spike[t] * ds_du + (torch::ones_like(spike_train[t]) - spike_train[t]) * grad_x[t+1]; // sbp
        grad_x[t] = grad_spike[t] * ds_du + tau * (torch::ones_like(ds_du) - ds_du) * grad_x[t+1]; // bptt
        // grad_x[t] = grad_spike[t] * zif_backward(mem_traj[t]) + grad_x[t+1];
    }
    grad_thre = - grad_thre.sum(0).mean();
    return {grad_x, grad_thre};  // Return the gradient for the input tensor
}

std::vector<torch::Tensor> neuron_one_backward_hard_reset(
    torch::Tensor grad_spike, 
    torch::Tensor mem_traj, 
    torch::Tensor spike_train, 
    torch::Tensor thre,
    int8_t T,
    float tau,
    float sp)
{
    // Initialize gradient for the input
    torch::Tensor grad_x = torch::zeros_like(grad_spike);  
    
    torch::Tensor ds_du = zif_backward(mem_traj[T-1]);
    torch::Tensor grad_thre = (grad_spike[T-1] * ds_du + sp * spike_train[T-1] * ds_du);
    grad_x[T-1] = grad_spike[T-1] * ds_du;
    for (int t=T-2; t>=0; t--) {
        ds_du = zif_backward(mem_traj[t]);
        grad_thre += (grad_spike[t] * ds_du + sp * spike_train[t] * ds_du);
        // grad_x[t] = grad_spike[t] * zif_backward(mem_traj[t]) + (torch::ones_like(grad_spike[t])-grad_spike[t])*grad_x[t+1];
        // grad_x[t] = grad_spike[t] * ds_du + (torch::ones_like(spike_train[t]) - spike_train[t]) * grad_x[t+1]; // sbp
        grad_x[t] = grad_spike[t] * ds_du + tau * (torch::ones_like(ds_du) - spike_train[t] - ds_du * (mem_traj[t] + thre)) * grad_x[t+1]; // bptt
        // grad_x[t] = grad_spike[t] * zif_backward(mem_traj[t]) + grad_x[t+1];
    }
    grad_thre = - grad_thre.sum(0).mean();
    return {grad_x, grad_thre};  // Return the gradient for the input tensor
}

std::vector<torch::Tensor> neuron_thre_backward_sbp(
    torch::Tensor grad_spike, 
    torch::Tensor mem_traj, 
    torch::Tensor spike_train, 
    torch::Tensor thre,
    int8_t T,
    float sp)
{
    // Initialize gradient for the input
    torch::Tensor grad_x = torch::zeros_like(grad_spike);
    
    torch::Tensor ds_du = zif_backward(mem_traj[T-1], thre);
    torch::Tensor grad_thre = ds_du.mean();
    grad_x[T-1] = grad_spike[T-1] * ds_du;
    for (int t=T-2; t>=0; t--) {
        ds_du = zif_backward(mem_traj[t], thre);
        grad_thre += ds_du.mean();
        // grad_x[t] = grad_spike[t] * zif_backward(mem_traj[t]) + (torch::ones_like(grad_spike[t])-grad_spike[t])*grad_x[t+1];
        grad_x[t] = grad_spike[t] * ds_du + (torch::ones_like(spike_train[t]) - spike_train[t]) * grad_x[t+1]; // sbp
        // grad_x[t] = grad_spike[t] * ds_du + (torch::ones_like(ds_du) - ds_du) * grad_x[t+1]; // bptt
        // grad_x[t] = grad_spike[t] * zif_backward(mem_traj[t]) + grad_x[t+1];
    }
    grad_thre = - sp * spike_train.mean() * grad_thre / T;
    return {grad_x, grad_thre};  // Return the gradient for the input tensor
}

// Expose the NeuronFunction to Python as a callable
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward_with_thre", &neuron_thre_forward, "neuron_forward");
  m.def("backward_with_thre", &neuron_thre_backward, "neuron_backward");
  m.def("forward_with_thre_hard", &neuron_thre_forward_hard_reset, "neuron_forward_hard");
  m.def("backward_with_thre_hard", &neuron_thre_backward_hard_reset, "neuron_backward_hard");
  m.def("backward_with_one", &neuron_one_backward, "neuron_backward");
  m.def("backward_with_one_hard", &neuron_one_backward_hard_reset, "neuron_backward");
  m.def("backward_with_thre_sbp", &neuron_thre_backward_sbp, "neuron_backward");
}