import torch
import torch.nn as nn
import numpy as np
import random
from PIL import Image

NUM_CLASSES = 10
CHANNELS = 3

class AlexNet(nn.Module):
    def __init__(self, in_channels=3, outputs=10, rate=1.0) -> None:
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels,max(1, int(16*rate)),kernel_size=(3,3),padding=1)
        self.bn1 = nn.BatchNorm2d(max(1, int(16*rate)))

        self.conv2 = nn.Conv2d(max(1, int(16*rate)),max(1, int(16*rate)),kernel_size=(3,3),padding=1)
        self.bn2 = nn.BatchNorm2d(max(1, int(16*rate)))

        self.conv3 = nn.Conv2d(max(1, int(16*rate)),max(1, int(32*rate)),kernel_size=(3,3),padding=1)
        self.bn3 = nn.BatchNorm2d(max(1, int(32*rate)))
        
        self.conv4 = nn.Conv2d(max(1, int(32*rate)),max(1, int(32*rate)),kernel_size=(3,3),padding=1)
        self.bn4 = nn.BatchNorm2d(max(1, int(32*rate)))

        self.conv5 = nn.Conv2d(max(1, int(32*rate)),max(1, int(64*rate)),kernel_size=(3,3),padding=1)
        self.bn5 = nn.BatchNorm2d(max(1, int(64*rate)))

        self.act = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(max(1, int(64*rate)*7*7), max(1, int(128*rate)),bias=False)
        self.fc2 = nn.Linear(max(1, int(128*rate)), max(1, int(96*rate)), bias=False)
        self.fc = nn.Linear(max(1, int(96*rate)), outputs)

    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)
        x = self.pool(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act(x)
        x = self.pool(x)

        x = self.conv4(x)
        x = self.bn4(x)
        x = self.act(x)
        x = self.pool(x)

        x = self.conv5(x)
        x = self.bn5(x)
        x = self.act(x)
        x = self.pool(x)

        x = x.reshape(x.shape[0], -1)
        #print(f"shape of x = {x.shape}")
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc(x)
        return x

class AlexNet_approximated(AlexNet):
    def __init__(self, in_channels=3, outputs=10, lf=0, rate=1.0) -> None:
        super(AlexNet_approximated, self).__init__()
        
        if lf >= 2:
            w1 = max(1, int(16*rate))
        else:
            w1 = 16
        
        if lf >= 3:
            w2 = max(1, int(16*rate))
        else:
            w2 = 16

        if lf >= 4:
            w3 = max(1, int(32*rate))
        else:
            w3 = 32
        
        if lf >= 5:
            w4 = max(1, int(32*rate))
        else:
            w4 = 32


        self.conv1 = nn.Conv2d(in_channels,w1, kernel_size=(3,3),padding=1)
        self.bn1 = nn.BatchNorm2d(w1)

        self.conv2 = nn.Conv2d(w1, w2, kernel_size=(3,3),padding=1)
        self.bn2 = nn.BatchNorm2d(w2)

        self.conv3 = nn.Conv2d(w2, w3, kernel_size=(3,3),padding=1)
        self.bn3 = nn.BatchNorm2d(w3)
        
        self.conv4 = nn.Conv2d(w3, w4, kernel_size=(3,3),padding=1)
        self.bn4 = nn.BatchNorm2d(w4)

        self.conv5 = nn.Conv2d(w4, 64, kernel_size=(3,3),padding=1)
        self.bn5 = nn.BatchNorm2d(64)

        self.act = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(64*7*7, 128, bias=False)
        self.fc2 = nn.Linear(128, 96, bias=False)
        self.fc = nn.Linear(96, outputs)

    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)
        x = self.pool(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act(x)
        x = self.pool(x)

        x = self.conv4(x)
        x = self.bn4(x)
        x = self.act(x)
        x = self.pool(x)

        x = self.conv5(x)
        x = self.bn5(x)
        x = self.act(x)
        x = self.pool(x)

        x = x.reshape(x.shape[0], -1)
        #print(f"shape of x = {x.shape}")
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc(x)
        return x

def random_freeze_layer(model:AlexNet, n, seed=12345):
    assert(n>=0 and n<=7)
    random.seed(seed)
    layers = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc1', 'fc2']
    frozen_layers = random.sample(layers, k=n)
    if 'conv1' in frozen_layers:
        model.conv1.weight.requires_grad_(False)
        model.conv1.bias.requires_grad_(False)
        model.bn1.weight.requires_grad_(False)
        model.bn1.bias.requires_grad_(False)
    if 'conv2' in frozen_layers:
        model.conv2.weight.requires_grad_(False)
        model.conv2.bias.requires_grad_(False)
        model.bn2.weight.requires_grad_(False)
        model.bn2.bias.requires_grad_(False)
    if 'conv3' in frozen_layers:
        model.conv3.weight.requires_grad_(False)
        model.conv3.bias.requires_grad_(False)
        model.bn3.weight.requires_grad_(False)
        model.bn3.bias.requires_grad_(False)
    if 'conv4' in frozen_layers:
        model.conv4.weight.requires_grad_(False)
        model.conv4.bias.requires_grad_(False)
        model.bn4.weight.requires_grad_(False)
        model.bn4.bias.requires_grad_(False)
    if 'conv5' in frozen_layers:
        model.conv5.weight.requires_grad_(False)
        model.conv5.bias.requires_grad_(False)
        model.bn5.weight.requires_grad_(False)
        model.bn5.bias.requires_grad_(False)
    if 'fc1' in frozen_layers:
        model.fc1.weight.requires_grad_(False)
    if 'fc2' in frozen_layers:
        model.fc2.weight.requires_grad_(False)
    return frozen_layers

def freeze_layer(model:AlexNet, n):
    assert(n>=0 and n<=7)
    if n >= 1:
        model.conv1.weight.requires_grad_(False)
        model.conv1.bias.requires_grad_(False)
        model.bn1.weight.requires_grad_(False)
        model.bn1.bias.requires_grad_(False)
    if n >= 2:
        model.conv2.weight.requires_grad_(False)
        model.conv2.bias.requires_grad_(False)
        model.bn2.weight.requires_grad_(False)
        model.bn2.bias.requires_grad_(False)
    if n >= 3:
        model.conv3.weight.requires_grad_(False)
        model.conv3.bias.requires_grad_(False)
        model.bn3.weight.requires_grad_(False)
        model.bn3.bias.requires_grad_(False)
    if n >= 4:
        model.conv4.weight.requires_grad_(False)
        model.conv4.bias.requires_grad_(False)
        model.bn4.weight.requires_grad_(False)
        model.bn4.bias.requires_grad_(False)
    if n >= 5:
        model.conv5.weight.requires_grad_(False)
        model.conv5.bias.requires_grad_(False)
        model.bn5.weight.requires_grad_(False)
        model.bn5.bias.requires_grad_(False)
    if n >= 6:
        model.fc1.weight.requires_grad_(False)
    if n == 7:
        model.fc2.weight.requires_grad_(False)

if __name__ == "__main__":
    my_model = AlexNet_approximated(CHANNELS, lf=2, rate=0.5)
    i1 = 0
    i2 = 0
    print("STATE DICT\n")
    for k, v in my_model.state_dict().items():
        print(f"{i1}:layer name: {k}, shape: {v.shape}")
        i1 += 1
    print("NAMED PARAM\n")
    for k, v in my_model.named_parameters():
        print(f"{i2}:layer name: {k}, shape: {v.shape}")
        i2 += 1
    '''
    x = np.array([[1,2,3,4,6,6],[5,5,2,2,2,2]])
    y = np.ones(x.shape)
    q = [x,y]
    print(f"x + y = {np.sum(q,axis=0)}")
    x[(x <= 3.5) & (x >= -3.5)] = 0.0
    y[(x <= 3.5) & (x >= -3.5)] = 0.0
    print(f"x = {x}")
    print(f"y = {y}")
    '''