import torch
import torch.nn as nn
import torch.nn.functional as F


class Interpolate(nn.Module):
    def __init__(self, scale_factor=2, mode='nearest'):
        super(Interpolate, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, input):
        output = F.interpolate(input=input, scale_factor=self.scale_factor, mode=self.mode)
        return output


class ZIF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, gama):
        out = (input > 0).float()
        L = torch.tensor([gama])
        ctx.save_for_backward(input, out, L)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        (input, out, others) = ctx.saved_tensors
        gama = others[0].item()
        grad_input = grad_output.clone()
        tmp = (1 / gama) * (1 / gama) * ((gama - input.abs()).clamp(min=0))
        grad_input = grad_input * tmp
        return grad_input, None


class SpikeNode(nn.Module):

    def __init__(self, vth=1.0, tau=0.5, gamma=1.0, v_reset=0.):
        super(SpikeNode, self).__init__()
        self.act = ZIF.apply
        self.vth = vth
        self.tau = tau
        self.gamma = gamma
        self.v_reset = v_reset
        self.v = self.v_reset

    def forward(self, input):
        self.v = self.tau * self.v + input
        return self.spiking()

    def spiking(self):
        spike = self.act(self.v - self.vth, self.gamma)
        self.v = self.v * (1 - spike) + self.v_reset * spike
        return spike

    def reset(self):
        self.v = self.v_reset
