import torch
import torch.nn as nn

# from spikingjelly.activation_based.neuron import LIFNode
# from spikingjelly.activation_based.surrogate import PiecewiseLeakyReLU, Sigmoid
# from spikingjelly.clock_driven import surrogate


a = 1.0


class LIF(nn.Module):
    def __init__(self, tau: float = 2.0, v_threshold: float = 1.0, v_reset: float = 0.0,  detach_reset: bool = False,
                 backend='torch'):
        super(LIF, self).__init__()
        self.tau = tau
        self.v_threshold = v_threshold
        self.detach_reset = detach_reset

    def forward(self, x):
        T = x.shape[0]
        B = x.shape[1]
        u = torch.zeros((B,)+x.shape[2:], device=x.device)
        o = torch.zeros(x.shape, device=x.device)
        for t in range(T):
            reset_mask = spikefunc(u, self.v_threshold)
            if self.detach_reset:
                # Applying .detach() to the reset part to avoid gradient flow through reset
                u = (1 / self.tau) * u * (1 - reset_mask.detach()) + x[t, ...]
            else:
                u = (1 / self.tau) * u * (1 - reset_mask) + x[t, ...]
            o[t, ...] = spikefunc(u, self.v_threshold)  # Capture spikes based on current 'u' state
        return o


class MixedLIF(nn.Module):
    """
    Activative function is different for two trail of contrastive learning.
    trail-1: original LIf
    trail-2: Relu-like continuous func.
    """
    def __init__(self, tau: float = 2.0, v_threshold: float = 1.0, v_reset: float = 0.0,  detach_reset: bool = False,
                 backend='torch'):
        super(MixedLIF, self).__init__()
        self.tau = tau
        self.v_threshold = v_threshold

    def forward(self, x):
        T = x.shape[0]
        B = x.shape[1]
        bs = B//2
        u = torch.zeros((bs,)+x.shape[2:], device=x.device)
        u2 = torch.zeros((bs,)+x.shape[2:], device=x.device)
        o = torch.zeros(x.shape, device=x.device)
        for t in range(T):
            u = (1/self.tau) * u * (1 - spikefunc(u, self.v_threshold).detach()) + x[t, :bs, ...]
            u2 = (1/self.tau) * u2 * (1 - spikefunc(u2, self.v_threshold).detach()) + x[t, bs:, ...]
            o[t, :bs, ...] = spikefunc(u, self.v_threshold)  # Equivalent to union of all spikes
            o[t, bs:, ...] = torch.clamp(u2-self.v_threshold+0.5, min=0, max=1.0)
        return o


class LIFt(nn.Module):
    """
    Activative function for two trail of contrastive learning.
    trail-1: LIf
    trail-2: LIf.
    """
    def __init__(self, tau: float = 2.0, v_threshold: float = 1.0, v_reset: float = 0.0,  detach_reset: bool = False,
                 backend='torch'):
        super(LIFt, self).__init__()
        self.tau = tau
        self.v_threshold = v_threshold

    def forward(self, x):
        T = x.shape[0]
        B = x.shape[1]
        bs = B//2
        u = torch.zeros((bs,)+x.shape[2:], device=x.device)
        u2 = torch.zeros((bs,)+x.shape[2:], device=x.device)
        o = torch.zeros(x.shape, device=x.device)
        for t in range(T):
            u = (1/self.tau) * u * (1 - spikefunc(u, self.v_threshold).detach()) + x[t, :bs, ...]
            u2 = (1/self.tau) * u2 * (1 - spikefunc(u2, self.v_threshold).detach()) + x[t, bs:, ...]
            o[t, :bs, ...] = spikefunc(u, self.v_threshold)  # Equivalent to union of all spikes
            o[t, bs:, ...] = spikefunc(u2, self.v_threshold)
        return o


class SpikeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, v_threshold):
        ctx.save_for_backward(input)
        ctx.v_threshold = v_threshold

        output = torch.gt(input, v_threshold)
        return output.float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        v_threshold = ctx.v_threshold
        grad_input = grad_output.clone()
        hu = (abs(input - v_threshold) < (a/2)) / a
        return grad_input * hu, None

spikefunc = SpikeFunction.apply


# class LIF(nn.Module):
#     def __init__(self, tau: float = 2.0, v_threshold: float = 1.0, v_reset: float = 0.0,  detach_reset: bool = False,
#                  backend='torch'):
#         super(LIF, self).__init__()
#         self.timestep = timestep
#         self.lif = LIFNode(step_mode='m', tau=1 / (1 - 1 / tau), v_threshold=v_threshold,
#                            surrogate_function=PiecewiseLeakyReLU(w=1.0, c=0.0, spiking=True),
#                            decay_input=False,
#                            detach_reset=detach_reset,
#                            backend=backend)
#
#     def forward(self, x):
#         x = self.lif(x)
#         return x
#
#
# class LIFt(nn.Module):
#     """
#     Activative function for two trail of contrastive learning.
#     trail-1: LIf
#     trail-2: LIf.
#     """
#
#     def __init__(self, tau: float = 2.0, v_threshold: float = 1.0, timestep=4, v_reset: float = 0.0,
#                  detach_reset: bool = False, backend='torch'):
#         super(LIFt, self).__init__()
#         self.timestep = timestep
#         self.lif_1 = LIFNode(step_mode='m', tau=1 / (1 - 1 / tau), v_threshold=v_threshold,
#                              # surrogate_function=PiecewiseLeakyReLU(w=1.0, c=0.0, spiking=True),
#                              decay_input=False,
#                              detach_reset=detach_reset,
#                              backend=backend)
#         self.lif_2 = LIFNode(step_mode='m', tau=1/(1-1/tau), v_threshold=v_threshold,
#                              # surrogate_function=PiecewiseLeakyReLU(w=1.0, c=0.0, spiking=True),
#                              decay_input=False,
#                              detach_reset=detach_reset,
#                              backend=backend)
#
#     def forward(self, x):
#         B = x.shape[1]
#         bs = B // 2
#
#         o = torch.zeros(x.shape, device=x.device)
#         o[:, :bs, ...] = self.lif_1(x[:, :bs, ...])
#         o[:, bs:, ...] = self.lif_2(x[:, bs:, ...])
#         return o
#
#
# class MixedLIF(nn.Module):
#     """
#     Activative function of two Paths in the way of contrastive learing.
#     Path-1: original LIf
#     Path-2: Relu-like continuous func.
#     Vth and tau are set same as class spikingjelly.activation_based.neuron.LIFNode
#     """
#     def __init__(self, tau: float = 2.0, v_threshold: float = 1.0, v_reset: float = 0.0,  detach_reset: bool = False,
#                  backend='torch'):
#         super(MixedLIF, self).__init__()
#         self.lif = LIFNode(step_mode='m', tau=1/(1-1/tau), v_threshold=v_threshold,
#                            # surrogate_function=PiecewiseLeakyReLU(w=1.0, c=0.0, spiking=True),
#                            decay_input=False,
#                            detach_reset=detach_reset,
#                            backend=backend)
#         self.ClampLu = LIFNode(step_mode='m', tau=1/(1-1/tau), v_threshold=v_threshold,
#                                surrogate_function=Sigmoid(alpha=4.0, spiking=False),
#                                decay_input=False,
#                                detach_reset=detach_reset,
#                                backend=backend)
#
#     def forward(self, x):
#         B = x.shape[1]
#         bs = B//2
#
#         o = torch.zeros(x.shape, device=x.device)
#         o[:, :bs, ...] = self.lif(x[:, :bs, ...])
#         o[:, bs:, ...] = self.ClampLu(x[:, bs:, ...])
#         return o


# class CustomLIFNode(neuron.BaseNode):
#     def __init__(self, v_threshold: float = 1.0, v_reset: float = 0.0, tau: float = 2.0, detach_reset: bool = False,
#                  backend='torch'):
#         super().__init__(v_threshold=v_threshold, v_reset=v_reset, surrogate_function=None, detach_reset=detach_reset)
#         self.tau = tau
#         self.v_threshold = v_threshold
#
#     def neuronal_charge(self, x: torch.Tensor):
#         self.v = (1 - 1 / self.tau) * self.v + 1 / self.tau * x
#
#     def neuronal_fire(self):
#         return torch.clamp(self.v, min=self.v_threshold - 0.5, max=self.v_threshold + 0.5)-(self.v_threshold - 0.5)
#
#     def neuronal_reset(self, spike: torch.Tensor):
#         """
#         * :ref:`API in English <BaseNode.neuronal_reset-en>`
#
#         .. _BaseNode.neuronal_reset-cn:
#
#         根据当前神经元释放的脉冲，对膜电位进行重置。
#
#         * :ref:`中文API <BaseNode.neuronal_reset-cn>`
#
#         .. _BaseNode.neuronal_reset-en:
#
#
#         Reset the membrane potential according to neurons' output spikes.
#         """
#         @staticmethod
#         @torch.jit.script
#         def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float):
#             v = (1. - spike) * v + spike * v_reset
#             return v
#
#         @staticmethod
#         @torch.jit.script
#         def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float):
#             v = v - spike * v_threshold
#             return v
#
#         if self.detach_reset:
#             spike_d = spike.detach()
#         else:
#             spike_d = spike
#         if self.v_reset is None:
#             # soft reset
#             self.v = jit_soft_reset(self.v, spike_d, self.v_threshold)
#         else:
#             # hard reset
#             self.v = jit_hard_reset(self.v, spike_d, self.v_reset)
#
