import random

from models.layers import *

class conv(nn.Module):
    def __init__(self,in_plane, out_plane, kernel_size, stride,padding, bias=True):
        super(conv, self).__init__()
        self.fwd = SeqToANNContainer(nn.Conv2d(in_plane,out_plane,kernel_size=kernel_size,stride=stride,padding=padding, bias=bias),
        nn.BatchNorm2d(out_plane))

    def forward(self,x):
        x = self.fwd(x)
        return x
class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, shortcut=None):
        super(ResidualBlock, self).__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.stride = stride
        self.conv1 = conv(in_ch, out_ch, 3, stride, 1, bias=False)
        self.neuron1 = LIFSpike()
        self.conv2 = conv(out_ch, out_ch, 3, 1, 1, bias=False)
        self.neuron2 = LIFSpike()
        self.right = shortcut

    def forward(self, input):
        out = self.conv1(input)
        out, v1 = self.neuron1(out)
        out = self.conv2(out)
        residual = input if self.right is None else self.right(input)
        out += residual
        out, v2 = self.neuron2(out)
        return out, v1 + v2



class ResNet19(nn.Module):
    def __init__(self, num_classes=10, norm=None,args=None):
        super(ResNet19, self).__init__()
        self.T = 2
        self.args = args
        if norm is not None and isinstance(norm, tuple):
            self.norm = TensorNormalization(*norm)
        else:
            self.norm = TensorNormalization((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        self.pre_conv = conv(3, 128, 3, stride=1, padding=1, bias=False)
        self.neuron1 = LIFSpike()
        self.block1 = self.make_layer(128, 128, 3, stride=1)
        self.block2 = self.make_layer(128, 256, 3, stride=2)
        self.block3 = self.make_layer(256, 512, 2, stride=2)
        self.pool = SeqToANNContainer(nn.AvgPool2d(2,2))
        self.flatten = nn.Flatten(2)
        W = 8 ##tinyimagenet
        self.fc1 = SeqToANNContainer(
            nn.Linear(512*W*W, 256),
            nn.BatchNorm1d(256)
        )
        self.fc2 = SeqToANNContainer(nn.Linear(256, num_classes))


    def make_layer(self, in_ch, out_ch, block_num, stride=1):
        shortcut = conv(in_ch, out_ch, 1, stride, 0, bias=False)
        layers = []
        layers.append(ResidualBlock(in_ch, out_ch, stride, shortcut))
        for i in range(1, block_num):
            layers.append(ResidualBlock(out_ch, out_ch))
        return nn.Sequential(*layers)

    
    def forward(self, x):
        x = self.norm(x)
        x = add_dimention(x, self.T)
        v = 0
        x = self.pre_conv(x)

        x,v1 = self.neuron1(x)
        v = v + v1

        for i in range(len(self.block1)):
            x, v1 = self.block1[i](x)
            v = v + v1

        for i in range(len(self.block2)):
            x, v1 = self.block2[i](x)
            v = v + v1

        for i in range(len(self.block3)):
            x, v1 = self.block3[i](x)
            v = v + v1

        x = self.pool(x)

        x = self.flatten(x)

        x = self.fc1(x)
        x = self.fc2(x)

        return x,v