import time
from argparse import Namespace
from typing import *

import torch
from torch.utils.data import DataLoader

from approaches.abst_appr import AbstractAppr
from approaches.pathnet.appr_pathnet_orig import Appr as OrigAppr
from approaches.pathnet.model_pathnet import Net


class Appr(AbstractAppr):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float, epochs_max: int, patience_max: int,
                 batch_size: int, drop1: float, drop2: float):
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=None, lamb=None)

        taskcla = [(t, ncls) for t, ncls in enumerate(list__ncls)]

        self.args = Namespace(pdrop1=drop1,
                              pdrop2=drop2,
                              parameter='',
                              )

        model = Net(inputsize=inputsize, taskcla=taskcla,
                    nhid=2048, args=self.args).to(self.device)
        self.appr = OrigAppr(device=self.device,
                             model=model,
                             nepochs=epochs_max,
                             sbatch=batch_size,
                             lr=lr,
                             lr_min=lr_min,
                             lr_factor=lr_factor,
                             lr_patience=patience_max,
                             clipgrad=10000,
                             args=self.args,
                             )
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              ) -> float:

        list__x_train, list__y_train = [], []
        for x, y in dl_train:
            list__x_train.append(x)
            list__y_train.append(y)
        # endfor
        xtrain = torch.cat(list__x_train, dim=0).to(self.device)
        ytrain = torch.cat(list__y_train, dim=0).to(self.device)

        list__x_val, list__y_val = [], []
        for x, y in dl_val:
            list__x_val.append(x)
            list__y_val.append(y)
        # endfor
        xvalid = torch.cat(list__x_val, dim=0).to(self.device)
        yvalid = torch.cat(list__y_val, dim=0).to(self.device)

        time_start = time.time()
        self.appr.train(t=idx_task,
                        xtrain=xtrain, ytrain=ytrain,
                        xvalid=xvalid, yvalid=yvalid,
                        args=self.args,
                        )
        time_end = time.time()
        time_consumed = time_end - time_start

        return time_consumed
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        list__x_test, list__y_test = [], []
        for x, y in dl_test:
            list__x_test.append(x)
            list__y_test.append(y)
        # endfor
        xtest = torch.cat(list__x_test, dim=0).to(self.device)
        ytest = torch.cat(list__y_test, dim=0).to(self.device)

        loss, acc = self.appr.eval(t=idx_task,
                                   x=xtest, y=ytest)

        return {
            'loss_test': loss,
            'acc_test': acc,
            }
    # enddef

    def complete_learning(self, idx_task: int) -> None:
        pass
    # enddef
