#%%
import torch

from ..layer.tucker_conv_vanilla import Conv2d_tucker_vanilla
from ..layer.mat_conv_vanilla import Conv2d_mat_vanilla
from ..layer.CP_conv_vanilla import Conv2d_CP_vanilla
from ..__init__ import factorization , glob_start_rank_perc


low_rank_layers = []

def conv(in_channels: int, out_channels: int,kernel_size:int, stride: int = 1, groups: int = 1, padding: int = 1,bias : bool = False,factorization = factorization) -> torch.nn.Conv2d:
    """3x3 convolution with padding"""
    if factorization.lower() == 'tucker':
        t = Conv2d_tucker_vanilla(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,
                                   bias=bias, dilation=padding, start_rank_percent=glob_start_rank_perc)
    elif factorization.lower() == 'mat':
        t = Conv2d_mat_vanilla(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,
                                   bias=bias, dilation=padding, start_rank_percent=glob_start_rank_perc)
    elif factorization.lower() == 'cp':
        t = Conv2d_CP_vanilla(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                              bias=bias, dilation=padding, start_rank_percent=glob_start_rank_perc)
    low_rank_layers.append(t)
    return t


VGG_types = {
    "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG13": [
        64,
        64,
        "M",
        128,
        128,
        "M",
        256,
        256,
        "M",
        512,
        512,
        "M",
        512,
        512,
        "M",
    ],
    "VGG16": [
        64,
        64,
        "M",
        128,
        128,
        "M",
        256,
        256,
        256,
        "M",
        512,
        512,
        512,
        "M",
        512,
        512,
        512,
        "M",
    ],
    "VGG19": [
        64,
        64,
        "M",
        128,
        128,
        "M",
        256,
        256,
        256,
        256,
        "M",
        512,
        512,
        512,
        512,
        "M",
        512,
        512,
        512,
        512,
        "M",
    ],
}


class Flatten(torch.nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        # out = input.view(batch_size,-1)
        out = input.contiguous().view(batch_size, -1)
        return out


class VGG(torch.nn.Module):
    def __init__(
            self,
            architecture,
            in_channels=3,
            in_height=224,
            in_width=224,
            num_hidden=4096,
            num_classes=1000,
    ):
        super(VGG, self).__init__()
        self.in_channels = in_channels
        self.in_width = in_width
        self.in_height = in_height
        self.num_hidden = num_hidden
        self.num_classes = num_classes
        self.layer = torch.nn.Sequential()
        j = 0
        for x in architecture:
            if type(x) == int:
                out_channels = x

                self.layer.add_module('conv_' + str(j), conv(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=3,
                    stride=(1, 1),
                    padding=(1, 1),
                    bias=False
                ))
                self.layer.add_module('bn_' + str(j), torch.nn.BatchNorm2d(out_channels, momentum=0.9))
                self.layer.add_module('relu_' + str(j), torch.nn.ReLU(inplace=True))
                in_channels = x
            else:
                self.layer.add_module('maxpool_' + str(j),
                                      torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
                                      )
            j += 1

        pool_count = architecture.count("M")
        factor = (2 ** pool_count)
        if (self.in_height % factor) + (self.in_width % factor) != 0:
            raise ValueError(f"`in_height` and `in_width` must be multiples of {factor}")
        out_height = self.in_height // factor
        out_width = self.in_width // factor
        last_out_channels = next(
            x for x in architecture[::-1] if type(x) == int
        )

        self.layer.add_module('flat', Flatten())
        self.layer.add_module('linear_' + str(1), torch.nn.Linear(
            last_out_channels * out_height * out_width,
            self.num_hidden))
        self.layer.add_module('bn_1d', torch.nn.BatchNorm1d(self.num_hidden))  ####
        self.layer.add_module('relu_' + str(j + 1), torch.nn.ReLU(inplace=True))
        self.layer.add_module('drop_' + str(1), torch.nn.Dropout(p=0.2))
        self.layer.add_module('linear_' + str(2), torch.nn.Linear(self.num_hidden, self.num_hidden,
                                                                  ))
        self.layer.add_module('relu_' + str(j + 2), torch.nn.ReLU(inplace=True))
        self.layer.add_module('drop_2', torch.nn.Dropout(p=0.2))
        self.layer.add_module('classifier', torch.nn.Linear(self.num_hidden, self.num_classes))
        self.init_weights()

    def init_weights(self, name='kn'):
        if name == 'kn':
            for l in self.layer:
                if isinstance(l, torch.nn.Linear):
                    torch.nn.init.kaiming_normal_(l.weight, nonlinearity='relu')
                    # l.bias.data.fill_(0.01)
                    torch.nn.init.uniform(l.bias.data)
                elif isinstance(l, torch.nn.Conv2d):
                    torch.nn.init.kaiming_normal_(l.weight, nonlinearity='relu')
                    if l.bias:
                        torch.nn.init.uniform(l.bias.data)
                    # l.bias.data.fill_(0.01)
        elif name == 'orthogonal':
            for l in self.layer:
                if isinstance(l, torch.nn.Linear) or isinstance(l, torch.nn.Conv2d):
                    torch.nn.init.orthogonal(l.weight.data, gain=1.41)
                    torch.nn.init.constant(l.bias.data, val=0.0)

    def forward(self, x):
        return self.layer(x)
    

def vgg16():
    return VGG(VGG_types["VGG16"],3,32,32,256,10)