from typing import *

from torch import Tensor, nn

from base_model_type import BaseModelType
from hat.model_hat_fc import ModelHATFc
from hat.model_hat_feature import ModelHATFeature
from utils import assert_type


class ModelHATEWC(nn.Module):
    def __init__(self, model_base: BaseModelType, list__ds_ncls: List[int],
                 ch: int, inputsize: int,
                 smax: float, hat_enabled: bool):
        super(ModelHATEWC, self).__init__()

        self.feature = ModelHATFeature(model_base, list__ds_ncls,
                                       ch=ch, inputsize=inputsize,
                                       smax=smax, hat_enabled=hat_enabled)
        self.fc = ModelHATFc(list__ds_ncls)
    # enddef

    def expand_tasks(self, num_tasks_append: int):
        self.feature.expand_tasks(num_tasks_append)
    # enddef

    def freeze_masks(self, idx: int):
        self.feature.freeze_masks(idx)
    # enddef

    def operate_embedding(self, idx: int, freeze: bool):
        self.feature.operate_embedding(idx, freeze)
    # enddef

    def replace_emb(self, idx_src: int, idx_dst: int):
        self.feature.replace_emb(idx_src=idx_src, idx_dst=idx_dst)
    # enddef

    def forward(self, x: Tensor, task_index: Tensor, **kwargs) -> Tuple[Tensor, Dict[str, Any]]:
        assert_type(x, Tensor)
        assert_type(task_index, Tensor)

        out, reg = self.feature(x, task_index, **kwargs)
        out = self.fc(out, task_index, **kwargs)

        return out, {'reg': reg}
    # enddef

    def on_after_backward(self, task_index: int, s: float, **kwargs) -> None:
        self.feature.on_after_backward(task_index, s, **kwargs)
        self.fc.on_after_backward()
    # enddef

# endclass
