from typing import *

from torch import Tensor, nn

from approaches.hat.model_hat_fc import ModelHATFc
from approaches.hat.model_hat_feature_alexnet import ModelHATFeatureAlexNet
from utils import BColors, assert_type, myprint as print


class ModelHAT(nn.Module):
    def __init__(self, list__ncls: List[int],
                 inputsize: Tuple[int, ...], nhid: int,
                 smax: float, hat_enabled: bool,
                 backbone: str,
                 eq_ncls: bool = False,
                 **kwargs):
        super(ModelHAT, self).__init__()
        self.smax = smax

        print(f'backbone: {backbone}', bcolor=BColors.OKBLUE)
        if backbone == 'mlp':
            raise NotImplementedError
        elif backbone == 'alexnet':
            drop1 = kwargs['drop1']
            drop2 = kwargs['drop2']
            self.feature = ModelHATFeatureAlexNet(list__ncls, inputsize,
                                                  smax=smax, hat_enabled=hat_enabled,
                                                  nhid=nhid, drop1=drop1, drop2=drop2)
        else:
            raise NotImplementedError
        # endif
        last_dim = self.feature.last_dim

        self.fc = ModelHATFc(list__ncls, dim=last_dim, eq_ncls=eq_ncls)
    # enddef

    def freeze_masks(self, idx: int):
        self.feature.freeze_masks(idx)
    # enddef

    def forward(self, idx_task: int, x: Tensor, s: float = None, args_on_forward: Dict[str, Any] = None) -> Tuple[Tensor, Dict[str, Any]]:
        assert_type(idx_task, int)
        assert_type(x, Tensor)
        assert_type(s, [float, int], allow_none=True)

        out, misc = self.feature(idx_task, x, s=s, args_on_forward=args_on_forward)
        out = self.fc(out, idx_task)

        return out, misc
    # enddef

    def on_after_backward_emb(self, s: float) -> None:
        self.feature.on_after_backward_emb(s=s)
    # enddef

    def on_after_backward_params(self, idx_task: int, s: float, args: Dict[str, Any]):
        blocking = self.feature.on_after_backward_params(idx_task, s=s, args=args)
        self.fc.on_after_backward()

        return blocking
    # enddef

# endclass
