from typing import *

from torch import Tensor, nn
from torch.utils.data import DataLoader

from approaches.ewcgi.model_ewcgi_feature_alexnet import ModelEWCGIFeatureAlexNet
from approaches.spg.ablation import Ablation
from approaches.spg.model_spg_fc import ModelSPGFc
from utils import BColors, assert_type, myprint as print


class ModelEWCGI(nn.Module):
    def __init__(self, list__ncls: List[int], inputsize: Tuple[int, ...], batch_size: int,
                 backbone: str, shift: float,
                 ablation: Optional[str], seqname: str,
                 **kwargs):
        super().__init__()

        nhid = kwargs['nhid']

        if ablation is None:
            self.ablation = Ablation.Asis
        else:
            raise NotImplementedError(ablation)
        # endif
        print(f'Ablation: {self.ablation}', bcolor=BColors.OKGREEN)

        print(f'backbone: {backbone}', bcolor=BColors.OKBLUE)
        if backbone == 'mlp':
            raise NotImplementedError
        elif backbone == 'alexnet':
            drop1 = kwargs['drop1']
            drop2 = kwargs['drop2']
            self.feature = ModelEWCGIFeatureAlexNet(list__ncls, inputsize, batch_size,
                                                    nhid=nhid, drop1=drop1, drop2=drop2,
                                                    shift=shift, ablation=self.ablation,
                                                    seqname=seqname,
                                                    )
            self.fc = ModelSPGFc(list__ncls, dim=self.feature.last_dim)
            self.feature.set_fc(self.fc)
        else:
            raise NotImplementedError
        # endif
    # enddef

    def compute_param_consumed(self, idx_task: int) -> float:
        return 0
    # enddef

    def freeze_masks(self, idx: int, dl: DataLoader, **kwargs):
        self.feature.freeze_masks(idx, dl, **kwargs)
    # enddef

    def forward(self, idx_task: int, x: Tensor, s: float = 0, args_on_forward: Dict[str, Any] = {}) -> Tuple[Tensor, Dict[str, Any]]:
        assert_type(idx_task, int)
        assert_type(x, Tensor)
        assert_type(s, [float, int])

        out, misc = self.feature(x, idx_task, s=s, args_on_forward=args_on_forward)
        out = self.fc(out, idx_task)

        return out, misc
    # enddef

    def on_after_backward_emb(self, s: float):
        pass
    # enddef

    def on_after_backward_params(self, idx_task: int, s: float, args: Dict[str, Any]):
        # don't soft-masking
        # blocking = self.feature.on_after_backward_params(idx_task, s=s, args=args)
        self.fc.on_after_backward()

        # return blocking
    # enddef

# endclass
