from typing import *

from torch import Tensor, nn
from torch.utils.data import DataLoader

from approaches.spg.ablation import Ablation
from approaches.spg.model_spg_fc import ModelSPGFc
from approaches.spg.model_spg_feature_alexnet import ModelSPGFeatureAlexNet
from approaches.spg.spg import SPG
from utils import BColors, assert_type, myprint as print


class ModelSPG(nn.Module):
    def __init__(self, list__ncls: List[int], inputsize: Tuple[int, ...], batch_size: int,
                 backbone: str, shift: float,
                 ablation: Optional[str], seqname: str,
                 **kwargs):
        super(ModelSPG, self).__init__()

        nhid = kwargs['nhid']

        if ablation is None:
            self.ablation = Ablation.Asis
        elif ablation == 'nochi':
            self.ablation = Ablation.NoCrossHeadImportance
        elif ablation == 'eg0':
            self.ablation = Ablation.EarlyGradients0
        elif ablation == 'eg10':
            self.ablation = Ablation.EarlyGradients10
        elif ablation == 'eg20':
            self.ablation = Ablation.EarlyGradients20
        else:
            raise NotImplementedError(ablation)
        # endif
        print(f'Ablation: {self.ablation}', bcolor=BColors.OKGREEN)

        print(f'backbone: {backbone}', bcolor=BColors.OKBLUE)
        if backbone == 'mlp':
            raise NotImplementedError
        elif backbone == 'alexnet':
            drop1 = kwargs['drop1']
            drop2 = kwargs['drop2']
            self.feature = ModelSPGFeatureAlexNet(list__ncls, inputsize, batch_size,
                                                  nhid=nhid, drop1=drop1, drop2=drop2,
                                                  shift=shift, ablation=self.ablation,
                                                  seqname=seqname,
                                                  )
            self.fc = ModelSPGFc(list__ncls, dim=self.feature.last_dim)
            self.feature.set_fc(self.fc)
        else:
            raise NotImplementedError
        # endif
    # enddef

    def compute_param_consumed(self, idx_task: int) -> float:
        num_all = 0
        num_blocked = 0
        for n, module in self.named_modules():
            if isinstance(module, SPG):
                _num_all, _num_blocked = module.count_consumtion(idx_task, strict=True)
                num_all += _num_all
                num_blocked += _num_blocked
            # endif
        # endfor

        if num_all == 0:
            return 0
        else:
            return num_blocked / num_all
        # endif
    # enddef

    def freeze_masks(self, idx: int, dl: DataLoader, **kwargs):
        self.feature.freeze_masks(idx, dl, **kwargs)
    # enddef

    def forward(self, idx_task: int, x: Tensor, s: float = 0, args_on_forward: Dict[str, Any] = {}) -> Tuple[Tensor, Dict[str, Any]]:
        assert_type(idx_task, int)
        assert_type(x, Tensor)
        assert_type(s, [float, int])

        out, misc = self.feature(x, idx_task, s=s, args_on_forward=args_on_forward)
        out = self.fc(out, idx_task)

        return out, misc
    # enddef

    def on_after_backward_emb(self, s: float):
        pass
    # enddef

    def on_after_backward_params(self, idx_task: int, s: float, args: Dict[str, Any]):
        blocking = self.feature.on_after_backward_params(idx_task, s=s, args=args)
        self.fc.on_after_backward()

        return blocking
    # enddef

# endclass
