from copy import deepcopy
from typing import *

from torch.utils.data import DataLoader

from approaches.abst_appr import AbstractAppr
from approaches.hat.model_hat import ModelHAT
from utils import print_num_params


class Appr(AbstractAppr):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int, backbone: str,
                 nhid: int, drop1: float, drop2: float):
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=1, lamb=0)

        self.model = ModelHAT(list__ncls=list__ncls, inputsize=inputsize,
                              smax=self.smax, hat_enabled=False,
                              nhid=nhid, drop1=drop1, drop2=drop2, backbone=backbone).to(self.device)
        self.model_backup = deepcopy(self.model)

        print_num_params(self.model)

        self.results_test = {}
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              list__dl_test: List[DataLoader] = None,
              ) -> Dict[str, float]:
        self.model = deepcopy(self.model_backup)

        return super().train(idx_task, dl_train, dl_val, args_on_forward, args_on_after_backward, list__dl_test)
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        if idx_task not in self.results_test.keys():
            r = super().test(idx_task, dl_test, args_on_forward)
            self.results_test[idx_task] = r
        # endif

        return self.results_test[idx_task]
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        pass
    # enddef
