from typing import *

import torch
from torch import Tensor, nn

from utils import assert_type


class ModelHATFc(nn.Module):
    def __init__(self, list__ncls: List[int], dim: int, eq_ncls: bool):
        super(ModelHATFc, self).__init__()

        ncls_max = max(list__ncls)

        self.list__classifier = nn.ModuleList()
        for dim_output in list__ncls:
            if eq_ncls:
                do = ncls_max
            else:
                do = dim_output
            # endif

            clf = nn.Sequential(
                nn.Linear(dim, do, bias=True),
                )
            self.list__classifier.append(clf)
        # endfor
    # enddef

    def forward(self, x: Tensor, idx_task: int, **kwargs) -> Tensor:
        assert_type(x, Tensor)
        assert_type(idx_task, int)

        clf = self.list__classifier[idx_task]
        x = x.view(x.shape[0], -1)
        out = clf(x)

        return out
    # enddef

    def on_after_backward(self, **kwargs) -> None:
        pass
        torch.nn.utils.clip_grad_norm_(self.parameters(), 10000)
    # enddef

    '''
    def freeze_masks(self, idx_task: int):
        pass
    # enddef
    '''

# endclass
