# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import relu, avg_pool2d
from typing import List
from torch.nn import ModuleList
import torch.optim as optim
import helper
import math
from copy import deepcopy


class Conv3x3(nn.Module):
    def __init__(self, in_planes: int, out_planes: int, stride: int = 1, n_tasks: int = 1, 
                    kernel_size=3, padding=1, bias=False, args=None):
        super(Conv3x3, self).__init__()
        self.conv = ModuleList([nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                              padding=padding, bias=bias) for _ in range(n_tasks)])

        assert args is not None
        self.scaling = 1
        self.in_channels = in_planes
        self.out_channels = out_planes
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.bias = bias
        self.nb_tasks = n_tasks
        self.forward_transfer = args.forward_transfer

    def forward(self, x, task_id: int  = 0, training: bool = False):
        if training or task_id == 0:
            if task_id == 0:
                x_out = self.conv[0](x)
            else:
                x_out = self.conv[-1](x)

        else:
            if self.forward_transfer:
                with torch.no_grad():
                    weight = deepcopy(self.conv[0].weight.data)

                weight += (self.lora_B[task_id-1].weight @ self.lora_A[task_id-1].weight).view(weight.shape) * self.scaling

                x_out = self.conv[0]._conv_forward(x, weight, None)
            else:
                x_out = self.conv[task_id](x)

        return x_out

    def reset_parameters(self):
        if hasattr(self, 'lora_A'):
            nn.init.kaiming_uniform_(self.lora_A[-1].weight, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B[-1].weight)

class ModifiedSequential(nn.Sequential):
    def forward(self, input, *args, **kwargs):
        # return super().forward(input, *args, **kwargs)
        for module in self:
            # Call the forward method of each module with input and additional arguments
            input = module(input, *args, **kwargs)
        return input


class CustomSequential(nn.Module):
    def __init__(self, conv=None, bn=None):
        super(CustomSequential, self).__init__()
        self.conv = conv
        self.bn = bn
        self.sequential = ModifiedSequential()

    def forward(self, x, task_id, training: bool = False):
        if self.conv is None and self.bn is None:
            return self.sequential(x)
        x = self.conv(x, task_id, training)
        x = self.bn[task_id](x)
        return x


class BasicBlockResNet(nn.Module):
    """
    The basic block of ResNet.
    """
    expansion = 1

    def __init__(self, in_planes: int, planes: int, stride: int=1, n_tasks: int=1, args = None) -> None:
        """
        Instantiates the basic block of the network.
        :param in_planes: the number of input channels
        :param planes: the number of channels (to be possibly expanded)
        """
        super(BasicBlockResNet, self).__init__()
        self.conv1 = Conv3x3(in_planes=in_planes, out_planes=planes, stride=stride, 
                                n_tasks=n_tasks, args=args) 
        self.bn1 = ModuleList([nn.BatchNorm2d(planes) for _ in range(n_tasks)])
        self.conv2 = Conv3x3(in_planes=planes, out_planes=planes, n_tasks=n_tasks, args=args)
        self.bn2 = ModuleList([nn.BatchNorm2d(planes) for _ in range(n_tasks)])

        
        self.shortcut = CustomSequential()
        if stride != 1 or in_planes != self.expansion * planes:
            
            self.shortcut = CustomSequential(
                conv = Conv3x3(in_planes=in_planes, out_planes=self.expansion * planes, kernel_size=1, padding=0, n_tasks=n_tasks, 
                          stride=stride, bias=False, args=args),
                
                bn = ModuleList([nn.BatchNorm2d(self.expansion * planes) for _ in range(n_tasks)])
            ) 

    def forward(self, x: torch.Tensor, task_id: int = 0, training: bool = False) -> torch.Tensor:
        """
        Compute a forward pass.
        :param x: input tensor (batch_size, input_size)
        :return: output tensor (10)
        """
        out = relu(self.bn1[task_id](self.conv1(x, task_id, training)))
        out = self.bn2[task_id](self.conv2(out, task_id, training))
        out += self.shortcut(x, task_id, training)
        out = relu(out)
        return out


class ResNet(nn.Module):
    """
    ResNet network architecture. Designed for complex datasets.
    """

    def __init__(self, block: BasicBlockResNet, num_blocks: List[int],
                 num_classes: int, nf: int, n_tasks: int, args = None, device=None) -> None:
        """
        Instantiates the layers of the network.
        :param block: the basic ResNet block
        :param num_blocks: the number of blocks per layer
        :param num_classes: the number of output classes
        :param nf: the number of filters
        """
        super(ResNet, self).__init__()
        self.device = device
        self.in_planes = nf
        self.block = block
        self.num_classes = num_classes
        self.nf = nf
        self.conv1 = Conv3x3(in_planes=3, out_planes=nf * 1, n_tasks=n_tasks, args=args)
        self.bn1 = ModuleList([nn.BatchNorm2d(nf * 1) for _ in range(n_tasks)])
        self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1, n_tasks=n_tasks, args=args)
        self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2, n_tasks=n_tasks, args=args)
        self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2, n_tasks=n_tasks, args=args)
        self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2, n_tasks=n_tasks, args=args)

        self.linear = ModuleList([nn.Linear(nf * 8 * block.expansion, num_classes) for _ in range(n_tasks)])
        self.classifier = self.linear
        self.fcs = self.linear

    def _make_layer(self, block: BasicBlockResNet, planes: int,
                    num_blocks: int, stride: int, n_tasks: int, args=None) -> nn.Module:
        """
        Instantiates a ResNet layer.
        :param block: ResNet basic block
        :param planes: channels across the network
        :param num_blocks: number of blocks
        :param stride: stride
        :return: ResNet layer
        """
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, n_tasks, args=args))
            self.in_planes = planes * block.expansion
        return ModifiedSequential(*layers)

    def forward(self, x: torch.Tensor, task_id: int = 0, training: bool = False, *args, **kwargs) -> torch.Tensor:
        """
        Compute a forward pass.
        :param x: input tensor (batch_size, *input_shape)
        :return: output tensor (output_classes)
        """
        out = relu(self.bn1[task_id](self.conv1(x, task_id, training)))
        out = self.layer1(out, task_id, training)  # 64, 32, 32
        out = self.layer2(out, task_id, training)  # 128, 16, 16
        out = self.layer3(out, task_id, training)  # 256, 8, 8
        out = self.layer4(out, task_id, training)  # 512, 4, 4
        out = avg_pool2d(out, out.shape[2]) # 512, 1, 1
        out = out.view(out.size(0), -1)  # 512
        out_cls = self.linear[task_id](out)
        return out_cls

    def get_params(self) -> torch.Tensor:
        """
        Returns all the parameters concatenated in a single tensor.
        :return: parameters tensor (??)
        """
        params = []
        for pp in list(self.parameters()):
            params.append(pp.view(-1))
        return torch.cat(params)

    def set_params(self, new_params: torch.Tensor) -> None:
        """
        Sets the parameters to a given value.
        :param new_params: concatenated values to be set (??)
        """
        assert new_params.size() == self.get_params().size()
        progress = 0
        for pp in list(self.parameters()):
            cand_params = new_params[progress: progress +
                torch.tensor(pp.size()).prod()].view(pp.size())
            progress += torch.tensor(pp.size()).prod()
            pp.data = cand_params

    def get_grads(self) -> torch.Tensor:
        """
        Returns all the gradients concatenated in a single tensor.
        :return: gradients tensor (??)
        """
        grads = []
        for pp in list(self.parameters()):
            grads.append(pp.grad.view(-1))
        return torch.cat(grads)
    
    def reset_last_parameters(self, args):
        helper.log_and_print(
            "Re-initializing the current sub-network with LoRA", args.logger, args.verbose
        )
        # Iterate through all ModuleLists in the model
        for name, module in self.named_modules():
            if isinstance(module, nn.ModuleList):
                last_layer = module[-1]
                if hasattr(last_layer, 'reset_parameters') :
                    last_layer.reset_parameters()
                elif 'bn' in name:
                    last_layer = module[-1][0]
                    last_layer.reset_parameters()
                else:
                    # print(f'Not re-initialized: {name}')
                    pass
            elif isinstance(module, Conv3x3):
                module.reset_parameters()
            else:
                    # print(f'Not re-initialized: {name}')
                    pass