import torch
import torch.nn as nn
import models.sn as sn

class alexnet(nn.Module):
    def __init__(self,num_class):
        super().__init__()
        self.layer1=nn.Sequential(
            nn.Conv2d(3, 96, bias=False, kernel_size=(5,5),padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.layer2=nn.Sequential(
            nn.Conv2d(96,256, bias=False, kernel_size=(5,5),padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.layer3=nn.Sequential(
            nn.Linear(8*8*256, 384),
            nn.ReLU(inplace=True),
            nn.Linear(384,192),
            nn.ReLU(inplace=True),
            nn.Linear(192, num_class),

        )
       
    def forward(self,x):
        x=self.layer1(x)
        x=self.layer2(x)
        
        x = x.view(x.size(0), -1)
        x=self.layer3(x)
        return x


def add_sn(m,beta):
    for name, layer in m.named_children():
        m.add_module(name, add_sn(layer,beta=beta))
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        return sn.spectral_norm(m,beta=beta)
    else:
        return m


def alexnet_sn(num_class,beta):
    return add_sn(alexnet(num_class),beta=beta)

