from typing import *

from torch import Tensor, nn

from approaches.hat.model_hat_fc import ModelHATFc
from approaches.hat.model_hat_feature import ModelHATFeature
from utils import assert_type


class ModelHAT(nn.Module):
    def __init__(self, list__ncls: List[int],
                 inputsize: Tuple[int, ...],
                 smax: float, hat_enabled: bool,
                 drop1: float, drop2: float):
        super(ModelHAT, self).__init__()
        self.smax = smax

        self.feature = ModelHATFeature(list__ncls,
                                       inputsize=inputsize,
                                       smax=smax, hat_enabled=hat_enabled,
                                       drop1=drop1, drop2=drop2)
        self.fc = ModelHATFc(list__ncls)
    # enddef

    def freeze_masks(self, idx: int):
        self.feature.freeze_masks(idx)
    # enddef

    def forward(self, idx_task: int, x: Tensor, s: float, args_on_forward: Dict[str, Any]) \
            -> Tuple[Tensor, Dict[str, Any]]:
        assert_type(idx_task, int)
        assert_type(x, Tensor)
        assert_type(s, [float, int])

        out, reg = self.feature(idx_task, x, s=s, args_on_forward=args_on_forward)
        out = self.fc(idx_task, out)

        return out, {'reg': reg}
    # enddef

    def on_after_backward(self, idx_task: int, s: float, args: Dict[str, Any]) -> None:
        self.feature.on_after_backward(idx_task, s=s, args=args)
        self.fc.on_after_backward()
    # enddef

# endclass
