import torch.nn as nn
import torch
import random

NUM_CLASSES = 100
CHANNELS = 3

class Resnet20(nn.Module):
    def __init__(self, in_channels=CHANNELS, outputs=NUM_CLASSES):
        super(Resnet20, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.act = nn.ReLU()
        """ The 1st Block """
        self.conv2_1 = nn.Conv2d(16, 16, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_1 = nn.BatchNorm2d(16)
        self.conv2_2 = nn.Conv2d(16, 16, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_2 = nn.BatchNorm2d(16)
        self.conv2_3 = nn.Conv2d(16, 16, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_3 = nn.BatchNorm2d(16)
        self.conv2_4 = nn.Conv2d(16, 16, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_4 = nn.BatchNorm2d(16)
        self.conv2_5 = nn.Conv2d(16, 16, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_5 = nn.BatchNorm2d(16)
        self.conv2_6 = nn.Conv2d(16, 16, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_6 = nn.BatchNorm2d(16)

        """ The 2nd Block """
        self.conv3_0 = nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=2, bias=False)  # Downsample
        self.bn3_0 = nn.BatchNorm2d(32)

        self.conv3_1 = nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=2, bias=False)
        self.bn3_1 = nn.BatchNorm2d(32)
        self.conv3_2 = nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_2 = nn.BatchNorm2d(32)
        self.conv3_3 = nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_3 = nn.BatchNorm2d(32)
        self.conv3_4 = nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_4 = nn.BatchNorm2d(32)
        self.conv3_5 = nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_5 = nn.BatchNorm2d(32)
        self.conv3_6 = nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_6 = nn.BatchNorm2d(32)

        """ The 3rd Block """
        self.conv4_0 = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2, bias=False)  # Downsample
        self.bn4_0 = nn.BatchNorm2d(64)
        self.conv4_1 = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2, bias=False)
        self.bn4_1 = nn.BatchNorm2d(64)
        self.conv4_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_2 = nn.BatchNorm2d(64)
        self.conv4_3 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_3 = nn.BatchNorm2d(64)
        self.conv4_4 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_4 = nn.BatchNorm2d(64)
        self.conv4_5 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_5 = nn.BatchNorm2d(64)
        self.conv4_6 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_6 = nn.BatchNorm2d(64)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, outputs, bias=False)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        
        identity = x
        x = self.conv2_1(x)
        x = self.bn2_1(x)
        x = self.act(x)
        x = self.conv2_2(x)
        x = self.bn2_2(x)
        x += identity
        x = self.act(x)
        identity = x
        x = self.conv2_3(x)
        x = self.bn2_3(x)
        x = self.act(x)
        x = self.conv2_4(x)
        x = self.bn2_4(x)
        x += identity
        x = self.act(x)
        identity = x
        x = self.conv2_5(x)
        x = self.bn2_5(x)
        x = self.act(x)
        x = self.conv2_6(x)
        x = self.bn2_6(x)
        x += identity
        x = self.act(x)
        
        identity = x
        identity = self.conv3_0(identity)
        identity = self.bn3_0(identity)
        x = self.conv3_1(x)
        x = self.bn3_1(x)
        x = self.act(x)
        x = self.conv3_2(x)
        x = self.bn3_2(x)
        x += identity
        x = self.act(x)
        identity = x
        x = self.conv3_3(x)
        x = self.bn3_3(x)
        x = self.act(x)
        x = self.conv3_4(x)
        x = self.bn3_4(x)
        x += identity
        x = self.act(x)
        identity = x
        x = self.conv3_5(x)
        x = self.bn3_5(x)
        x = self.act(x)
        x = self.conv3_6(x)
        x = self.bn3_6(x)
        x += identity
        x = self.act(x)
        
        identity = x
        identity = self.conv4_0(identity)
        identity = self.bn4_0(identity)
        x = self.conv4_1(x)
        x = self.bn4_1(x)
        x = self.act(x)
        x = self.conv4_2(x)
        x = self.bn4_2(x)
        x += identity
        x = self.act(x)
        identity = x
        x = self.conv4_3(x)
        x = self.bn4_3(x)
        x = self.act(x)
        x = self.conv4_4(x)
        x = self.bn4_4(x)
        x += identity
        x = self.act(x)
        identity = x
        x = self.conv4_5(x)
        x = self.bn4_5(x)
        x = self.act(x)
        x = self.conv4_6(x)
        x = self.bn4_6(x)
        x += identity
        x = self.act(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

class Resnet20_approximated(Resnet20):
    def __init__(self, in_channels=CHANNELS, outputs=NUM_CLASSES, lf=0, rate=1.0):
        super(Resnet20_approximated, self).__init__()
        if lf >= 2:
            w1 = max(1, int(16*rate))
        else:
            w1 = 16
        if lf >= 3:
            w2 = max(1, int(32*rate))
        else:
            w2 = 32
        w3 = 64
        
        self.conv1 = nn.Conv2d(in_channels, w1, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(w1)
        self.act = nn.ReLU()
        """ The 1st Block """
        self.conv2_1 = nn.Conv2d(w1, w1, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_1 = nn.BatchNorm2d(w1)
        self.conv2_2 = nn.Conv2d(w1, w1, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_2 = nn.BatchNorm2d(w1)
        self.conv2_3 = nn.Conv2d(w1, w1, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_3 = nn.BatchNorm2d(w1)
        self.conv2_4 = nn.Conv2d(w1, w1, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_4 = nn.BatchNorm2d(w1)
        self.conv2_5 = nn.Conv2d(w1, w1, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_5 = nn.BatchNorm2d(w1)
        self.conv2_6 = nn.Conv2d(w1, w1, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_6 = nn.BatchNorm2d(w1) 
        
        """ The 2nd Block """
        self.conv3_0 = nn.Conv2d(w1, w2, kernel_size=3, padding=1, stride=2, bias=False)  # Downsample
        self.bn3_0 = nn.BatchNorm2d(w2)
        self.conv3_1 = nn.Conv2d(w1, w2, kernel_size=3, padding=1, stride=2, bias=False)
        self.bn3_1 = nn.BatchNorm2d(w2)
        self.conv3_2 = nn.Conv2d(w2, w2, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_2 = nn.BatchNorm2d(w2)
        self.conv3_3 = nn.Conv2d(w2, w2, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_3 = nn.BatchNorm2d(w2)
        self.conv3_4 = nn.Conv2d(w2, w2, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_4 = nn.BatchNorm2d(w2)
        self.conv3_5 = nn.Conv2d(w2, w2, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_5 = nn.BatchNorm2d(w2)
        self.conv3_6 = nn.Conv2d(w2, w2, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_6 = nn.BatchNorm2d(w2)

        """ The 3rd Block """
        self.conv4_0 = nn.Conv2d(w2, w3, kernel_size=3, padding=1, stride=2, bias=False)  # Downsample
        self.bn4_0 = nn.BatchNorm2d(w3)
        self.conv4_1 = nn.Conv2d(w2, w3, kernel_size=3, padding=1, stride=2, bias=False)
        self.bn4_1 = nn.BatchNorm2d(w3)
        self.conv4_2 = nn.Conv2d(w3, w3, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_2 = nn.BatchNorm2d(w3)
        self.conv4_3 = nn.Conv2d(w3, w3, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_3 = nn.BatchNorm2d(w3)
        self.conv4_4 = nn.Conv2d(w3, w3, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_4 = nn.BatchNorm2d(w3)
        self.conv4_5 = nn.Conv2d(w3, 64, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_5 = nn.BatchNorm2d(64)
        self.conv4_6 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_6 = nn.BatchNorm2d(64)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, outputs, bias=False)

class Resnet20_dropout(Resnet20):
    def __init__(self, in_channels=CHANNELS, outputs=NUM_CLASSES, rate=1.0):
        super(Resnet20_dropout, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, int(16*rate), kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(int(16*rate))
        self.act = nn.ReLU()
        """ The 1st Block """
        self.conv2_1 = nn.Conv2d(int(16*rate), int(16*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_1 = nn.BatchNorm2d(int(16*rate))
        self.conv2_2 = nn.Conv2d(int(16*rate), int(16*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_2 = nn.BatchNorm2d(int(16*rate))
        self.conv2_3 = nn.Conv2d(int(16*rate), int(16*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_3 = nn.BatchNorm2d(int(16*rate))
        self.conv2_4 = nn.Conv2d(int(16*rate), int(16*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_4 = nn.BatchNorm2d(int(16*rate))
        self.conv2_5 = nn.Conv2d(int(16*rate), int(16*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_5 = nn.BatchNorm2d(int(16*rate))
        self.conv2_6 = nn.Conv2d(int(16*rate), int(16*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2_6 = nn.BatchNorm2d(int(16*rate))

        """ The 2nd Block """
        self.conv3_0 = nn.Conv2d(int(16*rate), int(32*rate), kernel_size=3, padding=1, stride=2, bias=False)  # Downsample
        self.bn3_0 = nn.BatchNorm2d(int(32*rate))

        self.conv3_1 = nn.Conv2d(int(16*rate), int(32*rate), kernel_size=3, padding=1, stride=2, bias=False)
        self.bn3_1 = nn.BatchNorm2d(int(32*rate))
        self.conv3_2 = nn.Conv2d(int(32*rate), int(32*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_2 = nn.BatchNorm2d(int(32*rate))
        self.conv3_3 = nn.Conv2d(int(32*rate), int(32*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_3 = nn.BatchNorm2d(int(32*rate))
        self.conv3_4 = nn.Conv2d(int(32*rate), int(32*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_4 = nn.BatchNorm2d(int(32*rate))
        self.conv3_5 = nn.Conv2d(int(32*rate), int(32*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_5 = nn.BatchNorm2d(int(32*rate))
        self.conv3_6 = nn.Conv2d(int(32*rate), int(32*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3_6 = nn.BatchNorm2d(int(32*rate))

        """ The 3rd Block """
        self.conv4_0 = nn.Conv2d(int(32*rate), int(64*rate), kernel_size=3, padding=1, stride=2, bias=False)  # Downsample
        self.bn4_0 = nn.BatchNorm2d(int(64*rate))
        self.conv4_1 = nn.Conv2d(int(32*rate), int(64*rate), kernel_size=3, padding=1, stride=2, bias=False)
        self.bn4_1 = nn.BatchNorm2d(int(64*rate))
        self.conv4_2 = nn.Conv2d(int(64*rate), int(64*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_2 = nn.BatchNorm2d(int(64*rate))
        self.conv4_3 = nn.Conv2d(int(64*rate), int(64*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_3 = nn.BatchNorm2d(int(64*rate))
        self.conv4_4 = nn.Conv2d(int(64*rate), int(64*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_4 = nn.BatchNorm2d(int(64*rate))
        self.conv4_5 = nn.Conv2d(int(64*rate), int(64*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_5 = nn.BatchNorm2d(int(64*rate))
        self.conv4_6 = nn.Conv2d(int(64*rate), int(64*rate), kernel_size=3, padding=1, stride=1, bias=False)
        self.bn4_6 = nn.BatchNorm2d(int(64*rate))

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(int(64*rate), outputs, bias=False)

def freeze_layer(model:Resnet20, n): # n is the number of frozen blocks!
    assert(n <= 4)
    if n >= 1:
        model.conv1.weight.requires_grad_(False)
        model.bn1.weight.requires_grad_(False)
        model.bn1.bias.requires_grad_(False)
    if n >= 2:
        model.conv2_1.weight.requires_grad_(False)
        model.bn2_1.weight.requires_grad_(False)
        model.bn2_1.bias.requires_grad_(False)
        model.conv2_2.weight.requires_grad_(False)
        model.bn2_2.weight.requires_grad_(False)
        model.bn2_2.bias.requires_grad_(False)
        model.conv2_3.weight.requires_grad_(False)
        model.bn2_3.weight.requires_grad_(False)
        model.bn2_3.bias.requires_grad_(False)
        model.conv2_4.weight.requires_grad_(False)
        model.bn2_4.weight.requires_grad_(False)
        model.bn2_4.bias.requires_grad_(False)
        model.conv2_5.weight.requires_grad_(False)
        model.bn2_5.weight.requires_grad_(False)
        model.bn2_5.bias.requires_grad_(False)
        model.conv2_6.weight.requires_grad_(False)
        model.bn2_6.weight.requires_grad_(False)
        model.bn2_6.bias.requires_grad_(False)
    if n >= 3:
        model.conv3_0.weight.requires_grad_(False)
        model.bn3_0.weight.requires_grad_(False)
        model.bn3_0.bias.requires_grad_(False)
        model.conv3_1.weight.requires_grad_(False)
        model.bn3_1.weight.requires_grad_(False)
        model.bn3_1.bias.requires_grad_(False)
        model.conv3_2.weight.requires_grad_(False)
        model.bn3_2.weight.requires_grad_(False)
        model.bn3_2.bias.requires_grad_(False)
        model.conv3_3.weight.requires_grad_(False)
        model.bn3_3.weight.requires_grad_(False)
        model.bn3_3.bias.requires_grad_(False)
        model.conv3_4.weight.requires_grad_(False)
        model.bn3_4.weight.requires_grad_(False)
        model.bn3_4.bias.requires_grad_(False)
        model.conv3_5.weight.requires_grad_(False)
        model.bn3_5.weight.requires_grad_(False)
        model.bn3_5.bias.requires_grad_(False)
        model.conv3_6.weight.requires_grad_(False)
        model.bn3_6.weight.requires_grad_(False)
        model.bn3_6.bias.requires_grad_(False)
    if n >= 4:
        model.conv4_0.weight.requires_grad_(False)
        model.bn4_0.weight.requires_grad_(False)
        model.bn4_0.bias.requires_grad_(False)
        model.conv4_1.weight.requires_grad_(False)
        model.bn4_1.weight.requires_grad_(False)
        model.bn4_1.bias.requires_grad_(False)
        model.conv4_2.weight.requires_grad_(False)
        model.bn4_2.weight.requires_grad_(False)
        model.bn4_2.bias.requires_grad_(False)
        model.conv4_3.weight.requires_grad_(False)
        model.bn4_3.weight.requires_grad_(False)
        model.bn4_3.bias.requires_grad_(False)
        model.conv4_4.weight.requires_grad_(False)
        model.bn4_4.weight.requires_grad_(False)
        model.bn4_4.bias.requires_grad_(False)
        model.conv4_5.weight.requires_grad_(False)
        model.bn4_5.weight.requires_grad_(False)
        model.bn4_5.bias.requires_grad_(False)
        model.conv4_6.weight.requires_grad_(False)
        model.bn4_6.weight.requires_grad_(False)
        model.bn4_6.bias.requires_grad_(False)

def random_freeze_layer(model:Resnet20, n, seed=12345):
    random.seed(seed)
    assert(n <= 4)
    layers = ['conv1', 'conv2', 'conv3', 'conv4']
    frozen_layers = random.sample(layers, k=n)
    if 'conv1' in frozen_layers:
        model.conv1.weight.requires_grad_(False)
        model.bn1.weight.requires_grad_(False)
        model.bn1.bias.requires_grad_(False)
    if 'conv2' in frozen_layers:
        model.conv2_1.weight.requires_grad_(False)
        model.bn2_1.weight.requires_grad_(False)
        model.bn2_1.bias.requires_grad_(False)
        model.conv2_2.weight.requires_grad_(False)
        model.bn2_2.weight.requires_grad_(False)
        model.bn2_2.bias.requires_grad_(False)
        model.conv2_3.weight.requires_grad_(False)
        model.bn2_3.weight.requires_grad_(False)
        model.bn2_3.bias.requires_grad_(False)
        model.conv2_4.weight.requires_grad_(False)
        model.bn2_4.weight.requires_grad_(False)
        model.bn2_4.bias.requires_grad_(False)
        model.conv2_5.weight.requires_grad_(False)
        model.bn2_5.weight.requires_grad_(False)
        model.bn2_5.bias.requires_grad_(False)
        model.conv2_6.weight.requires_grad_(False)
        model.bn2_6.weight.requires_grad_(False)
        model.bn2_6.bias.requires_grad_(False)
    if 'conv3' in frozen_layers:
        model.conv3_0.weight.requires_grad_(False)
        model.bn3_0.weight.requires_grad_(False)
        model.bn3_0.bias.requires_grad_(False)
        model.conv3_1.weight.requires_grad_(False)
        model.bn3_1.weight.requires_grad_(False)
        model.bn3_1.bias.requires_grad_(False)
        model.conv3_2.weight.requires_grad_(False)
        model.bn3_2.weight.requires_grad_(False)
        model.bn3_2.bias.requires_grad_(False)
        model.conv3_3.weight.requires_grad_(False)
        model.bn3_3.weight.requires_grad_(False)
        model.bn3_3.bias.requires_grad_(False)
        model.conv3_4.weight.requires_grad_(False)
        model.bn3_4.weight.requires_grad_(False)
        model.bn3_4.bias.requires_grad_(False)
        model.conv3_5.weight.requires_grad_(False)
        model.bn3_5.weight.requires_grad_(False)
        model.bn3_5.bias.requires_grad_(False)
        model.conv3_6.weight.requires_grad_(False)
        model.bn3_6.weight.requires_grad_(False)
        model.bn3_6.bias.requires_grad_(False)
    if 'conv4' in frozen_layers:
        model.conv4_0.weight.requires_grad_(False)
        model.bn4_0.weight.requires_grad_(False)
        model.bn4_0.bias.requires_grad_(False)
        model.conv4_1.weight.requires_grad_(False)
        model.bn4_1.weight.requires_grad_(False)
        model.bn4_1.bias.requires_grad_(False)
        model.conv4_2.weight.requires_grad_(False)
        model.bn4_2.weight.requires_grad_(False)
        model.bn4_2.bias.requires_grad_(False)
        model.conv4_3.weight.requires_grad_(False)
        model.bn4_3.weight.requires_grad_(False)
        model.bn4_3.bias.requires_grad_(False)
        model.conv4_4.weight.requires_grad_(False)
        model.bn4_4.weight.requires_grad_(False)
        model.bn4_4.bias.requires_grad_(False)
        model.conv4_5.weight.requires_grad_(False)
        model.bn4_5.weight.requires_grad_(False)
        model.bn4_5.bias.requires_grad_(False)
        model.conv4_6.weight.requires_grad_(False)
        model.bn4_6.weight.requires_grad_(False)
        model.bn4_6.bias.requires_grad_(False)
    return frozen_layers

if __name__ == "__main__":
    my_model = Resnet20(in_channels=CHANNELS, outputs=NUM_CLASSES)
    i1 = 0
    i2 = 0
    print("STATE DICT")
    for k, v in my_model.state_dict().items():
        print(f"{i1}:layer name: {k}, shape: {v.shape}")
        i1 += 1