import time
from argparse import Namespace
from typing import *

from torch.utils.data import DataLoader

from approaches.abst_appr import AbstractAppr
from approaches.pathnet.appr_pathnet_orig import Appr as OrigAppr
from approaches.pathnet.model_pathnet_alexnet import PathNetAlexNet
from approaches.pathnet.model_pathnet_mlp import PathNetMLP
from utils import print_num_params


class Appr(AbstractAppr):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float, epochs_max: int, patience_max: int,
                 backbone: str, nhid: int,
                 batch_size: int, drop1: float, drop2: float, expand_factor: float, N: int, M: int):
        lr = lr / 10
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=1, lamb=0)

        taskcla = [(t, ncls) for t, ncls in enumerate(list__ncls)]

        self.args = Namespace(pdrop1=drop1,
                              pdrop2=drop2,
                              expand_factor=expand_factor,
                              M=M, N=N,
                              parameter='',
                              )

        if backbone == 'mlp':
            model = PathNetMLP(inputsize=inputsize, taskcla=taskcla,
                               nhid=nhid, args=self.args).to(self.device)
        elif backbone == 'alexnet':
            model = PathNetAlexNet(inputsize=inputsize, taskcla=taskcla,
                                   nhid=nhid, args=self.args).to(self.device)
        else:
            raise NotImplementedError
        # endif

        print_num_params(model)

        self.appr = OrigAppr(device=self.device,
                             model=model,
                             nepochs=epochs_max,
                             sbatch=batch_size,
                             lr=lr,
                             lr_min=lr_min,
                             lr_factor=lr_factor,
                             lr_patience=patience_max,
                             clipgrad=10000,
                             args=self.args,
                             )
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              **kwargs
              ) -> float:
        time_start = time.time()
        self.appr.train(t=idx_task, dl_train=dl_train, dl_val=dl_val)
        time_end = time.time()
        time_consumed = time_end - time_start

        ret = {
            'time_consumed': time_consumed,
            }

        return ret
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        loss, acc = self.appr.eval(t=idx_task, dl=dl_test)

        return {
            'loss_test': loss,
            'acc_test': acc,
            }
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        pass
    # enddef
