import torch.nn as nn
import math
import torch.nn.functional as F
import torch
from typing import List, Tuple
from models.base import DownsampleB, SeparableConv2d
import helper


class SeparableResNet(nn.Module):
    def __init__(self, block, num_classes, factor, depth, logger, device, forward_transfer):
        super().__init__()
        self.device = device

        self.nb_tasks = len(num_classes)
        self.in_planes = int(64 * factor)
        self.conv1 = SeparableConv2d(
            3,
            int(64 * factor),
            3,
            1,
            1,
            1,
            False,
            self.nb_tasks,
            forward_transfer
        )
        self.pre_bn = nn.ModuleList(
            [nn.BatchNorm2d(int(64 * factor)) for _ in range(self.nb_tasks)]
        )

        strides = [1, 2, 2, 2, 2][:depth]
        filt_sizes = [64, 128, 256, 512, 1024][:depth]
        layers = [2, 2, 2, 2, 2][:depth]

        if logger is not None:
            logger.info("strides : {}".format(strides))
            logger.info("filter sizes : {}".format([x * factor for x in filt_sizes]))
            logger.info("layers : {}".format(layers))

        self.blocks, self.ds = [], []

        for (filt_size, num_blocks, stride) in zip(filt_sizes, layers[:depth], strides):
            blocks, ds = self._make_layer(
                block,
                int(filt_size * factor),
                num_blocks,
                stride=stride,
                forward_transfer=forward_transfer
            )
            self.blocks.append(nn.ModuleList(blocks))
            self.ds.append(ds)

        self.blocks = nn.ModuleList(self.blocks)
        self.ds = nn.ModuleList(self.ds)

        self.bns = nn.ModuleList(
            [
                nn.Sequential(
                    nn.BatchNorm2d(int(filt_sizes[-1] * factor)), nn.ReLU(True)
                )
                for i in range(self.nb_tasks)
            ]
        )
        self.maxpool = nn.AdaptiveMaxPool2d(1)
        self.fcs = nn.ModuleList(
            [
                nn.Linear(int(filt_sizes[-1] * factor), num_classes[i])
                for i in range(self.nb_tasks)
            ]
        )

        self.layer_config = layers

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2.0 / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def seed(self, x: torch.Tensor, task_id: int, training=False) -> torch.Tensor:
        x = self.conv1(x, task_id, training)
        x = self.pre_bn[task_id](x)
        return x

    def _make_layer(self, block, planes, blocks, stride, forward_transfer):
        downsample = nn.Sequential()
        if stride != 1 or self.in_planes != planes:
            downsample = DownsampleB(2)

        layers = [
            block(self.in_planes, planes, self.nb_tasks, stride, forward_transfer)
        ]
        self.in_planes = planes
        for i in range(1, blocks):
            layers.append(
                block(
                    self.in_planes,
                    planes,
                    nb_tasks=self.nb_tasks,
                    forward_transfer=forward_transfer
                )
            )

        return layers, downsample

    def forward(self, x: torch.Tensor, task_id: int, training=False, *args, **kwargs) -> torch.Tensor:
        x = self.seed(x, task_id, training)

        for segment, num_blocks in enumerate(self.layer_config):
            for b in range(num_blocks):
                if b == 0 and segment == 0:
                    residual = x
                else:
                    residual = self.ds[segment](x) if b == 0 else x
                x = self.blocks[segment][b](x, task_id, training)
                x = F.relu(residual + x)

        x = self.bns[task_id](x)
        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.fcs[task_id](x)
        return x

    def reset_last_parameters(self, args):
        helper.log_and_print(
            "Re-initializing the current sub-network with LoRA", args.logger, args.verbose
        )
        # Iterate through all ModuleLists in the model
        for name, module in self.named_modules():
            if isinstance(module, nn.ModuleList):
                last_layer = module[-1]
                if hasattr(last_layer, 'reset_parameters') :
                    last_layer.reset_parameters()
                elif name == 'bns':
                    last_layer = module[-1][0]
                    last_layer.reset_parameters()
                elif name == 'blocks':
                    for a_name, block in module.named_children():
                        for b_name, basicblock in block.named_children():
                            for c_name, c_block in basicblock.named_children():
                                if isinstance(c_block, SeparableConv2d):
                                    c_block.reset_parameters()
                                elif isinstance(c_block, nn.ModuleList):
                                    last_layer = c_block[-1]
                                    last_layer.reset_parameters()
            elif isinstance(module, SeparableConv2d):
                module.reset_parameters()