from typing import *

import torch
from torch import Tensor, nn

from utils import assert_type


class ModelHATFc(nn.Module):
    def __init__(self, list__ncls: List[int]):
        super(ModelHATFc, self).__init__()

        self.list__classifier = nn.ModuleList()
        for dim_output in list__ncls:
            clf = nn.Sequential(
                nn.Linear(2048, dim_output),
                )
            self.list__classifier.append(clf)
        # endfor
    # enddef

    def forward(self, idx_task: int, x: Tensor) -> Tensor:
        assert_type(idx_task, int)
        assert_type(x, Tensor)

        clf = self.list__classifier[idx_task]
        out = clf(x)

        return out
    # enddef

    def on_after_backward(self) -> None:
        torch.nn.utils.clip_grad_norm_(self.parameters(), 10000)
    # enddef

# endclass
