import torch
import torch.nn as nn
import torch.nn.functional as F

class Boxcar(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, thresh, subthresh):
        # spike threshold, Heaviside
        # store membrane potential before reset
        ctx.save_for_backward(input)
        ctx.thresh = thresh
        ctx.subthresh = subthresh
        return input.gt(thresh).float()

    @staticmethod
    def backward(ctx, grad_output):
        # surrogate-gradient, BoxCar
        # stored membrane potential before reset
        (input,) = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input - ctx.thresh) < ctx.subthresh
        # return grad_input, None, None
        return grad_input * temp.float(), None, None

class HeavisideBoxcarCall(nn.Module):
    def __init__(self, thresh=1.0, subthresh=0.5, alpha=1.0, spiking=True):
        super().__init__()
        self.alpha = alpha
        self.spiking = spiking
        self.thresh = torch.tensor(thresh)
        self.subthresh = torch.tensor(subthresh)
        self.thresh.to("cuda" if torch.cuda.is_available() else "cpu")
        self.subthresh.to("cuda" if torch.cuda.is_available() else "cpu")
        if spiking:
            self.f = Boxcar.apply
        else:
            self.f = self.primitive_function

    def forward(self, x):
        return self.f(x, self.thresh, self.subthresh)

    @staticmethod
    def primitive_function(x: torch.Tensor, alpha):
        return (x * alpha)


class SpkTrace(torch.autograd.Function):
    # Define approximate firing function
    @staticmethod
    def forward(ctx, input_spk, input_spk_trace, spk_trace_threshold):
        # spike threshold, Heaviside
        # store membrane potential before reset
        ctx.save_for_backward(input_spk_trace)
        ctx.spk_trace_threshold = spk_trace_threshold
        return input_spk.float()

    @staticmethod
    def backward(ctx, grad_output):
        # surrogate-gradient, BoxCar
        # stored membrane potential before reset
        (input_spk_trace,) = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = (input_spk_trace - ctx.spk_trace_threshold) > 0
        return grad_input * temp.float(), None, None


class SpkTraceCall:
    def __init__(self, spk_trace_threshold=0.5):
        self.spk_trace_threshold = torch.tensor(spk_trace_threshold)
        self.spk_trace_threshold.to("cuda" if torch.cuda.is_available() else "cpu")

    def __call__(self, input_spk, input_spk_trace):
        return SpkTrace.apply(input_spk, input_spk_trace, self.spk_trace_threshold)


class SpkTraceModified(torch.autograd.Function):
    # Define approximate firing function
    @staticmethod
    def forward(ctx, input_spk, input_spk_trace_binary):
        # spike threshold, Heaviside
        # store membrane potential before reset
        ctx.save_for_backward(input_spk_trace_binary)
        return input_spk.float()

    @staticmethod
    def backward(ctx, grad_output):
        # surrogate-gradient, BoxCar
        # stored membrane potential before reset
        (input_spk_trace_binary,) = ctx.saved_tensors
        grad_input = grad_output.clone()
        return grad_input * input_spk_trace_binary, None, None


class SpkTraceModifiedCall:
    def __call__(self, input_spk, input_spk_trace_binary):
        return SpkTraceModified.apply(input_spk, input_spk_trace_binary)


class SpkTraceBinary(torch.autograd.Function):
    # Define approximate firing function
    @staticmethod
    def forward(ctx, input_spk_trace):
        ctx.set_materialize_grads(False)
        temp = (input_spk_trace - 0.5) > 0
        return temp.float()

    @staticmethod
    def backward(ctx, grad_output):
        # surrogate-gradient, BoxCar
        # stored membrane potential before reset
        grad_input = grad_output.clone()
        return grad_input


class SpkTraceBinaryCall:
    def __call__(self, input_spk_trace):
        return SpkTraceBinary.apply(input_spk_trace)


class SpkTraceBinaryModified(torch.autograd.Function):
    # Define approximate firing function
    @staticmethod
    def forward(ctx, input_spk_trace):
        ctx.save_for_backward(input_spk_trace)
        temp = (input_spk_trace - 0.5) > 0
        return temp.float()

    @staticmethod
    def backward(ctx, grad_output):
        (input_spk_trace,) = ctx.saved_tensors
        # surrogate-gradient, BoxCar
        # stored membrane potential before reset
        grad_input = grad_output.clone()
        temp = input_spk_trace > 0.25
        return grad_input * temp.float()


class SpkTraceBinaryModifiedCall:
    def __call__(self, input_spk_trace):
        return SpkTraceBinaryModified.apply(input_spk_trace)


class SpkTraceSwitch(torch.autograd.Function):
    # Define approximate firing function
    @staticmethod
    def forward(ctx, input_spk_mul_w, input_spk_b_mul_w):
        ctx.set_materialize_grads(False)
        return input_spk_mul_w

    @staticmethod
    def backward(ctx, grad_output):
        # surrogate-gradient, BoxCar
        # stored membrane potential before reset
        grad_input = grad_output.clone()
        # print("SpkTraceSwitch", grad_input[0])
        return grad_input, grad_input


class SpkTraceSwitchCall:
    def __call__(self, input_spk_mul_w, input_spk_b_mul_w):
        return SpkTraceSwitch.apply(input_spk_mul_w, input_spk_b_mul_w)


class AccAlwaysGrad(torch.autograd.Function):
    # Define approximate firing function
    # 항상 모든 step에서 output에서는 grad가 만들어지게 유도, acc_nueron에 들어오는 spk가 0이여도 작동
    # always update - 상관없을 듯, 어자피 spike 나오면, Vmem이 0이므로 update 안됨.
    @staticmethod
    def forward(ctx, input_spk):
        # spike threshold, Heaviside
        ctx.set_materialize_grads(False)
        return input_spk.float()

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return grad_input


class AccAlwaysGradCall:
    def __call__(self, input_spk):
        return AccAlwaysGrad.apply(input_spk)


class OnePool2d(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_tensor):
        ctx.save_for_backward(input_tensor)
        pool = nn.MaxPool2d(2, 2)
        out = pool(input_tensor)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        (input_tensor,) = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_up_sampling = F.interpolate(grad_input, scale_factor=2, mode="nearest")
        # print('OnePool', grad_up_sampling[0][0][0][0])
        return grad_up_sampling * input_tensor, None, None


class pPLIF_Node(nn.Module):
    def __init__(self, surrogate_function=HeavisideBoxcarCall()):
        super().__init__()
        self.surrogate_function = surrogate_function

    def forward(
        self,
        mem: torch.Tensor,
        spike_before: torch.Tensor,
        decay: torch.Tensor,
        I_in: torch.Tensor,
    ):
        mem = mem * decay * (1 - spike_before) + I_in
        spike = self.surrogate_function(mem)
        return mem, spike


class PLIF_Node(nn.Module):
    def __init__(self, surrogate_function=HeavisideBoxcarCall()):
        super().__init__()
        self.surrogate_function = surrogate_function

    def forward(
        self,
        mem: torch.Tensor,
        spike_before: torch.Tensor,
        decay: torch.Tensor,
        I_in: torch.Tensor,
    ):
        mem = mem * (1 - decay) * (1 - spike_before) + I_in * (decay)
        spike = self.surrogate_function(mem)
        return mem, spike


class pPTRACE_Node(nn.Module):
    def __init__(
        self,
        parametric_tau=True,
        version="v3",
        surrogate_function1=SpkTraceBinaryModifiedCall,
        surrogate_function2=SpkTraceModifiedCall,
        surrogate_function3=SpkTraceSwitchCall,
    ):
        super().__init__()
        self.parametric_tau = parametric_tau
        self.surrogate_function1 = surrogate_function1
        self.surrogate_function2 = surrogate_function2
        self.surrogate_function3 = surrogate_function3
        self.version = version

    def forward(self, mem: torch.Tensor, decay: torch.Tensor, spk_in: torch.Tensor):
        if self.parametric_tau:
            mem = mem * decay + spk_in.detach() * 0.5
        else:
            mem = mem * decay.detach() + spk_in.detach() * 0.5
        mem = 1 / (1 + torch.exp(-(mem - 0.5) * 8))
        spike_b = self.surrogate_function1()(mem)

        if self.version == "v3":
            spike = self.surrogate_function3()(spike_b, spk_in)
        return mem, spike


class pPLI_Node(nn.Module):
    # for acc neuron
    def __init__(self, decay_acc=False, surrogate_function=AccAlwaysGradCall()):
        super().__init__()
        self.decay_acc = decay_acc
        self.surrogate_function = surrogate_function

    def forward(self, mem: torch.Tensor, decay: torch.Tensor, spk_in: torch.Tensor):
        if self.decay_acc:
            mem = mem * decay + self.surrogate_function(spk_in)
        else:
            mem = mem + self.surrogate_function(spk_in)
        return mem


class pPLI_Node_BPTT(nn.Module):
    # for acc neuron
    def __init__(self, decay_acc=False, surrogate_function=AccAlwaysGradCall()):
        super().__init__()
        self.decay_acc = decay_acc
        self.surrogate_function = surrogate_function

    def forward(self, mem: torch.Tensor, decay: torch.Tensor, spk_in: torch.Tensor):
        if self.decay_acc:
            mem = mem * decay + spk_in
        else:
            mem = mem + spk_in
        return mem
