from typing import *

from torch import Tensor, nn

from approaches.spg.model_spg_fc import ModelSPGFc
from approaches.spgfi.model_spgfi_feature_alexnet import ModelSPGFIFeatureAlexNet
from approaches.spgfi.spgfi import SPGFI
from utils import BColors, assert_type, myprint as print


class ModelSPGFI(nn.Module):
    def __init__(self, list__ncls: List[int], inputsize: Tuple[int, ...], batch_size: int,
                 backbone: str, **kwargs):
        super(ModelSPGFI, self).__init__()

        nhid = kwargs['nhid']

        print(f'backbone: {backbone}', bcolor=BColors.OKBLUE)
        if backbone == 'mlp':
            raise NotImplementedError
        elif backbone == 'alexnet':
            drop1 = kwargs['drop1']
            drop2 = kwargs['drop2']
            self.feature = ModelSPGFIFeatureAlexNet(list__ncls, inputsize, batch_size,
                                                    nhid=nhid, drop1=drop1, drop2=drop2)
            self.fc = ModelSPGFc(list__ncls, dim=self.feature.last_dim)
            self.feature.set_fc(self.fc)
        else:
            raise NotImplementedError
        # endif
    # enddef

    def compute_param_consumed(self, idx_task: int) -> float:
        num_all = 0
        num_blocked = 0
        for n, module in self.named_modules():
            if isinstance(module, SPGFI):
                _num_all, _num_blocked = module.count_consumtion(idx_task, strict=True)
                num_all += _num_all
                num_blocked += _num_blocked
            # endif
        # endfor

        if num_all == 0:
            return 0
        else:
            return num_blocked / num_all
        # endif
    # enddef

    def freeze_ewc_masks(self, idx_task: int, fisher: Dict[str, Tensor], **kwargs):
        self.feature.freeze_ewc_masks(idx_task=idx_task, fisher=fisher, **kwargs)
    # enddef

    def forward(self, idx_task: int, x: Tensor, s: float = 0, args_on_forward: Dict[str, Any] = {}) -> Tuple[Tensor, Dict[str, Any]]:
        assert_type(idx_task, int)
        assert_type(x, Tensor)
        assert_type(s, [float, int])

        out, misc = self.feature(x, idx_task, s=s, args_on_forward=args_on_forward)
        out = self.fc(out, idx_task)

        return out, misc
    # enddef

    def on_after_backward_emb(self, s: float):
        pass
    # enddef

    def on_after_backward_params(self, idx_task: int, s: float, args: Dict[str, Any]):
        blocking = self.feature.on_after_backward_params(idx_task, s=s, args=args)
        self.fc.on_after_backward()

        return blocking
    # enddef

# endclass
