import torch.nn as nn
import torch

CHANNELS = 1
CLASSES = 62

class CNN(nn.Module):
    def __init__(self, in_channels=CHANNELS, outputs=CLASSES, rate=1.0) -> None:
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, max(1, int(20*rate)), kernel_size=(3,3), padding=1)
        self.bn1 = nn.BatchNorm2d(max(1, int(20*rate)))
        self.conv2 = nn.Conv2d(max(1, int(20*rate)), max(1, int(20*rate)), kernel_size=(3,3), padding=1)
        self.bn2 = nn.BatchNorm2d(max(1, int(20*rate)))
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.act = nn.ReLU()
        self.fc = nn.Linear(max(1, int(20*rate))*7*7, outputs)
    
    def forward(self, x):
        #print(f"shape of x is {x.shape}")
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.pool(x)
        #print(f"shape of x is {x.shape}")
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        #print(f"shape of x is {x.shape}")
        x = self.fc(x)
        return x
    
class successive_CNN(nn.Module):
    def __init__(self, in_channels=CHANNELS, outputs=CLASSES, rate=1.0) -> None:
        super(successive_CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 20, kernel_size=(3,3), padding=1)
        self.bn1 = nn.BatchNorm2d(20)
        self.conv2 = nn.Conv2d(20, max(1, int(20*rate)), kernel_size=(3,3), padding=1)
        self.bn2 = nn.BatchNorm2d(max(1, int(20*rate)))
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.act = nn.ReLU()
        self.fc = nn.Linear(max(1, int(20*rate))*7*7, outputs)
    
    def forward(self, x):
        #print(f"shape of x is {x.shape}")
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.pool(x)
        #print(f"shape of x is {x.shape}")
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        #print(f"shape of x is {x.shape}")
        x = self.fc(x)
        return x

def freeze_layer(model:CNN, n):
    assert(n>=0 and n<=2)
    if n >= 1:
        model.conv1.weight.requires_grad_(False)
        model.conv1.bias.requires_grad_(False)
        model.bn1.weight.requires_grad_(False)
        model.bn1.bias.requires_grad_(False)
    if n == 2:
        model.conv2.weight.requires_grad_(False)
        model.conv2.bias.requires_grad_(False)
        model.bn2.weight.requires_grad_(False)
        model.bn2.bias.requires_grad_(False)

def freeze_layer_random(model:CNN, n):
    assert(n>=0 and n<=2)
    if n == 0:
        return [0,1,2]
    elif n == 1:
        import random
        l = random.choice([1,2])
        if l == 1:
            model.conv1.weight.requires_grad_(False)
            model.conv1.bias.requires_grad_(False)
            model.bn1.weight.requires_grad_(False)
            model.bn1.bias.requires_grad_(False)
        else:
            model.conv2.weight.requires_grad_(False)
            model.conv2.bias.requires_grad_(False)
            model.bn2.weight.requires_grad_(False)
            model.bn2.bias.requires_grad_(False)
        q = 2 if l == 1 else 1
        return [0,q]
    else:
        model.conv1.weight.requires_grad_(False)
        model.conv1.bias.requires_grad_(False)
        model.bn1.weight.requires_grad_(False)
        model.bn1.bias.requires_grad_(False)
        model.conv2.weight.requires_grad_(False)
        model.conv2.bias.requires_grad_(False)
        model.bn2.weight.requires_grad_(False)
        model.bn2.bias.requires_grad_(False)
        return [0]

def add_group_lasso(model:CNN):
    filter_lasso_1 = model.conv1.weight.pow(2).sum(dim=(3,2)).pow(0.5).sum()
    channel_lasso_1 = model.conv1.weight.pow(2).sum(dim=(3,2,1)).pow(0.5).sum()
    filter_lasso_2 = model.conv2.weight.pow(2).sum(dim=(3,2)).pow(0.5).sum()
    channel_lasso_2 = model.conv2.weight.pow(2).sum(dim=(3,2,1)).pow(0.5).sum()
    return filter_lasso_1 + channel_lasso_1 + filter_lasso_2 + channel_lasso_2

if __name__ == "__main__":
    my_model = CNN(CHANNELS)
    for k, v in my_model.named_parameters():
        print(f"layer name: {k}, shape: {v.shape}, size = {torch.numel(v)} requires grad = {v.requires_grad}")