import torch.nn as nn
from hyper_params import Z_DIM, CHANNEL, CLASSES, Z_Dim

class block(nn.Module):
    def __init__(
        self, in_channels, out_channels, num_convs, kernel=(3,3), final_stride=2
    ):
        super(block, self).__init__()
        if in_channels == out_channels:
            self.upsampler = None
        else:
            self.upsampler = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel, padding=1, stride=final_stride, bias=False), 
                                           nn.BatchNorm2d(out_channels))

        layers = []
        for i in range(num_convs):
            if in_channels == out_channels:
                layers.append(nn.Conv2d(in_channels, in_channels, kernel_size=kernel, stride=1, padding=1, bias=False))
                layers.append(nn.BatchNorm2d(in_channels))
            elif i == num_convs-1:
                layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=final_stride, padding=1, bias=False))
                layers.append(nn.BatchNorm2d(out_channels))
            else:
                layers.append(nn.Conv2d(in_channels, in_channels, kernel_size=kernel, stride=1, padding=1, bias=False))
                layers.append(nn.BatchNorm2d(in_channels))
            layers.append(nn.ReLU())

        self.convolutions = nn.Sequential(*layers)
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x.clone()
        if not (self.upsampler == None):
            identity = self.upsampler(identity)
            identity = self.relu(identity)
        x = self.convolutions(x)
        x = x + identity
        x = self.relu(x)
        return x

class ResNet18(nn.Module):
    def __init__(self, c_in=CHANNEL, c_out=Z_DIM, z=Z_Dim, classes=CLASSES):
        super(ResNet18, self).__init__()

        self.mc = max(1, int(0.5*Z_DIM))
        self.conv1 = nn.Conv2d(c_in, self.mc, kernel_size=(3,3), padding=1, bias=False)
        self.layers = self.make_layers(self.mc, c_out)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(z, classes)
    
    def make_layers(self, cin, cout):
        layers = []
        
        # Creating level-2 residual blocks:
        layers.append(block(cin, cin, num_convs=2))
        layers.append(block(cin, cin, num_convs=2))

        # Creating level-3 residual blocks:
        layers.append(block(cin, cin, num_convs=2))
        layers.append(block(cin, cin, num_convs=2))

        # Creating level-4 residual blocks:
        layers.append(block(cin, cout, num_convs=2))
        layers.append(block(cout, cout, num_convs=2))

        # Creating level-5 residual blocks:
        layers.append(block(cout, cout, num_convs=2))
        layers.append(block(cout, cout, num_convs=2))

        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        z = self.layers(x)
        z = z.reshape(z.shape[0], -1)
        y = self.fc(z)
        return y, z
    
    def freeze_all_params(self):
        for param in self.parameters():
            param.requires_grad_(False)

class ResNet34(ResNet18):
    def __init__(self, c_in=CHANNEL, c_out=Z_DIM, z=Z_Dim, classes=CLASSES):
        super(ResNet34, self).__init__()
        self.mc = max(1, int(0.5*Z_DIM))
        self.conv1 = nn.Conv2d(c_in, self.mc, kernel_size=(3,3), padding=1, bias=False)
        self.layers = self.make_layers(self.mc, c_out)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(z, classes)
    

    def make_layers(self, cin, cout):
        layers = []
        
        # Creating level-2 residual blocks:
        for _ in range(3):
            layers.append(block(cin, cin, num_convs=2))

        # Creating level-3 residual blocks:
        for _ in range(4):
            layers.append(block(cin, cin, num_convs=2))

        # Creating level-4 residual blocks:
        layers.append(block(cin, cout, num_convs=2))
        for _ in range(5):
            layers.append(block(cout, cout, num_convs=2))

        # Creating level-5 residual blocks:
        for _ in range(3):
            layers.append(block(cout, cout, num_convs=2))

        return nn.Sequential(*layers)

class ResNet50(ResNet18):
    def __init__(self, c_in=CHANNEL, c_out=Z_DIM, z=Z_Dim, classes=CLASSES):
        super(ResNet50, self).__init__()
        self.mc = max(1, int(0.5*Z_DIM))
        self.conv1 = nn.Conv2d(c_in, self.mc, kernel_size=(3,3), padding=1, bias=False)
        self.layers = self.make_layers(self.mc, c_out)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(z, classes)
    

    def make_layers(self, cin, cout):
        layers = []
        
        # Creating level-2 residual blocks:
        for _ in range(3):
            layers.append(block(cin, cin, num_convs=3))

        # Creating level-3 residual blocks:
        for _ in range(4):
            layers.append(block(cin, cin, num_convs=3))

        # Creating level-4 residual blocks:
        layers.append(block(cin, cout, num_convs=3))
        for _ in range(5):
            layers.append(block(cout, cout, num_convs=3))

        # Creating level-5 residual blocks:
        for _ in range(3):
            layers.append(block(cout, cout, num_convs=3))

        return nn.Sequential(*layers) 

class ResNet101(ResNet18):
    def __init__(self, c_in=CHANNEL, c_out=Z_DIM, z=Z_Dim, classes=CLASSES):
        super(ResNet101, self).__init__()
        self.mc = max(1, int(0.5*Z_DIM))
        self.conv1 = nn.Conv2d(c_in, self.mc, kernel_size=(3,3), padding=1, bias=False)
        self.layers = self.make_layers(self.mc, c_out)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(z, classes)
    

    def make_layers(self, cin, cout):
        layers = []
        
        # Creating level-2 residual blocks:
        for _ in range(3):
            layers.append(block(cin, cin, num_convs=3))

        # Creating level-3 residual blocks:
        for _ in range(4):
            layers.append(block(cin, cin, num_convs=3))

        # Creating level-4 residual blocks:
        layers.append(block(cin, cout, num_convs=3))
        for _ in range(22):
            layers.append(block(cout, cout, num_convs=3))

        # Creating level-5 residual blocks:
        for _ in range(3):
            layers.append(block(cout, cout, num_convs=3))

        return nn.Sequential(*layers) 
    
class ResNet152(ResNet18):
    def __init__(self, c_in=CHANNEL, c_out=Z_DIM, z=Z_Dim, classes=CLASSES):
        super(ResNet152, self).__init__()
        self.mc = max(1, int(0.5*Z_DIM))
        self.conv1 = nn.Conv2d(c_in, self.mc, kernel_size=(3,3), padding=1, bias=False)
        self.layers = self.make_layers(self.mc, c_out)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(z, classes)
    

    def make_layers(self, cin, cout):
        layers = []
        
        # Creating level-2 residual blocks:
        for _ in range(3):
            layers.append(block(cin, cin, num_convs=3))

        # Creating level-3 residual blocks:
        for _ in range(8):
            layers.append(block(cin, cin, num_convs=3))

        # Creating level-4 residual blocks:
        layers.append(block(cin, cout, num_convs=3))
        for _ in range(35):
            layers.append(block(cout, cout, num_convs=3))

        # Creating level-5 residual blocks:
        for _ in range(3):
            layers.append(block(cout, cout, num_convs=3))

        return nn.Sequential(*layers) 


def build_model(cid, device='cuda'):
    z = int(cid) % 5
    if z == 0:
        return ResNet18().to(device)
    if z == 1:
        return ResNet34().to(device)
    if z == 2:
        return ResNet50().to(device)
    if z == 3:
        return ResNet101().to(device)
    return ResNet152().to(device)