from typing import *

from argparse import Namespace
import torch
import torch.optim as optim
from torch.nn import functional as F
import numpy as np
from time import time
from approaches.hypernet.mnets.classifier_interface import Classifier
from torch.utils.data import DataLoader

from approaches.hypernet.utils import sim_utils as sutils
import approaches.hypernet.utils.optim_step as opstep
import approaches.hypernet.utils.hnet_regularizer as hreg
from approaches.hypernet.utils.torch_utils import get_optimizer
from approaches.hypernet.mlp import train_utils as tutils
import utils


class Appr(object):
    def __init__(self, device: str, list__ncls: List[int]):
        self.device = device
        self.list__ncls = list__ncls
    # enddef

    def test(self, task_id,
             # data,
             dl: DataLoader,
             mnet, hnet, device, shared, config, writer, logger,
             train_iter=None, task_emb=None, cl_scenario=None, test_size=None):
        """Evaluate the current performance using the test set.
        Note:
            The hypernetwork ``hnet`` may be ``None``, in which case it is assumed
            that the main network ``mnet`` has internal weights.
        Args:
            (....): See docstring of function :func:`train`.
            train_iter (int, optional): The current training iteration. If given, it
                is used for tensorboard logging.
            task_emb (torch.Tensor, optional): Task embedding. If given, no task ID
                will be provided to the hypernetwork. This might be useful if the
                performance of other than the trained task embeddings should be
                tested.
                .. note::
                    This option may only be used for ``cl_scenario=1``. It doesn't
                    make sense if the task ID has to be inferred.
            cl_scenario (int, optional): In case the system should be tested on
                another CL scenario than the one user-defined in ``config``.

                .. note::
                    It is up to the user to ensure that the CL scnearios are
                    compatible in this implementation.
            test_size (int, optional): In case the testing shouldn't be performed
                on the entire test set, this option can be used to specify the
                number of test samples to be used.
        Returns:
            (tuple): Tuple containing:
            - **test_acc**: Test accuracy on classification task.
            - **task_acc**: Task prediction accuracy (always 100% for **CL1**).
        """
        if cl_scenario is None:
            cl_scenario = config.cl_scenario
        else:
            assert cl_scenario in [1, 2, 3]
        # endif

        # `task_emb` ignored for other cl scenarios!
        assert task_emb is None or cl_scenario == 1, \
            '"task_emb" may only be specified for CL1, as we infer the ' + \
            'embedding for other scenarios.'

        mnet.eval()
        if hnet is not None:
            hnet.eval()
        # endif

        if train_iter is None:
            logger.info('### Test run ...')
        else:
            logger.info('# Testing network before running training step %d ...' % \
                        train_iter)
        # endif

        # We need to tell the main network, which batch statistics to use, in case
        # batchnorm is used and we checkpoint the batchnorm stats.
        mnet_kwargs = {}
        if mnet.batchnorm_layers is not None:
            if config.bn_distill_stats:
                raise NotImplementedError()
            elif not config.bn_no_running_stats and \
                    not config.bn_no_stats_checkpointing:
                # Specify current task as condition to select correct
                # running stats.
                mnet_kwargs['condition'] = task_id

                if task_emb is not None:
                    # NOTE `task_emb` might have nothing to do with `task_id`.
                    logger.warning('Using batch statistics accumulated for task ' +
                                   '%d for batchnorm, but testing is ' % task_id +
                                   'performed using a given task embedding.')
                # endif
            # endif
        # endif

        with torch.no_grad():
            batch_size = config.val_batch_size
            # FIXME Assuming all output heads have the same size.
            """
            n_head = config.dims[task_id]
            print('n_head: ', n_head)
            """

            """
            if test_size is None or test_size >= data.num_test_samples:
                test_size = data.num_test_samples
            else:
                # Make sure that we always use the same test samples.
                data.reset_batch_generator(train=False, test=True, val=False)
                logger.info('Note, only part of test set is used for this test ' +
                            'run!')
            # endif
            """
            test_size = len(dl.dataset)

            test_loss = 0.0

            # We store all predicted labels and tasks while going over individual
            # test batches.
            correct_labels = np.empty(test_size, np.int)
            pred_labels = np.empty(test_size, np.int)
            correct_tasks = np.ones(test_size, np.int) * task_id
            pred_tasks = np.empty(test_size, np.int)

            curr_bs = batch_size
            N_processed = 0

            # Sweep through the test set.
            idx_batch_selected = 0
            while N_processed < test_size:
                if N_processed + curr_bs > test_size:
                    curr_bs = test_size - N_processed
                # endif
                N_processed += curr_bs

                """
                batch = data.next_test_batch(curr_bs)
                X = data.input_to_torch_tensor(batch[0], device)
                T = data.output_to_torch_tensor(batch[1], device)
                """
                X, T = [], []
                for idx_batch, (x, y) in enumerate(dl):
                    if idx_batch != idx_batch_selected:
                        continue
                    # endif
                    X.append(x.to(device).view(x.shape[0], -1))
                    T.append(y.to(device))
                # endfor
                idx_batch_selected = (idx_batch_selected + 1) % len(dl)
                X = torch.cat(X, dim=0)
                T = torch.cat(T, dim=0)
                T = torch.eye(self.list__ncls[task_id]).to(device)[T]

                ############################
                ### Get main net weights ###
                ############################
                if hnet is None:
                    weights = None
                elif cl_scenario > 1:
                    raise NotImplementedError()
                elif task_emb is not None:
                    weights = hnet.forward(task_emb=task_emb)
                else:
                    weights = hnet.forward(task_id=task_id)
                # endif

                #######################
                ### Get predictions ###
                #######################
                Y_hat_logits = mnet.forward(X, weights=weights, **mnet_kwargs)

                if config.cl_scenario == 1:
                    # Select current head.
                    task_out = [sum(self.list__ncls[:task_id]),
                                sum(self.list__ncls[:task_id + 1])]
                else:
                    raise NotImplementedError()
                    # TODO Choose the predicted output head per sample.
                # endif

                """
                print('task_out:', task_out)
                print('Y_hat_logits:', Y_hat_logits.shape)
                """

                Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]
                # We take the softmax after the output neurons are chosen.
                Y_hat = F.softmax(Y_hat_logits, dim=1).cpu().numpy()

                correct_labels[N_processed - curr_bs:N_processed] = \
                    T.argmax(dim=1, keepdim=False).cpu().numpy()

                pred_labels[N_processed - curr_bs:N_processed] = \
                    Y_hat.argmax(axis=1)

                # Set task prediction to 100% if we do not infer it.
                if cl_scenario > 1:
                    raise NotImplementedError()
                    # pred_tasks[N_processed-curr_bs:N_processed] = \
                    #    predicted_task_id.cpu().numpy()
                else:
                    pred_tasks[N_processed - curr_bs:N_processed] = task_id
                # endif

                # Note, targets are 1-hot encoded.

                """
                print('test Y_hat_logits: ', Y_hat_logits.shape)
                print('test T: ', T.shape)
                """

                test_loss += Classifier.logit_cross_entropy_loss(Y_hat_logits, T,
                                                                 reduction='sum')

                """
                print('test Y_hat: ', Y_hat.argmax(axis=1))
                print('test T: ', T.argmax(dim=1))
                print('test len: ', T.argmax(dim=1).size())
                print('N_processed: ', N_processed)
                """
            # endwhile

            class_n_correct = (correct_labels == pred_labels).sum()
            test_acc = 100.0 * class_n_correct / test_size

            task_n_correct = (correct_tasks == pred_tasks).sum()
            task_acc = 100.0 * task_n_correct / test_size

            test_loss /= test_size

            msg = '### Test accuracy of task %d' % (task_id + 1) \
                  + (' (before training iteration %d)' % train_iter if \
                         train_iter is not None else '') \
                  + ': %.3f' % (test_acc) \
                  + (' (using a given task embedding)' if task_emb is not None \
                         else '') \
                  + (' - task prediction accuracy: %.3f' % task_acc if \
                         cl_scenario > 1 else '')
            logger.info(msg)

            if train_iter is not None:
                writer.add_scalar('test/task_%d/class_accuracy' % task_id,
                                  test_acc, train_iter)

                if config.cl_scenario > 1:
                    writer.add_scalar('test/task_%d/task_pred_accuracy' % \
                                      task_id, task_acc, train_iter)

            return test_acc, task_acc

    def train(self, task_id,
              # data,
              dl_train: DataLoader, dl_val: DataLoader,
              mnet, hnet, device, config, shared, writer, logger):
        """Train the hyper network using the task-specific loss plus a regularizer
        that should overcome catastrophic forgetting.
        :code:`loss = task_loss + beta * regularizer`.
        Args:
            task_id: The index of the task on which we train.
            data: The dataset handler.
            mnet: The model of the main network.
            hnet: The model of the hyper network. May be ``None``.
            device: Torch device (cpu or gpu).
            config: The command line arguments.
            shared (argparse.Namespace): Set of variables shared between functions.
            writer: The tensorboard summary writer.
            logger: The logger that should be used rather than the print method.
        """
        start_time = time()

        """
        print('data: ', data)
        print('data.num_classes: ', config.dims[task_id])
        print('data.num_train_samples: ', data.num_train_samples)
        """

        logger.info('Training network ...')

        mnet.train()
        if hnet is not None:
            hnet.train()
        # endif

        #################
        ### Optimizer ###
        #################
        # Define the optimizers used to train main network and hypernet.
        if hnet is not None:
            theta_params = list(hnet.theta)
            if config.continue_emb_training:
                for i in range(task_id):  # for all previous task embeddings
                    theta_params.append(hnet.get_task_emb(i))

            # Only for the current task embedding.
            # Important that this embedding is in a different optimizer in case
            # we use the lookahead.
            emb_optimizer = get_optimizer([hnet.get_task_emb(task_id)],
                                          config.lr, momentum=config.momentum,
                                          weight_decay=config.weight_decay, use_adam=config.use_adam,
                                          adam_beta1=config.adam_beta1, use_rmsprop=config.use_rmsprop)
        else:
            theta_params = mnet.weights
            emb_optimizer = None
        # endif

        theta_optimizer = get_optimizer(theta_params, config.lr,
                                        momentum=config.momentum, weight_decay=config.weight_decay,
                                        use_adam=config.use_adam, adam_beta1=config.adam_beta1,
                                        use_rmsprop=config.use_rmsprop)

        ################################
        ### Learning rate schedulers ###
        ################################
        if config.plateau_lr_scheduler:
            assert (config.epochs != -1)
            # The scheduler config has been taken from here:
            # https://keras.io/examples/cifar10_resnet/
            # Note, we use 'max' instead of 'min' as we look at accuracy rather
            # than validation loss!
            plateau_scheduler_theta = optim.lr_scheduler.ReduceLROnPlateau(
                # theta_optimizer, 'max', factor=np.sqrt(0.1), patience=5,
                theta_optimizer, 'max', factor=config.lr_factor, patience=config.patience_max,
                # min_lr=0.5e-6, cooldown=0,
                min_lr=config.lr_min, cooldown=0,
                )
            plateau_scheduler_emb = None
            if emb_optimizer is not None:
                plateau_scheduler_emb = optim.lr_scheduler.ReduceLROnPlateau(
                    # emb_optimizer, 'max', factor=np.sqrt(0.1), patience=5,
                    emb_optimizer, 'max', factor=config.lr_factor, patience=config.patience_max,
                    # min_lr=0.5e-6, cooldown=0,
                    min_lr=config.lr_min, cooldown=0,
                    )
            # endif
        # endif

        if config.lambda_lr_scheduler:
            assert (config.epochs != -1)

            def lambda_lr(epoch):
                """Multiplicative Factor for Learning Rate Schedule.
                Computes a multiplicative factor for the initial learning rate based
                on the current epoch. This method can be used as argument
                ``lr_lambda`` of class :class:`torch.optim.lr_scheduler.LambdaLR`.
                The schedule is inspired by the Resnet CIFAR-10 schedule suggested
                here https://keras.io/examples/cifar10_resnet/.
                Args:
                    epoch (int): The number of epochs
                Returns:
                    lr_scale (float32): learning rate scale
                """
                lr_scale = 1.
                if epoch > 180:
                    lr_scale = 0.5e-3
                elif epoch > 160:
                    lr_scale = 1e-3
                elif epoch > 120:
                    lr_scale = 1e-2
                elif epoch > 80:
                    lr_scale = 1e-1
                # endif
                return lr_scale
            # enddef

            lambda_scheduler_theta = optim.lr_scheduler.LambdaLR(theta_optimizer,
                                                                 lambda_lr)
            lambda_scheduler_emb = None
            if emb_optimizer is not None:
                lambda_scheduler_emb = optim.lr_scheduler.LambdaLR(emb_optimizer,
                                                                   lambda_lr)
            # endif
        # endif

        ##############################
        ### Prepare CL Regularizer ###
        ##############################
        # Whether we will calculate the regularizer.
        calc_reg = task_id > 0 and not config.mnet_only and config.beta > 0 and \
                   not config.train_from_scratch

        # Compute targets when the reg is activated and we are not training
        # the first task
        if calc_reg:
            if config.online_target_computation:
                # Compute targets for the regularizer whenever they are needed.
                # -> Computationally expensive.
                targets_hypernet = None
                prev_theta = [p.detach().clone() for p in hnet.theta]
                prev_task_embs = [p.detach().clone() for p in hnet.get_task_embs()]
            else:
                # Compute targets for the regularizer once and keep them all in
                # memory -> Memory expensive.
                targets_hypernet = hreg.get_current_targets(task_id, hnet)
                prev_theta = None
                prev_task_embs = None
            # endif

            # If we do not want to regularize all outputs (in a multi-head setup).
            # Note, we don't care whether output heads other than the current one
            # change.
            regged_outputs = None
            if config.cl_scenario != 2:
                # FIXME We assume here that all tasks have the same output size. What the heck!!!
                # n_y = config.dims[task_id]
                regged_outputs = [list(range(sum(self.list__ncls[:task_id]),
                                             sum(self.list__ncls[:task_id + 1])))
                                  for i in range(task_id)]
            # endif
        # endif

        # We need to tell the main network, which batch statistics to use, in case
        # batchnorm is used and we checkpoint the batchnorm stats.
        mnet_kwargs = {}
        if mnet.batchnorm_layers is not None:
            if config.bn_distill_stats:
                raise NotImplementedError()
            elif not config.bn_no_running_stats and \
                    not config.bn_no_stats_checkpointing:
                # Specify current task as condition to select correct
                # running stats.
                mnet_kwargs['condition'] = task_id
            # endif
        # endif

        ######################
        ### Start training ###
        ######################

        iter_per_epoch = -1
        if config.epochs == -1:
            training_iterations = config.n_iter
        else:
            assert (config.epochs > 0)
            # iter_per_epoch = int(np.ceil(data.num_train_samples / config.batch_size))
            iter_per_epoch = int(np.ceil(len(dl_train.dataset) / config.batch_size))
            training_iterations = config.epochs * iter_per_epoch
        # endif

        summed_iter_runtime = 0

        idx_batch_selected = 0
        patience = 0
        test_acc_best = 0.0
        for i in range(training_iterations):
            ### Evaluate network.
            # We test the network before we run the training iteration.
            # That way, we can see the initial performance of the untrained network.
            if i % config.val_iter == 0:
                self.test(task_id,
                          # data,
                          dl_val,
                          mnet, hnet, device, shared, config, writer,
                          logger, train_iter=i)
                mnet.train()
                if hnet is not None:
                    hnet.train()
                # endif
            # endif

            if i % 200 == 0:
                logger.info('Training step: %d ...' % i)
            # endif

            iter_start_time = time()

            theta_optimizer.zero_grad()
            if emb_optimizer is not None:
                emb_optimizer.zero_grad()
            # endif

            #######################################
            ### Data for current task and batch ###
            #######################################
            """
            batch = data.next_train_batch(config.batch_size)
            X = data.input_to_torch_tensor(batch[0], device, mode='train')
            T = data.output_to_torch_tensor(batch[1], device, mode='train')
            """
            X, T = [], []
            for idx_batch, (x, y) in enumerate(dl_train):
                if idx_batch != idx_batch_selected:
                    continue
                    # pass
                # endif
                X.append(x.to(device).view(x.shape[0], -1))
                T.append(y.to(device))
            # endfor
            idx_batch_selected = (idx_batch_selected + 1) % len(dl_train)
            X = torch.cat(X, dim=0)
            T = torch.cat(T, dim=0)
            T = torch.eye(self.list__ncls[task_id]).to(device)[T]

            # Get the output neurons depending on the continual learning scenario.
            # n_y = config.dims[task_id]
            if config.cl_scenario == 1:
                # Choose current head.
                task_out = [sum(self.list__ncls[:task_id]),
                            sum(self.list__ncls[:task_id + 1])]
            # endif

            ########################
            ### Loss computation ###
            ########################

            # print('hnet forward')
            if config.mnet_only:
                weights = None
            else:
                weights = hnet.forward(task_id=task_id)
            # endif

            # print('weights: ',weights[0])
            Y_hat_logits = mnet.forward(X, weights, **mnet_kwargs)

            # Restrict output neurons
            Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]
            assert (T.shape[1] == Y_hat_logits.shape[1])
            # compute loss on task and compute gradients
            if config.soft_targets:
                soft_label = 0.95
                num_classes = self.list__ncls[task_id]
                soft_targets = torch.where(T == 1,
                                           torch.Tensor([soft_label]),
                                           torch.Tensor([(1 - soft_label) / (num_classes - 1)]))
                soft_targets = soft_targets.to(device)
                loss_task = Classifier.softmax_and_cross_entropy(Y_hat_logits,
                                                                 soft_targets)
            else:
                loss_task = Classifier.logit_cross_entropy_loss(Y_hat_logits, T)
            # endif

            # Compute gradients based on task loss (those might be used in the CL
            # regularizer).
            loss_task.backward(retain_graph=calc_reg, create_graph=calc_reg and \
                                                                   config.backprop_dt)

            # The current task embedding only depends in the task loss, so we can
            # update it already.
            if emb_optimizer is not None:
                emb_optimizer.step()
            # endif

            # 替换为wright 和 regularization 同时进行
            #############################
            ### CL (HNET) Regularizer ###
            #############################
            loss_reg = 0
            dTheta = None

            if calc_reg:
                if config.no_lookahead:
                    dTembs = None
                    dTheta = None
                else:
                    dTheta = opstep.calc_delta_theta(theta_optimizer, False,
                                                     lr=config.lr, detach_dt=not config.backprop_dt)

                    if config.continue_emb_training:
                        dTembs = dTheta[-task_id:]
                        dTheta = dTheta[:-task_id]
                    else:
                        dTembs = None
                    # endif
                # endif

                loss_reg = hreg.calc_fix_target_reg(hnet, task_id,
                                                    targets=targets_hypernet, dTheta=dTheta, dTembs=dTembs,
                                                    mnet=mnet, inds_of_out_heads=regged_outputs,
                                                    prev_theta=prev_theta, prev_task_embs=prev_task_embs,
                                                    batch_size=config.cl_reg_batch_size)

                loss_reg *= config.beta

                loss_reg.backward()
            # endif

            # Now, that we computed the regularizer, we can use the accumulated
            # gradients and update the hnet (or mnet) parameters.
            theta_optimizer.step()

            Y_hat = F.softmax(Y_hat_logits, dim=1)
            classifier_accuracy = Classifier.accuracy(Y_hat, T) * 100.0

            #########################
            # Learning rate scheduler
            #########################
            if config.plateau_lr_scheduler:
                assert (iter_per_epoch != -1)
                if i % iter_per_epoch == 0 and i > 0:
                    # if True:
                    curr_epoch = i // iter_per_epoch
                    logger.info('Computing test accuracy for plateau LR ' +
                                'scheduler (epoch %d).' % curr_epoch)
                    # We need a validation quantity for the plateau LR scheduler.
                    # FIXME we should use an actual validation set rather than the
                    # test set.
                    # Note, https://keras.io/examples/cifar10_resnet/ uses the test
                    # set to compute the validation loss. We use the "validation"
                    # accuracy instead.
                    # FIXME We increase `train_iter` as the print messages in the
                    # test method suggest that the testing has been executed before
                    test_acc, _ = self.test(task_id,
                                            # data,
                                            dl_val,
                                            mnet, hnet, device, shared,
                                            config, writer, logger, train_iter=i + 1)
                    mnet.train()
                    if hnet is not None:
                        hnet.train()
                    # endif

                    plateau_scheduler_theta.step(test_acc)
                    if plateau_scheduler_emb is not None:
                        plateau_scheduler_emb.step(test_acc)
                    # endif
                # endif
            # endif

            if config.lambda_lr_scheduler:
                assert (iter_per_epoch != -1)
                if i % iter_per_epoch == 0 and i > 0:
                    curr_epoch = i // iter_per_epoch
                    logger.info('Applying Lambda LR scheduler (epoch %d).'
                                % curr_epoch)

                    lambda_scheduler_theta.step()
                    if lambda_scheduler_emb is not None:
                        lambda_scheduler_emb.step()
                    # endif
                # endif
            # endif

            ###########################
            ### Tensorboard summary ###
            ###########################
            # We don't wanna slow down training by having too much output.
            if i % 50 == 0:
                writer.add_scalar('train/task_%d/class_accuracy' % task_id,
                                  classifier_accuracy, i)
                writer.add_scalar('train/task_%d/loss_task' % task_id, loss_task, i)
                writer.add_scalar('train/task_%d/loss_reg' % task_id, loss_reg, i)
            # endif

            ### Show the current training progress to the user.
            if i % config.val_iter == 0:
                msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
                      '(on current training batch).'
                logger.debug(msg.format(i, classifier_accuracy))
            # endif

            iter_end_time = time()
            summed_iter_runtime += (iter_end_time - iter_start_time)

            if i % 200 == 0:
                logger.info('Training step: %d ... Done -- (runtime: %f sec)' % \
                            (i, iter_end_time - iter_start_time))
            # endif

            """
            # early stop
            if test_acc > test_acc_best:
                test_acc_best = test_acc
                patience = 0
            else:
                if utils.get_current_lr(theta_optimizer) <= config.lr_min:
                    patience += 1
                else:
                    patience = 0
                # endif
                # print(f'patience: {patience}')
            # endif

            if patience >= config.patience_max:
                break
            # endif
            """
        # endfor

        if mnet.batchnorm_layers is not None:
            if not config.bn_distill_stats and \
                    not config.bn_no_running_stats and \
                    not config.bn_no_stats_checkpointing:
                # Checkpoint the current running statistics (that have been
                # estimated while training the current task).
                for bn_layer in mnet.batchnorm_layers:
                    assert (bn_layer.num_stats == task_id + 1)
                    bn_layer.checkpoint_stats()
                # endfor
            # endif
        # endif

        avg_iter_time = summed_iter_runtime / config.n_iter
        logger.info('Average runtime per training iteration: %f sec.' % \
                    avg_iter_time)

        logger.info('Elapsed time for training task %d: %f sec.' % \
                    (task_id + 1, time() - start_time))
    # enddef
