import torch
import torch.nn as nn
# from spikingjelly.clock_driven import layer
from spikingjelly.activation_based import functional, layer, neuron, surrogate
import numpy as np
from layers import *

__all__ = ['SpikingResNet', 'spiking_resnet18', 'spiking_resnet34', 'spiking_resnet50', 'spiking_resnet101',
           'spiking_resnet152']

Vth = 1.0
alpha_init_gru = 0.9
alpha_init_conv = 0.9
gamma = 0.1

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class SCNN(nn.Module): 
    """ RESNET-18 frontend + 3 BiGRU backend
        if front=True: only using frontend (with FC backend)
    """
    def __init__(self, front, num_classes=100, useBN=True, twogates=True, bidirectional=True):
        super(SCNN, self).__init__()

        Cin1 = 2
        Cin2 = 64
        Cin3 = 2*Cin2; Cin4 = 2*Cin3; Cin5 = 2*Cin4
        kernel_size_in = (7,7)
        kernel_size = (3,3)
        padding_in = (3,3)
        padding = (1,1)
        dilatation = (1,1)
        stride1 = (1,1)
        stride2 = (2,2)
        output_shape = Cin5 * 1 * 1
        gru_hidden_size = 1024

        kernel_size_3d = (5,7,7)
        stride3d = (1,2,2)
        padding3d = (2,3,3)
        dilatation3d = (1,1,1)

        self.layer1 = SCNNlayer(44, 44, Cin1, Cin2, kernel_size_3d, dilatation3d, stride3d, padding3d, useBN=useBN, dilation_sn=1)
        self.avgpool = SAvgPool2d((3,3),(2,2),(1,1), 22, Cin2)
        self.layer2_2 = SBasicBlock(22, 22, Cin2, Cin2, kernel_size, dilatation, stride1, padding, useBN=useBN, dilation_sn1=2, dilation_sn2=3)
        self.layer2_1 = SBasicBlock(22, 22, Cin2, Cin2, kernel_size, dilatation, stride1, padding, useBN=useBN, dilation_sn1=1, dilation_sn2=2)
        self.layer3_1 = SBasicBlock(11, 11, Cin2, Cin3, kernel_size, dilatation, stride2, padding, useBN=useBN, dilation_sn1=3, dilation_sn2=1)
        self.layer3_2 = SBasicBlock(11, 11, Cin3, Cin3, kernel_size, dilatation, stride1, padding, useBN=useBN, dilation_sn1=2, dilation_sn2=3)
        self.layer4_1 = SBasicBlock(6, 6, Cin3, Cin4, kernel_size, dilatation, stride2, padding, useBN=useBN, dilation_sn1=1, dilation_sn2=2)
        self.layer4_2 = SBasicBlock(6, 6, Cin4, Cin4, kernel_size, dilatation, stride1, padding, useBN=useBN, dilation_sn1=3, dilation_sn2=1)
        self.layer5_1 = SBasicBlock(3, 3, Cin4, Cin5, kernel_size, dilatation, stride2, padding, useBN=useBN, dilation_sn1=2, dilation_sn2=3)
        self.layer5_2 = SBasicBlock(3, 3, Cin5, Cin5, kernel_size, dilatation, stride1, padding, useBN=useBN, dilation_sn1=1, dilation_sn2=2)
        self.adaptavgpool = SAdaptiveAvgPool2d((1,1))
        self.flatten = nn.Flatten(start_dim=2, end_dim=-1) # we do not want to flatten through T (timesteps) dim
        self.dropout = nn.Dropout(p=0.5)
        
        self.front = front

        if self.front:
            self.dense = SFCLayer(output_shape, num_classes, stateful=True) 
        else:
            self.gru = LiGRU(twogates, 3, bidirectional, 0.2, output_shape, gru_hidden_size)
            self.dense = SFCLayer(gru_hidden_size*2, num_classes, stateful=False)   
        functional.set_step_mode(self, step_mode='m')


    def clamp(self):
        self.layer1.clamp()
        self.layer2_1.clamp()
        self.layer2_2.clamp()
        self.layer3_1.clamp()
        self.layer3_2.clamp()
        self.layer4_1.clamp()
        self.layer4_2.clamp()
        self.layer5_1.clamp()
        self.layer5_2.clamp()
        if not self.front:
            self.gru.clamp()
        self.dense.clamp()

    def forward(self, x):
        # In: (N, T, Cin, X, Y)
        x = x.permute(1, 0, 2, 3, 4)
        out1 = self.layer1(x)
        out1pool = self.avgpool(out1)
        out2_1, out2_11 = self.layer2_1(out1pool)
        out2_2, out2_21 = self.layer2_2(out2_1)
        out3_1, out3_11 = self.layer3_1(out2_2)
        out3_2, out3_21 = self.layer3_2(out3_1)
        out4_1, out4_11 = self.layer4_1(out3_2)
        out4_2, out4_21 = self.layer4_2(out4_1)
        out5_1, out5_11 = self.layer5_1(out4_2)
        out5_2, out5_21 = self.layer5_2(out5_1)
        out5_2pool = self.adaptavgpool(out5_2)
        out5_2pool = self.flatten(out5_2pool)
        out5_2pool = self.dropout(out5_2pool)


        if self.front:
            out_9 = self.dropout(out5_2pool)
            out = self.dense(out_9)
        else:
            out8, out7, out6 = self.gru(out5_2pool)
            out9 = self.dropout(out8)
            out = self.dense(out9)
        
        return out.mean(0)
      