import time
from typing import *

from torch.utils.data import DataLoader

from approaches.abst_appr import AbstractAppr
from approaches.pgn.appr_pnn_orig import Appr as OrigAppr
from approaches.pgn.model_pnn_alexnet import ProgressiveAlexNet
from utils import print_num_params


class Appr(AbstractAppr):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float, epochs_max: int, patience_max: int,
                 backbone: str, nhid: int,
                 batch_size: int, drop1: float, drop2: float, expand_factor: float):
        # lr = lr / 10
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=1, lamb=0)

        taskcla = [(t, ncls) for t, ncls in enumerate(list__ncls)]

        if backbone == 'mlp':
            raise NotImplementedError
        elif backbone == 'alexnet':
            model = ProgressiveAlexNet(inputsize=inputsize,
                                       taskcla=taskcla,
                                       nhid=nhid, drop1=drop1, drop2=drop2,
                                       expand_factor=expand_factor)
        else:
            raise NotImplementedError(backbone)
        # endif

        print_num_params(model)

        self.appr = OrigAppr(device=device, model=model.to(device),
                             nepochs=epochs_max, sbatch=batch_size,
                             lr=lr, lr_min=lr_min, lr_factor=lr_factor,
                             lr_patience=patience_max)
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              list__dl_test: List[DataLoader] = None,
              ) -> Dict[str, float]:

        time_start = time.time()
        self.appr.train(t=idx_task,
                        dl_train=dl_train, dl_val=dl_val)
        time_end = time.time()
        time_consumed = (time_end - time_start)

        ret = {
            'time_consumed': time_consumed
            }
        return ret
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        loss, acc = self.appr.eval(t=idx_task, dl=dl_test)

        return {
            'loss_test': loss,
            'acc_test': acc,
            }
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        pass
    # enddef
