import time
from argparse import Namespace
from typing import *

from torch.utils.data import DataLoader

from approaches.abst_appr import AbstractAppr
from approaches.hypernet.appr_hypernet_orig import Appr as ApprOrig
from approaches.hypernet.mlp import train_utils as tutils
from approaches.hypernet.utils import sim_utils as sutils


class Appr(AbstractAppr):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int,
                 lamb: float,
                 batch_size: int,
                 out_dir: str, expname: str,
                 ):
        super().__init__(device=device, list__ncls=list__ncls,
                         inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=0, lamb=0.0,
                         )
        self.expname = expname
        self.device = device
        self.list__ncls = list__ncls
        self.inputsize = inputsize

        self.appr = ApprOrig(device=device, list__ncls=list__ncls)

        emb_size = 32
        self.config = Namespace(
            out_dir=out_dir,
            loglevel_info=True,
            use_cuda=True,
            num_tasks=len(list__ncls),
            cl_reg_batch_size=None,
            random_seed=123,
            deterministic_run=False,
            mnet_only=False,
            cl_scenario=1,
            mlp_arch='2048,2048',
            custom_network_init=False,
            hyper_chunks=42000,
            hnet_arch='',
            sa_hnet_filters='128,512,256,128',
            sa_hnet_kernels=5,
            sa_hnet_attention_layers='1,3',
            hnet_act='relu',
            temb_size=emb_size,
            emb_size=emb_size,
            hnet_dropout_rate=-1,
            hnet_noise_dim=-1,
            temb_std=-1,
            std_normal_emb=1.0,
            std_normal_temb=1.0,
            hnet_init_shift=False,
            continue_emb_training=False,
            # lr=lr * .04,
            # lr_factor=1.0 / lr_factor,
            # lr_min=lr_min * .04,
            lr=0.001,
            momentum=0.0,
            weight_decay=0.0,
            use_adam=True,
            adam_beta1=lamb,
            use_rmsprop=False,
            plateau_lr_scheduler=False,
            epochs=epochs_max,
            patience_max=patience_max,
            n_iter=epochs_max,
            lambda_lr_scheduler=False,
            batch_size=batch_size,
            val_batch_size=batch_size,
            val_iter=500,
            soft_targets=False,
            train_from_scratch=False,
            beta=lamb,
            online_target_computation=False,
            backprop_dt=False,
            no_lookahead=False,
            )

        _, self.writer, self.logger = sutils.setup_environment(self.config,
                                                               logger_name='det_cl_cifar_%s' % expname)

        # Container for variables shared across function.
        self.shared = Namespace()
        self.shared.experiment = 'mlp'

        ### Load datasets (i.e., create tasks).
        """
        dhandlers = tutils.load_datasets(config, shared, logger,
                                         data_dir='../datasets')
        print('handlers traninig: ', dhandlers[0].num_train_samples)
        """

        ### Create main network.
        # TODO Allow main net only training.
        self.mnet = tutils.get_main_model(self.config,
                                          list__ncls, inputsize,
                                          self.shared, self.logger, self.device,
                                          no_weights=not self.config.mnet_only)

        ### Create the hypernetwork.
        if self.config.mnet_only:
            self.hnet = None
        else:
            self.hnet = tutils.get_hnet_model(self.config, self.mnet, self.logger, self.device)
        # endif

        ### Initialize the performance measures, that should be tracked during
        ### training.
        tutils.setup_summary_dict(self.config, self.shared, self.mnet, hnet=self.hnet)

        # Add hparams to tensorboard, such that the identification of runs is
        # easier.
        self.writer.add_hparams(hparam_dict={**vars(self.config), **{
            'num_weights_main': self.shared.summary['num_weights_main'],
            'num_weights_hyper': self.shared.summary['num_weights_hyper'],
            'num_weights_ratio': self.shared.summary['num_weights_ratio'],
            }}, metric_dict={})

        # FIXME: Method "calc_fix_target_reg" expects a None value.
        # But `writer.add_hparams` can't deal with `None` values.
        if self.config.cl_reg_batch_size == -1:
            self.config.cl_reg_batch_size = None
        # endif
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              ) -> float:
        time_start = time.time()

        if idx_task > 0 and self.config.train_from_scratch:
            # FIXME Since we simply override the current network, future testing
            # on this new network for old tasks doesn't make sense. So we
            # shouldn't report `final` accuracies.
            if self.config.mnet_only:
                self.logger.info('From scratch training: Creating new main network.')
                self.mnet = tutils.get_main_model(self.config,
                                                  self.list__ncls, self.inputsize,
                                                  self.shared,
                                                  self.logger, self.device,
                                                  no_weights=not self.config.mnet_only)
            else:
                self.logger.info('From scratch training: Creating new hypernetwork.')
                self.hnet = tutils.get_hnet_model(self.config, self.mnet,
                                                  self.logger, self.device)
            # endif
        # endif

        self.appr.train(task_id=idx_task,
                        dl_train=dl_train, dl_val=dl_val,
                        mnet=self.mnet, hnet=self.hnet,
                        device=self.device,
                        config=self.config,
                        shared=self.shared,
                        writer=self.writer,
                        logger=self.logger, )
        time_end = time.time()
        time_consumed = time_end - time_start

        return time_consumed
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        test_acc, _ = self.appr.test(task_id=idx_task,
                                     dl=dl_test,
                                     mnet=self.mnet,
                                     hnet=self.hnet,
                                     device=self.device,
                                     shared=self.shared,
                                     config=self.config,
                                     writer=self.writer,
                                     logger=self.logger,
                                     train_iter=None,
                                     task_emb=None,
                                     cl_scenario=None,
                                     test_size=None,
                                     )
        return {
            'loss_test': 0.0,
            'acc_test': test_acc / 100,
            }
    # enddef

    def complete_learning(self, idx_task: int) -> None:
        tutils.save_summary_dict(self.config, self.shared, self.expname)
    # enddef
