import time
from typing import *

from torch.utils.data import DataLoader

from approaches.abst_appr import AbstractAppr
from approaches.agem.appr_agem_orig import Appr as ApprOrig
from approaches.hat.model_hat import ModelHAT
from utils import print_num_params


class Appr(AbstractAppr):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int,
                 backbone: str, nhid: int, drop1: float, drop2: float,
                 buffer_size: int, buffer_percent: float,
                 ):
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=1, lamb=0)

        model = ModelHAT(list__ncls=list__ncls, inputsize=inputsize,
                         nhid=nhid, smax=0, hat_enabled=False, backbone=backbone,
                         drop1=drop1, drop2=drop2, eq_ncls=True)

        print_num_params(model)

        self.appr = ApprOrig(model=model, device=device,
                             epochs_max=epochs_max, patience_max=patience_max,
                             lr=lr, lr_min=lr_min, lr_factor=lr_factor,
                             buffer_size=buffer_size, buffer_percent=buffer_percent)
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              list__dl_test: List[DataLoader] = None,
              ) -> Dict[str, float]:
        t1 = time.time()
        self.appr.train(t=idx_task, dl_train=dl_train, dl_val=dl_val)
        t2 = time.time()

        ret = {
            'time_consumed': (t2 - t1),
            }

        return ret
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        loss, acc = self.appr.eval(t=idx_task, dl=dl_test)

        ret = {
            'loss_test': loss,
            'acc_test': acc,
            }

        return ret
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        pass
    # enddef
