from typing import *

from approaches.abst_appr import AbstractAppr
from approaches.hat.model_hat import ModelHAT
from approaches.param_consumable import ParamConsumable
from utils import print_num_params


class Appr(AbstractAppr, ParamConsumable):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float, epochs_max: int, patience_max: int,
                 smax: float, lamb: float,
                 backbone: str,
                 nhid: int, drop1: float, drop2: float, hat_enabled: bool):
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=smax, lamb=lamb)
        self.model = ModelHAT(list__ncls=list__ncls, inputsize=inputsize,
                              smax=smax, hat_enabled=hat_enabled,
                              nhid=nhid, drop1=drop1, drop2=drop2,
                              backbone=backbone).to(self.device)

        print_num_params(self.model)
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        self.model.freeze_masks(idx_task)
    # enddef

    def compute_param_consumed(self, idx_task: int) -> float:
        return self.model.feature.model.compute_param_consumed(idx_task)
    # enddef
# enclasss
