import time
from typing import *

from torch.utils.data import DataLoader

from approaches.abst_appr import AbstractAppr
from approaches.ewc.mixin_ewc import MixinEWC
from approaches.spgfi.model_spgfi import ModelSPGFI
from utils import print_num_params


class Appr(AbstractAppr, MixinEWC):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...], batch_size: int,
                 lr: float, lr_factor: float, lr_min: float, epochs_max: int, patience_max: int,
                 backbone: str, **kwargs):
        MixinEWC.__init__(self)
        AbstractAppr.__init__(self, device=device, list__ncls=list__ncls, inputsize=inputsize,
                              lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                              epochs_max=epochs_max, patience_max=patience_max,
                              smax=1, lamb=0)
        self.model = ModelSPGFI(list__ncls=list__ncls, inputsize=inputsize, batch_size=batch_size,
                                backbone=backbone, **kwargs).to(self.device)

        print_num_params(self.model)
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              list__dl_test: List[DataLoader] = None,
              ) -> Dict[str, float]:
        ret = super().train(idx_task=idx_task, dl_train=dl_train, dl_val=dl_val,
                            args_on_forward=args_on_forward,
                            args_on_after_backward=args_on_after_backward)

        time_start = time.time()
        self.ewc_in_train(idx_task=idx_task, dl_train=dl_train,
                          smax=self.smax, args_on_forward=args_on_forward)
        time_end = time.time()
        ret['time_consumed'] += (time_end - time_start)

        return ret
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        self.operate_ewcsm(idx_task, **kwargs)
    # enddef

    def operate_ewcsm(self, idx_task: int, **kwargs):
        # dl = kwargs['dl_train']

        fisher = self.fisher
        self.model.freeze_ewc_masks(idx_task=idx_task, fisher=fisher, **kwargs)
    # enddef

    def compute_param_consumed(self, idx_task: int) -> float:
        return self.model.compute_param_consumed(idx_task)
    # enddef
# enclasss
