import os.path
import time
from argparse import Namespace
from typing import *

from torch.utils.data import DataLoader

from approaches.abst_appr import AbstractAppr
from approaches.ucl.appr_ucl_orig import Appr as ApprOrig
from approaches.ucl.model_ucl_alexnet import ModelUCLAlexNet
from approaches.ucl.model_ucl_mlp import ModelUCLMLP


class Appr(AbstractAppr):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 backbone: str, batch_size: int, expname: str,
                 epochs_max: int, patience_max: int, lamb: float,
                 nhid: int, drop1: float, drop2: float, ratio: float,
                 alpha: float, beta: float,
                 list__dl_val: List[DataLoader], log_dir: str):
        lr = lr / 10
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=None, lamb=None)

        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        # endif

        taskcla = [(t, ncls) for t, ncls in enumerate(list__ncls)]
        if backbone == 'mlp':
            model = ModelUCLMLP(taskcla=taskcla, inputsize=inputsize, ratio=ratio,
                                nhid=nhid, drop1=drop1, drop2=drop2)
            conv_net = False
        elif backbone == 'alexnet':
            model = ModelUCLAlexNet(taskcla=taskcla, inputsize=inputsize, ratio=ratio,
                                    nhid=nhid, drop1=drop1, drop2=drop2)
            conv_net = True
        else:
            raise NotImplementedError(backbone)
        # endif

        args = Namespace(alpha=alpha,
                         beta=beta,
                         ratio=ratio,
                         conv_net=conv_net)

        self.appr = ApprOrig(model=model, inputsize=inputsize,
                             epochs_max=epochs_max, sbatch=batch_size,
                             lr=lr, lr_min=lr_min, lr_factor=lr_factor, patience_max=patience_max,
                             lamb=lamb, log_name=expname, device=device,
                             list__dl_val=list__dl_val, args=args, log_dir=log_dir)
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              list__dl_test: List[DataLoader] = None,
              ) -> Dict[str, float]:
        t1 = time.time()
        self.appr.train(t=idx_task,
                        dl_train=dl_train, dl_val=dl_val)
        t2 = time.time()

        ret = {
            'time_consumed': (t2 - t1)
            }

        return ret
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        loss, acc = self.appr.eval(t=idx_task, dl=dl_test)

        ret = {
            'loss_test': loss,
            'acc_test': acc,
            }

        return ret
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        pass
    # enddef
