from typing import *

import torch
from torch import Tensor, nn

from utils import assert_type


class ModelSPGFc(nn.Module):
    def __init__(self, list__ncls: List[int], dim: int):
        super().__init__()

        self.list__classifier = nn.ModuleList()
        for ncls in list__ncls:
            clf = nn.Sequential(
                nn.Linear(dim, ncls, bias=True),
                )
            self.list__classifier.append(clf)
        # endfor
    # enddef

    def forward(self, x: Tensor, idx_task: int, **kwargs) -> Tensor:
        assert_type(x, Tensor)
        assert_type(idx_task, int)

        clf = self.list__classifier[idx_task]
        x = x.view(x.shape[0], -1)
        out = clf(x)

        return out
    # enddef

    def on_after_backward(self, **kwargs) -> None:
        pass
        torch.nn.utils.clip_grad_norm_(self.parameters(), 10000)
    # enddef
# endclass
