"""
Implement custom layers
"""

import torch
import torch.nn as nn
import numpy as np
import math
import torch.nn.functional as F 
from SlidingPSN import mul_free_channel_wise_psn
from spikingjelly.activation_based import surrogate, neuron

Vth = 1.0
alpha_init_gru = 0.9
K = 2
alpha_init_conv = 0.9
gamma = 10

class SpikeAct(torch.autograd.Function): 
    @staticmethod
    def forward(ctx, x_input):
        ctx.save_for_backward(x_input)
        output = torch.ge(x_input, Vth) 
        return output.float()

    @staticmethod
    def backward(ctx, grad_output):
        x_input, = ctx.saved_tensors 
        grad_input = grad_output.clone()
         ## derivative of arctan (scaled)
        grad_input = grad_input * 1 / (1 + gamma * (x_input - Vth)**2)
        return grad_input

class SpikeAct_signed(torch.autograd.Function): ## ternact
    @ staticmethod
    def forward(self, x):
        self.save_for_backward(x)
        x_forward = torch.clamp(torch.sign(x + Vth)+torch.sign(x - Vth), min=-1, max=1)
        return x_forward

    @ staticmethod
    def backward(self, grad_output):
        x_input, = self.saved_tensors
        grad_input = grad_output.clone()
        ## derivative of arctan (scaled)
        scale = 1 + 1/(1 + 4*Vth**2*gamma)
        grad_input = grad_input * 1/scale * (1/(1+ gamma * ((x_input - Vth)**2)) \
                                            + 1/(1+ gamma * ((x_input + Vth)**2))) 
        return grad_input


class LIFNode(nn.Module):
    def __init__(self, C:int, decay_init:float=0.9, v_th:float=1.):
        super(LIFNode, self).__init__()
        self.decay = nn.Parameter(torch.rand([1, C, 1, 1]))
        self.decay.data.fill_(decay_init)
        self.v_th = v_th

    def forward(self, x_seq):
        torch.clamp_(self.decay.data, 0., 1.)
        v = 0.
        s = 0.
        s_seq = []
        for t in range(x_seq.shape[0]):
            v = v * self.decay + x_seq[t] - s * self.v_th
            s = SpikeAct.apply(v, 10., self.v_th)
            s_seq.append(s)
        return torch.stack(s_seq)


# class ATan(torch.autograd.Function):
#     @torch.jit.script
#     def spikeact_forward(x_input: torch.Tensor):
#         return torch.ge(x_input, 0).to(x_input)

#     @torch.jit.script
#     def spikeact_backward(grad_output: torch.Tensor, alpha: float, x_input: torch.Tensor):
#         return grad_output * 1 / (1 + alpha * (x_input**2))

#     @staticmethod
#     def forward(ctx, x_input: torch.Tensor, alpha: float=15.):
#         ctx.save_for_backward(x_input)
#         ctx.alpha = alpha
#         output = ATan.spikeact_forward(x_input)
#         return output

#     @staticmethod
#     def backward(ctx, grad_output):
#         x_input, =  ctx.saved_tensors
#         return ATan.spikeact_backward(grad_output, ctx.alpha, x_input), None


@torch.jit.script
def atan_backward(grad_output: torch.Tensor, x: torch.Tensor, alpha: float):
    return 1. / (1. + torch.pow(x * alpha, 2)) * grad_output, None

class atan(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.alpha = alpha
        return surrogate.heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        return atan_backward(grad_output, ctx.saved_tensors[0], ctx.alpha)


class ATan(surrogate.SurrogateFunctionBase):
    def __init__(self, alpha=2.0, spiking=True):
        super().__init__(alpha, spiking)

    @staticmethod
    def spiking_function(x, alpha):
        return atan.apply(x, alpha)

    @staticmethod
    def backward(grad_output, x, alpha):
        return atan_backward(grad_output, x, alpha)[0]
    
class SCNNlayer(nn.Module):
    """ spiking 2D (or 3D if conv3d=True) convolution layer
        ann mode if ann=True
    """
    def __init__(self, height, width, in_channels, out_channels, kernel_size, dilation, stride, padding, useBN, conv3d=True, dilation_sn=1):
        super(SCNNlayer, self).__init__()
        self.conv3d = conv3d
        self.height = height
        self.width = width
        self.sn = mul_free_channel_wise_psn(C=out_channels, K=K, surrogate_function=surrogate.ATan(2.), dilation=dilation_sn)
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.useBN = useBN

        if self.conv3d:
            if self.useBN:
                self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=1, bias=False, padding_mode='zeros')
                self.bn = nn.BatchNorm3d(out_channels)
            else:
                self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=1, bias=True, padding_mode='zeros')
        else:
            if self.useBN:
                self.bn = nn.BatchNorm2d(out_channels)
                self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=1, bias=False, padding_mode='zeros')
            else:
                self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=1, bias=True, padding_mode='zeros')
            
        self.clamp()

        k = np.sqrt(6 / (in_channels*np.prod(kernel_size)))
        nn.init.uniform_(self.conv.weight, a=-k, b=k)

        if self.useBN:
            nn.init.constant_(self.bn.weight, 1)
            nn.init.constant_(self.bn.bias, 0)

    def forward(self, x):
        ## x : (T, N, Cin, Y, X)
        T = x.size(0)
        N = x.size(1)
        if self.conv3d:
            x = x.permute(1,2,0,3,4) # [N,C,T,H,W]
            conv_all = self.conv(x)
            if self.useBN:
                conv_all = self.bn(conv_all)
            conv_all = conv_all.permute(2,0,1,3,4)
        else:
            x = x.contiguous()
            x = x.view(-1, x.size(2), x.size(3), x.size(4)) #fuse T and B dim
            conv_all = self.conv(x)
            if self.useBN:
                conv_all = self.bn(conv_all)
            conv_all = conv_all.view(T, N, self.out_channels, self.height, self.width)

        outputs = self.sn(conv_all)
            
        return outputs

    def clamp(self):
        pass


class SBasicBlock(nn.Module):
    """ Spiking Resnet basic block
        ann mode if ann=True
    """ 
    def __init__(self, height, width, in_channels, out_channels, kernel_size, dilation, stride, padding, useBN, dilation_sn1, dilation_sn2):
        super(SBasicBlock, self).__init__()
        self.height = height
        self.width = width

        self.sn1 = mul_free_channel_wise_psn(C=out_channels, K=K, surrogate_function=surrogate.ATan(2.), dilation=dilation_sn1)
        self.sn2 = mul_free_channel_wise_psn(C=out_channels, K=K, surrogate_function=surrogate.ATan(2.), dilation=dilation_sn2)
        
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.useBN = useBN
        
        if self.useBN:
            self.bn1 = nn.BatchNorm2d(out_channels)
            self.bn2 = nn.BatchNorm2d(out_channels)
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=False)
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, (1, 1), padding=padding, bias=False)
        else:
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True)
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, (1, 1), padding=padding, bias=True)
        
        k1 = np.sqrt(6 /(self.in_channels*np.prod(self.kernel_size)))
        k2 = np.sqrt(6 /(self.out_channels*np.prod(self.kernel_size)))
        nn.init.uniform_(self.conv1.weight, a=-k1, b=k1)
        nn.init.uniform_(self.conv2.weight, a=-k2, b=k2)

        if self.useBN:
            nn.init.constant_(self.bn1.weight, 1)
            nn.init.constant_(self.bn1.bias, 0)
            nn.init.constant_(self.bn2.weight, 1)
            nn.init.constant_(self.bn2.bias, 0)

        if self.stride != (1,1):
            if self.useBN:
                self.downsample = nn.Conv2d(in_channels, out_channels, (1, 1), stride, padding=(0,0), bias=False)
                self.bn3 = nn.BatchNorm2d(out_channels)
                nn.init.constant_(self.bn3.weight, 1)
                nn.init.constant_(self.bn3.bias, 0)
            else:
                self.downsample = nn.Conv2d(in_channels, out_channels, (1, 1), stride, padding=(0,0), bias=True)
            k3 = np.sqrt(6 /(self.in_channels)) # kernel_size == (1,1)
            nn.init.uniform_(self.downsample.weight, a=-k3, b=k3)

        self.clamp()


    def forward(self, x):
        ## x : (T, N, Cin, Y, X)
        T = x.size(0)
        N = x.size(1)

        x = x.contiguous()
        identity = x.view(-1, x.size(2), x.size(3), x.size(4)) #fuse T and B dim

        conv1 = self.conv1(identity)
        if self.useBN:
            conv1 = self.bn1(conv1)
        conv1 = conv1.view(T, N, conv1.size(1), conv1.size(2), conv1.size(3))

        outputs1 = self.sn1(conv1)
        input2 = outputs1.view(-1, outputs1.size(2), outputs1.size(3), outputs1.size(4)) #fuse T and B dim
        
        conv2 = self.conv2(input2)
        if self.useBN:
            conv2 = self.bn2(conv2)

        if self.stride != (1,1):
            identity = self.downsample(identity)
            if self.useBN:
                identity = self.bn3(identity)
        conv_all = conv2 + identity
        conv_all = conv_all.view(T, N, conv_all.size(1), conv_all.size(2), conv_all.size(3))
        
        outputs = self.sn2(conv_all)

        return outputs, outputs1

    def clamp(self):
        pass


class SFCLayer(nn.Module):
    """ leaky integrator layer. if stateful=True, implement the stateful synapse version of the leaky integrator
        ann mode (=simple fully connected layer) if ann=True
    """
    def __init__(self, in_size, out_size, stateful=True):
        super(SFCLayer, self).__init__()
        self.in_size = in_size
        self.out_size = out_size
        self.dense = nn.Linear(in_size, out_size, bias=True)
        self.stateful = stateful
        self.alpha = nn.Parameter(torch.zeros(out_size).uniform_(alpha_init_gru, alpha_init_gru))
        if stateful:
            self.beta = nn.Parameter(torch.zeros(out_size).uniform_(alpha_init_gru, alpha_init_gru))

    def forward(self, x):
        # X : (B, T, N)
        T = x.size(0)
        N = x.size(1)
        outputs = torch.zeros((T, N, self.out_size), device = x.device)
        potential = torch.zeros((N, self.out_size), device = x.device)
        current = torch.zeros((N, self.out_size), device = x.device)

        if self.stateful:
            for t in range(T):
                out = self.dense(x[t,:,:])
                current = self.beta * current + out
                potential = self.alpha * potential + (1 - self.alpha) * current
                outputs[t,:,:] = potential
        else: 
            for t in range(T):
                out = self.dense(x[t,:,:])
                potential = self.alpha * potential + out
                outputs[t,:,:] = potential

        return outputs

    def clamp(self):
        self.alpha.data.clamp_(0.,1.)


class SAdaptiveAvgPool2d(nn.Module):
    """ spiking adaptive avg pool 2d
        ann mode if ann=True
    """
    def __init__(self, kernel_size):
        super(SAdaptiveAvgPool2d, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(kernel_size)
        self.kernel_size = kernel_size
        self.spikeact = SpikeAct.apply


    def forward(self, x):
        # x: (B, T, Cin, Y, X)
        T = x.size(0)
        N = x.size(1)
        Cin = x.size(2)
        out = torch.zeros((T, N, Cin, self.kernel_size[0], self.kernel_size[0]), device = x.device)
        potential = torch.zeros((N, Cin, self.kernel_size[0], self.kernel_size[1]), device = x.device)
        output_prev = torch.zeros_like(potential)

        x = x.contiguous()
        x = x.view(-1, x.size(2), x.size(3), x.size(4)) #fuse T and B dim
        pool = self.avgpool(x)
        pool = pool.view(T, N, pool.size(1), pool.size(2), pool.size(3))
        return pool

        for t in range(T):
            potential = potential + pool[t,:,:,:,:] - Vth * output_prev #IF neuron
            output_prev = self.spikeact(potential)
            out[t,:,:,:,:] = output_prev
        return out

class SAvgPool2d(nn.Module):
    """ spiking avg pool 2d
        ann mode if ann=True
    """
    def __init__(self, kernel, stride, padding, out_size, channel_in,):
        super(SAvgPool2d, self).__init__()
        self.avgpool = nn.AvgPool2d(kernel, stride=stride, padding=padding)
        self.out_size = out_size
        self.spikeact = SpikeAct.apply


    def forward(self, x):
        T = x.size(0)
        N = x.size(1)
        Cin = x.size(2)
        out = torch.zeros((T, N, Cin, self.out_size, self.out_size), device = x.device)
        potential = torch.zeros((N, Cin, self.out_size, self.out_size), device = x.device)
        output_prev = torch.zeros_like(potential)

        x = x.contiguous()
        x = x.view(-1, x.size(2), x.size(3), x.size(4)) #fuse T and B dim
        pool = self.avgpool(x)
        pool = pool.view(T, N, pool.size(1), pool.size(2), pool.size(3))
        return pool

        for t in range(T):
            potential = potential + pool[t,:,:,:,:] - Vth * output_prev #IF neuron
            output_prev = self.spikeact(potential)
            out[t,:,:,:,:] = output_prev
        return out


class GRUlayer(nn.Module):
    """ 
        spiking GRU layer
        ann mode if ann=True
        SpikGRU2+ if twogates=True and ternact=True
    """
    def __init__(self, input_size, hidden_size, ternact=True, twogates=True):
        super(GRUlayer, self).__init__()
        self.twogates = twogates
        self.hidden_size = hidden_size
        self.wz = nn.Linear(input_size, hidden_size, bias=True)
        self.wi = nn.Linear(input_size, hidden_size, bias=True)
        self.uz = nn.Linear(hidden_size, hidden_size, bias=False)
        self.ui = nn.Linear(hidden_size, hidden_size, bias=True)
        if self.twogates:
            self.wr = nn.Linear(input_size, hidden_size, bias=True)
            self.ur = nn.Linear(hidden_size, hidden_size, bias=False)
        self.alpha = nn.Parameter(torch.zeros(hidden_size).uniform_(alpha_init_gru, alpha_init_gru))
        self.clamp()
        self.spikeact = SpikeAct_signed.apply

        k_ff = np.sqrt(1./hidden_size)
        k_rec = np.sqrt(1./hidden_size)
        nn.init.uniform_(self.wi.weight, a=-k_ff, b=k_ff)
        nn.init.uniform_(self.wz.weight, a=-k_ff, b=k_ff)
        nn.init.uniform_(self.ui.weight, a=-k_rec, b=k_rec)
        nn.init.uniform_(self.uz.weight, a=-k_rec, b=k_rec)
        nn.init.uniform_(self.wi.bias, a=-k_ff, b=k_ff)
        nn.init.uniform_(self.wz.bias, a=-k_ff, b=k_ff)
        if self.twogates:
            nn.init.uniform_(self.wr.weight, a=-k_ff, b=k_ff)
            nn.init.uniform_(self.ur.weight, a=-k_rec, b=k_rec)
            nn.init.uniform_(self.wr.bias, a=-k_ff, b=k_ff)


    def forward(self, x):
        T = x.size(0)
        N = x.size(1)
        outputs = torch.zeros((T, N, self.hidden_size), device = x.device)
        output_prev = torch.zeros((N, self.hidden_size), device = x.device)
        temp = torch.zeros_like(output_prev)
        tempcurrent = torch.zeros_like(output_prev)

        for t in range(T): 
            
            tempZ = torch.sigmoid(self.wz(x[t,:,:]) + self.uz(output_prev)) 
            if self.twogates:
                tempR = torch.sigmoid(self.wr(x[t,:,:]) + self.ur(output_prev))
                tempcurrent = self.alpha * tempcurrent + self.wi(x[t,:,:]) + self.ui(output_prev) * tempR
            else:
                tempcurrent = self.alpha * tempcurrent + self.wi(x[t,:,:]) + self.ui(output_prev)
                
            temp = tempZ * temp + (1 - tempZ) * tempcurrent - Vth * output_prev
            output_prev = self.spikeact(temp)

            outputs[t,:,:] = output_prev

        return outputs

    def clamp(self):
        self.alpha.data.clamp_(0.,1.)

class LiGRU(nn.Module):
    """ 3-layer bidrectionnal GRU backend
        ann mode if ann=True
    """
    def __init__(self, twogates, num_layers, bidirectional, dropout, input_size, hidden_size):
        super(LiGRU, self).__init__()
        self.hidden_size = hidden_size
        self.bidirectional = bidirectional

        if num_layers != 3:
            print("Error in LiGRU: only defined with 3 layers")

        if self.bidirectional:
            self.grulayer1 = GRUlayer(input_size, hidden_size, twogates=twogates)
            self.grulayer2 = GRUlayer(hidden_size * 2, hidden_size, twogates=twogates)
            self.grulayer3 = GRUlayer(hidden_size * 2, hidden_size, twogates=twogates)
            self.grulayer1_b = GRUlayer(input_size, hidden_size, twogates=twogates)
            self.grulayer2_b = GRUlayer(hidden_size * 2, hidden_size, twogates=twogates)
            self.grulayer3_b = GRUlayer( hidden_size * 2, hidden_size, twogates=twogates)
        else:
            self.grulayer1 = GRUlayer(input_size, hidden_size, twogates=twogates)
            self.grulayer2 = GRUlayer(hidden_size, hidden_size, twogates=twogates)
            self.grulayer3 = GRUlayer(hidden_size, hidden_size, twogates=twogates)
            
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        # x: [T, N, *]
        if self.bidirectional:
            x_b = torch.flip(x, [0])
            out1 = self.grulayer1(x)
            out1_b = self.grulayer1_b(x_b)
            out2 = self.grulayer2(self.dropout(torch.cat((out1, torch.flip(out1_b, [0])), 2)))
            out2_b = self.grulayer2_b(self.dropout(torch.cat((torch.flip(out1, [0]), out1_b), 2)))
            out3 = self.grulayer3(self.dropout(torch.cat((out2, torch.flip(out2_b, [0])), 2)))
            out3_b = self.grulayer3_b(self.dropout(torch.cat((torch.flip(out2, [0]), out2_b), 2)))
            outputs = torch.cat((out3, out3_b), 2)
        else:
            out1 = self.grulayer1(x)
            out2 = self.grulayer2(self.dropout(out1))
            out3 = self.grulayer3(self.dropout(out2))
            outputs = out3
        if self.bidirectional:
            return outputs, torch.cat((out2, out2_b), 2), torch.cat((out1, out1_b), 2)
        else:
            return outputs, out2, out1

    def clamp(self):
        self.grulayer1.clamp()
        self.grulayer2.clamp()
        self.grulayer3.clamp()
        if self.bidirectional:
            self.grulayer1_b.clamp()
            self.grulayer2_b.clamp()
            self.grulayer3_b.clamp()