import torch
import torch.nn as nn
import models.sn as sn
class incept(nn.Module):
    def __init__(self,input_filters, ch1_filters, ch3_filters):
        super().__init__()
        self.ch1_output=nn.Sequential(
                nn.Conv2d(input_filters, ch1_filters, bias=False, kernel_size=(1,1)),
                nn.ReLU(inplace=True)
            )
        self.ch3_output=nn.Sequential(
                nn.Conv2d(input_filters, ch3_filters, bias=False, kernel_size=(3,3),padding=1),
                nn.ReLU(inplace=True)
            )
    def forward(self, x):
        ch1_output=self.ch1_output(x)
        ch3_output=self.ch3_output(x)
        x=torch.cat([ch1_output,ch3_output],1)

        return x


class downsample(nn.Module):
    def __init__(self,input_filters, ch3_filters):
        super().__init__()
        self.ch1_output=nn.Sequential(
                nn.Conv2d(input_filters, ch3_filters, bias=False, kernel_size=(3,3),stride=2,padding=1),
                nn.ReLU(inplace=True)
            )
        self.pool_output=nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    def forward(self, x):
        ch1_output=self.ch1_output(x)
        pool_output=self.pool_output(x)
        # ch3_output=self.ch3_output(x)
        # x=torch.cat([ch1_output,ch3_output],1)

        return torch.cat([ch1_output,pool_output],1)

class inception(nn.Module):
    def __init__(self,num_class):
        super().__init__()
        self.layer1=nn.Sequential(
                nn.Conv2d(3, 96, bias=False, kernel_size=(3,3),stride=1,padding=1),
                nn.ReLU(inplace=True)
            )
        self.layer2=incept(96,32,32)
        self.layer3=incept(32+32,32,48)
        self.layer4=downsample(32+48,80)

        self.layer5=incept(32+48+80,112,48)
        self.layer6=incept(112+48,96,64)
        self.layer7=incept(96+64,80,80)
        self.layer8=incept(80+80,48,96)
        self.layer9=downsample(48+96,96)

        self.layer10=incept(48+96+96,176,160)
        self.layer11=incept(176+160,176,160)

        self.layer12=torch.nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(176+160, num_class)
    def forward(self,x):
        x=self.layer1(x)
        x=self.layer2(x)
        x=self.layer3(x)
        x=self.layer4(x)
        x=self.layer5(x)
        x=self.layer6(x)
        x=self.layer7(x)
        x=self.layer8(x)
        x=self.layer9(x)
        x=self.layer10(x)
        x=self.layer11(x)
        x=self.layer12(x)
        x = x.view(x.size(0), -1)
        x=self.fc(x)
        return x


class incept_bn(nn.Module):
    def __init__(self,input_filters, ch1_filters, ch3_filters):
        super().__init__()
        self.ch1_output=nn.Sequential(
                nn.Conv2d(input_filters, ch1_filters, bias=False, kernel_size=(1,1)),
                nn.BatchNorm2d(ch1_filters),
                nn.ReLU(inplace=True)
            )
        self.ch3_output=nn.Sequential(
                nn.Conv2d(input_filters, ch3_filters, bias=False, kernel_size=(3,3),padding=1),
                nn.BatchNorm2d(ch3_filters),
                nn.ReLU(inplace=True)
            )
    def forward(self, x):
        ch1_output=self.ch1_output(x)
        ch3_output=self.ch3_output(x)
        x=torch.cat([ch1_output,ch3_output],1)

        return x


class downsample_bn(nn.Module):
    def __init__(self,input_filters, ch3_filters):
        super().__init__()
        self.ch1_output=nn.Sequential(
                nn.Conv2d(input_filters, ch3_filters, bias=False, kernel_size=(3,3),stride=2,padding=1),
                nn.BatchNorm2d(ch3_filters),
                nn.ReLU(inplace=True)
            )
        self.pool_output=nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    def forward(self, x):
        ch1_output=self.ch1_output(x)
        pool_output=self.pool_output(x)
        # ch3_output=self.ch3_output(x)
        # x=torch.cat([ch1_output,ch3_output],1)

        return torch.cat([ch1_output,pool_output],1)

class inception_bn(nn.Module):
    def __init__(self,num_class):
        super().__init__()
        self.layer1=nn.Sequential(
                nn.Conv2d(3, 96, bias=False, kernel_size=(3,3),stride=1,padding=1),
                nn.BatchNorm2d(96),
                nn.ReLU(inplace=True)
            )
        self.layer2=incept_bn(96,32,32)
        self.layer3=incept_bn(32+32,32,48)
        self.layer4=downsample_bn(32+48,80)

        self.layer5=incept_bn(32+48+80,112,48)
        self.layer6=incept_bn(112+48,96,64)
        self.layer7=incept_bn(96+64,80,80)
        self.layer8=incept_bn(80+80,48,96)
        self.layer9=downsample_bn(48+96,96)

        self.layer10=incept_bn(48+96+96,176,160)
        self.layer11=incept_bn(176+160,176,160)

        self.layer12=torch.nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(176+160, num_class)
    def forward(self,x):
        x=self.layer1(x)
        x=self.layer2(x)
        x=self.layer3(x)
        x=self.layer4(x)
        x=self.layer5(x)
        x=self.layer6(x)
        x=self.layer7(x)
        x=self.layer8(x)
        x=self.layer9(x)
        x=self.layer10(x)
        x=self.layer11(x)
        x=self.layer12(x)
        x = x.view(x.size(0), -1)
        x=self.fc(x)
        return x


def add_sn(m,beta):
    for name, layer in m.named_children():
        m.add_module(name, add_sn(layer,beta=beta))
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        return sn.spectral_norm(m,beta=beta)
    else:
        return m


def inception_sn(num_class,beta):
    return add_sn(inception(num_class),beta=beta)

def inception_bn_sn(num_class,beta):
    return add_sn(inception_bn(num_class),beta=beta)

# model=inception_bn_sn(10,1).cuda()
# from torchsummary import summary
# summary(model, input_size=(3, 32, 32),batch_size=128)