from typing import *

import numpy as np
from torch import Tensor, nn
from torch.utils.data import DataLoader, TensorDataset

from approaches.spg.ablation import Ablation
from approaches.spg.spg import SPG
from approaches.spg.other_task_loss import OtherTaskLoss
from mvseq import MultiVariableSequence
from utils import assert_type

use_bias = True


class ModelSPGFeatureAlexNet(nn.Module):
    def __init__(self, list__ncls: List[int], inputsize: Tuple[int, ...], batch_size: int,
                 nhid: int, drop1: float, drop2: float, shift: float, ablation: Ablation,
                 seqname: str):
        super().__init__()
        self.model = Alexnet(inputsize=inputsize, batch_size=batch_size,
                             nhid=nhid, drop1=drop1, drop2=drop2,
                             shift=shift, ablation=ablation, seqname=seqname)

        self.last_dim = nhid
    # enddef

    def set_fc(self, fc: nn.Module):
        self.model.fc = fc
    # enddef

    def freeze_masks(self, idx_task: int, dl: DataLoader, **kwargs):
        self.model.freeze_masks(idx_task, dl, **kwargs)
    # enddef

    def forward(self, x: Tensor, idx_task: int, s: float, args_on_forward: Dict[str, Any]) -> Tuple[Tensor, Dict]:
        assert_type(idx_task, int)
        assert_type(x, Tensor)

        batch_size = x.shape[0]
        # assert_shape(x, batch_size, -1, -1, -1)

        ret, misc = self.model(x, idx_task, s=s, args_on_forward=args_on_forward)

        return ret, misc
    # enddef

    def on_after_backward_params(self, idx_task: int, s: float, args: Dict[str, Any]) -> None:
        blocking = self.model.on_after_backward_params(idx_task=idx_task, s=s, smax=NotImplemented, args=args)

        return blocking
    # enddef


# endclass

def compute_conv_output_size(Lin, kernel_size, stride=1, padding=0, dilation=1):
    return int(np.floor((Lin + 2 * padding - dilation * (kernel_size - 1) - 1) / float(stride) + 1))


class Flattener(nn.Module):
    def __init__(self):
        super().__init__()
    # enddef

    def forward(self, x: Tensor) -> Tensor:
        return x.view(x.shape[0], -1)
    # enddef


class Alexnet(nn.Module):
    def __init__(self, inputsize: Tuple[int, ...], batch_size: int,
                 nhid: int, drop1: float, drop2: float, shift: float,
                 ablation: Ablation, seqname: str):
        super().__init__()

        self.ablation = ablation

        nch, size = inputsize[0], inputsize[1]
        self.batch_size = batch_size

        self.c1 = SPG(nn.Conv2d(nch, 64, kernel_size=size // 8), shift, target_name='c1', seqname=seqname)
        s = compute_conv_output_size(size, size // 8)
        s = s // 2

        self.c2 = SPG(nn.Conv2d(64, 128, kernel_size=size // 10), shift, target_name='c2', seqname=seqname)
        s = compute_conv_output_size(s, size // 10)
        s = s // 2

        self.c3 = SPG(nn.Conv2d(128, 256, kernel_size=2), shift, target_name='c3', seqname=seqname)
        s = compute_conv_output_size(s, 2)
        s = s // 2

        self.smid = s
        self.maxpool = nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.drop1 = nn.Dropout(drop1)
        self.drop2 = nn.Dropout(drop2)

        self.fc1 = SPG(nn.Linear(256 * self.smid ** 2, nhid, bias=use_bias), shift, target_name='fc1', seqname=seqname)
        self.fc2 = SPG(nn.Linear(nhid, nhid, bias=use_bias), shift, target_name='fc2', seqname=seqname)
    # endddef

    def forward(self, x: Tensor, idx_task: int, s: float, args_on_forward: Dict[str, Any]) -> Tuple[Tensor, Dict]:
        assert_type(idx_task, int)
        assert_type(x, Tensor)

        self.device = x.device

        # 1st
        h = self.maxpool(self.drop1(self.relu(self.c1(x))))

        # 2nd
        h = self.maxpool(self.drop1(self.relu(self.c2(h))))

        # 3rd
        h = self.maxpool(self.drop2(self.relu(self.c3(h))))

        # 4th
        h = h.view(h.shape[0], -1)
        h = self.drop2(self.relu(self.fc1(h)))
        h = self.drop2(self.relu(self.fc2(h)))

        misc = {
            'reg': 0
            }

        return h, misc
    # enddef

    def freeze_masks(self, idx_task: int, dl: DataLoader, **kwargs):
        # dl = DataLoader(TensorDataset(x, y), batch_size=self.batch_size)
        dl = DataLoader(dl.dataset, batch_size=self.batch_size * 10)

        if self.ablation == Ablation.Asis:
            range_tasks = range(idx_task + 1)
        elif self.ablation == Ablation.NoCrossHeadImportance:
            range_tasks = [idx_task]
        elif self.ablation in [Ablation.EarlyGradients0, Ablation.EarlyGradients10,
                               Ablation.EarlyGradients20]:
            range_tasks = range(idx_task + 1)
            if 'epoch' in kwargs.keys():
                epoch = kwargs['epoch']
            else:
                epoch = None
            # endif
        else:
            raise NotImplementedError(self.ablation)
        # endif

        if self.ablation == Ablation.EarlyGradients0 and epoch != 0:
            pass
        elif self.ablation == Ablation.EarlyGradients10 and epoch != 10:
            pass
        elif self.ablation == Ablation.EarlyGradients20 and epoch != 20:
            pass
        else:
            for t in range_tasks:
                dict_list_history = {}
                for x, y in dl:
                    assert_type(x, Tensor)
                    assert_type(y, Tensor)

                    x = x.to(self.device)
                    y = y.to(self.device)

                    if t == idx_task:
                        lossfunc = nn.CrossEntropyLoss()
                    else:
                        lossfunc = OtherTaskLoss()
                    # endif

                    modules = MultiVariableSequence([
                        #
                        (self.c1, None),
                        (self.relu, None),
                        (self.drop1, None),
                        (self.maxpool, None),
                        #
                        (self.c2, None),
                        (self.relu, None),
                        (self.drop1, None),
                        (self.maxpool, None),
                        #
                        (self.c3, None),
                        (self.relu, None),
                        (self.drop2, None),
                        (self.maxpool, None),
                        #
                        (Flattener(), None),
                        #
                        (self.fc1, None),
                        (self.relu, None),
                        (self.drop2, None),
                        #
                        (self.fc2, None),
                        (self.relu, None),
                        (self.drop2, None),
                        #
                        (self.fc, [t]),
                        (lossfunc, [y]),
                        ])
                    # modules.train()

                    for name_module, module in self.named_modules():
                        if isinstance(module, SPG):
                            history = module.compute_a_by_p(modules, x)

                            if name_module not in dict_list_history.keys():
                                dict_list_history[name_module] = list()
                            # endif

                            dict_list_history[name_module].append(history)
                        # endif
                    # endfor
                # endfor

                # aggregate grad across mini-batches
                dict_history = {}
                for name_module in dict_list_history.keys():
                    list_history = dict_list_history[name_module]

                    if name_module not in dict_history.keys():
                        dict_history[name_module] = {}
                    # endif
                    for history in list_history:
                        for name_param, g in history.items():
                            if name_param not in dict_history[name_module].keys():
                                dict_history[name_module][name_param] = 0
                            # endif

                            dict_history[name_module][name_param] += g
                        # endfor
                    # endfor
                # endfor

                for name_module, module in self.named_modules():
                    if isinstance(module, SPG):
                        history = dict_history[name_module]

                        module.register_grad(idx_task, t, None, history)
                    # endif
                # endfor
            # endfor|t

            # freeze masks
            for name, module in self.named_modules():
                if isinstance(module, SPG):
                    module.freeze_masks(idx_task, save_artifact=self.ablation == Ablation.Asis)
                # endif
            # endfor
        # endif
    # enddef

    def on_after_backward_params(self, idx_task: int, s: float, smax: float, args: Dict[str, Any]):
        if idx_task == 0:
            return
        # endif

        epoch = args['epoch']
        idx_batch = args['idx_batch']
        show = epoch == 0 and idx_batch == 0

        for n, module in self.named_modules():
            if isinstance(module, SPG):
                module.block(idx_task, name=n, show=show)
            # endif
        # endfor
    # enddef

# endclass
