from typing import *

from approaches.abst_appr import AbstractAppr
from approaches.ewcgi.model_ewcgi import ModelEWCGI
from approaches.param_consumable import ParamConsumable
from approaches.spg.ablation import Ablation
from utils import print_num_params


class Appr(AbstractAppr, ParamConsumable):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...], batch_size: int,
                 lr: float, lr_factor: float, lr_min: float, lamb: float, epochs_max: int, patience_max: int,
                 backbone: str, shift: float, ablation: Optional[str], **kwargs):
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=1, lamb=lamb)
        self.model = ModelEWCGI(list__ncls=list__ncls, inputsize=inputsize, batch_size=batch_size,
                                backbone=backbone, shift=shift, ablation=ablation,
                                **kwargs).to(self.device)

        print_num_params(self.model)
    # enddef

    def freeze_mask_on_each_epoch(self, idx_task: int, epoch: int, is_final: bool, **kwargs) -> None:
        kw = kwargs.copy()
        if self.model.ablation in [Ablation.EarlyGradients0, Ablation.EarlyGradients10, Ablation.EarlyGradients20]:
            kw['epoch'] = epoch
            self.operate_spg(idx_task, **kw)
        # endfi
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        if self.model.ablation in [Ablation.EarlyGradients0, Ablation.EarlyGradients10, Ablation.EarlyGradients20]:
            return
        # endif

        self.operate_spg(idx_task, **kwargs)
    # enddef

    def operate_spg(self, idx_task: int, **kwargs):
        dl = kwargs['dl_train']

        self.model.freeze_masks(idx_task, dl, **kwargs)
    # enddef

    def compute_param_consumed(self, idx_task: int) -> float:
        return self.model.compute_param_consumed(idx_task)
    # enddef
# enclasss
