#%%
import torch

from ..layer.tucker_conv_vanilla import Conv2d_tucker_vanilla
from ..layer.mat_conv_vanilla import Conv2d_mat_vanilla
from ..layer.CP_conv_vanilla import Conv2d_CP_vanilla
from ..__init__ import factorization , glob_start_rank_perc


low_rank_layers = []

def conv(in_channels: int, out_channels: int,kernel_size:int, stride: int = 1, groups: int = 1, padding: int = 1,bias : bool = False,factorization = factorization) -> torch.nn.Conv2d:
    """3x3 convolution with padding"""
    if factorization.lower() == 'tucker':
        t = Conv2d_tucker_vanilla(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,
                                   bias=bias, dilation=padding, start_rank_percent=glob_start_rank_perc)
    elif factorization.lower() == 'mat':
        t = Conv2d_mat_vanilla(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,
                                   bias=bias, dilation=padding, start_rank_percent=glob_start_rank_perc)
    elif factorization.lower() == 'cp':
        t = Conv2d_CP_vanilla(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                              bias=bias, dilation=padding, start_rank_percent=glob_start_rank_perc)
    low_rank_layers.append(t)
    return t


class Flatten(torch.nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        # out = input.view(batch_size,-1)
        out = input.contiguous().view(batch_size, -1)
        return out
    

class AlexNet(torch.nn.Module):
    def __init__(self, output_dim,device = 'cpu'):
        super().__init__()
        self.device = device
        self.layer = torch.nn.Sequential(
            conv(in_channels = 3,out_channels = 64,kernel_size= 3,stride =  1, padding = 1,bias = False),  # in_channels, out_channels, kernel_size, stride, padding
            torch.nn.BatchNorm2d(64,momentum=0.9),
            torch.nn.MaxPool2d(2), 
            torch.nn.ReLU(),
            conv(64, 192, 3, padding=1,bias = False),
            torch.nn.BatchNorm2d(192,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            conv(192, 384, 3, padding=1,bias = False),
            torch.nn.BatchNorm2d(384,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            conv(384, 256, 3, padding=1),
            torch.nn.BatchNorm2d(256,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            conv(256, 256, 3, padding=1,bias = False),
            torch.nn.BatchNorm2d(256,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            Flatten(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(256, output_dim)
        )

    def forward(self, x):
        return self.layer(x)
    
def alexnet():

    return AlexNet(10)