from typing import *

from approaches.abst_appr import AbstractAppr
from approaches.hat.model_hat import ModelHAT
from utils import print_num_params


class Appr(AbstractAppr):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float, epochs_max: int, patience_max: int,
                 nhid: int, drop1: float, drop2: float, backbone: str):
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=-1, lamb=0)

        self.model = ModelHAT(list__ncls=list__ncls, inputsize=inputsize,
                              smax=self.smax, hat_enabled=False,
                              nhid=nhid, drop1=drop1, drop2=drop2, backbone=backbone).to(self.device)

        print_num_params(self.model)
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        pass
    # enddef
# enclasss
