import nninfo
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from abc import ABC, abstractmethod
from nninfo.config import CLUSTER_MODE
from nninfo.quantization import quantizer_list_factory

log = nninfo.log.get_logger(__name__)

__all__ = ["Trainer", "Tester"]

# optimizers for pytorch
OPTIMIZERS_PYTORCH = {"SGD": optim.SGD, "Adam": optim.Adam}

# the losses that are available at the moment
LOSSES_PYTORCH = {
    "BCELoss": nn.BCELoss,
    "CELoss": nn.CrossEntropyLoss,
    "MSELoss": nn.MSELoss,
}


class ExperimentComponent(ABC):
    """
    Abstract class that defines parent property for each component of the experiment
    (experiment is then the parent of each component, if they are connected).
    """

    def __init__(self, *args, **kwargs):
        self._parent = None
        super(ExperimentComponent, self).__init__(*args, **kwargs)

    @property
    def parent(self):
        return self._parent

    @parent.setter
    def parent(self, parent):
        if self.parent is not None:
            if parent is not None:
                log.warning(
                    "Parent of {} is changed to experiment {}.".format(type(self)),
                    parent.id,
                )
            else:
                log.info("Parent of {} is removed.".format(type(self)))
        self._parent = parent


class Trainer(ExperimentComponent):
    """
    Trains the network using chapter structure.
    Define your training settings here.
    """

    def __init__(self):
        """
        Initialize a new instance of Trainer.
        Later, set parameters via set_training_parameters().
        """

        super(Trainer, self).__init__()
        self._net = None
        self._task = None

        self._optim_str = None
        self._optimizer = None

        self._batch_size = None
        self._shuffle = True
        self._lr = None

        self._loss = None
        self._loss_str = None

        self._dataset_name = None
        self._n_epochs_trained = 0
        self._n_epochs_chapter = None

        self._n_chapters_trained = 0

        self._quantizer_params = None

        self.train_log = []
        
        # TODO: Implement logging of training.
        
        self.momentum=0.

    @property
    def loss(self):
        return self._loss

    @property
    def lr(self):
        return self._lr

    @property
    def n_epochs_trained(self):
        return self._n_epochs_trained

    @property
    def n_chapters_trained(self):
        return self._n_chapters_trained

    def set_n_epochs_trained(self, n_epochs_trained):
        """
        Sets the number of epochs trained to a new value.
        Should not be called by user, only by experiment.
        """
        log.info("n_epochs_trained is changed from outside.")
        self._n_epochs_trained = n_epochs_trained

    def set_n_chapters_trained(self, n_chapters_trained):
        """
        Sets the number of epochs trained to a new value.
        Should not be called by user, only by experiment.
        """
        log.info("n_chapters_trained is changed from outside.")
        self._n_chapters_trained = n_chapters_trained

    def optimizer_state_dict(self):
        return self._optimizer.state_dict()

    def load_optimizer_state_dict(self, opt_state_dict):
        self._optimizer.load_state_dict(opt_state_dict)

    def get_training_parameters(self):
        param_dict = {
            "dataset_name": self._dataset_name,
            "optim_str": self._optim_str,
            "batch_size": self._batch_size,
            "shuffle": self._shuffle,
            "lr": self._lr,
            "loss_str": self._loss_str,
            "n_epochs_trained": self._n_epochs_trained,
            "n_chapters_trained": self._n_chapters_trained,
            "n_epochs_chapter": self._n_epochs_chapter,
            "quantizer": self._quantizer_params
        }
        return param_dict

    def set_training_parameters(
        self,
        dataset_name=None,
        optim_str=None,
        loss_str=None,
        lr=None,
        shuffle=None,
        batch_size=None,
        n_epochs_chapter=None,
        quantizer=None
    ):
        """
        Sets training parameters. Is also called when loading parameters from file.

        Args:
            dataset_name (str): Name of the dataset in the TaskManagers dataset
                dict that should be trained on.
            optim_str (str): One of the optimizers available in constant OPTIMIZERS_PYTORCH.
                It is easy to add new ones, if necessary, since most commonly used ones are
                already implemented in pytorch.
            loss_str (str): One of the losses available in LOSSES_PYTORCH.
                It is easy to add new ones, if necessary, since most commonly used ones are
                already implemented in pytorch.
            lr (float): The learning rate that should be used for the training.
            shuffle (bool): Whether to shuffle
            batch_size (int): Number of samples from the dataset that should be used together
                as a batch for one training step, for example in (Batch) Stochastic Gradient
                Descent.
            n_epochs_chapter (int): If the number of epochs per chapter is a constant it can
                be also set here. Otherwise it must be passed each time train_chapter is
                called.

        """

        if self.parent is not None:
            if self.parent.components_locked:
                log.error(
                    "For consistency reasons, the trainer settings cannot be changed "
                    + "once the experiment is locked."
                )
                raise PermissionError

        if dataset_name is not None:
            self._dataset_name = dataset_name
        if lr is not None:
            self._lr = lr
        if batch_size is not None:
            self._batch_size = batch_size
        if optim_str is not None:
            self._optim_str = optim_str
        if loss_str is not None:
            self._loss_str = loss_str
        if shuffle is not None:
            self._shuffle = shuffle
        if n_epochs_chapter is not None:
            self._n_epochs_chapter = n_epochs_chapter
        if quantizer is not None:
            self._quantizer_params = quantizer

    @property
    def key_parameters_set(self):
        param_flag = True
        if self._dataset_name is None:
            log.warning(
                "Missing dataset name. If you want to train the full standard,"
                + " set it via set_training_parameters(dataset_name='full_set')"
            )
            param_flag = False
        if self._optim_str is None:
            log.warning("Missing optimizer name as parameter optim_str.")
            param_flag = False
        if self._lr is None:
            log.warning("Missing learning rate lr.")
            param_flag = False
        if self._loss_str is None:
            log.warning("Missing loss name as parameter loss_str.")
            param_flag = False
        if self._batch_size is None:
            log.warning("Missing training batch size as parameter batch_size.")
            param_flag = False
        return param_flag

    def train_chapter(
        self, use_cuda, use_ipex, n_epochs_chapter=None
    ):
        """
        Perform the training steps for a given number of epochs. If no n_epochs_chapter is given
        it is expected to have already been set in set_training_parameters(..).

        Args:
            n_epochs_chapter (int):    Number of epochs to train for this chapter of the training.
        """

        # make experiment components ready for training
        self._start_chapter(n_epochs_chapter, use_ipex)
        # set model to train mode
        self._net.train()

        if use_cuda and not next(self._net.parameters()).is_cuda:
            print('Moving model to CUDA')
            self._net.cuda()

        # create a DataLoader that then feeds the chosen dataset into the network during training
        feeder = DataLoader(
            self._task[self._dataset_name],
            batch_size=self._batch_size,
            shuffle=self._shuffle,
        )

        # central training loop
        temp_start_epochs = self._n_epochs_trained
        for e in range(temp_start_epochs, temp_start_epochs + self._n_epochs_chapter):
            full_loss = 0
            i = 0
            for local_x_batch, local_y_batch in feeder:
                #if i%100==0:
                #    print("batch " + str(i))
                #i += 1
                if use_cuda:
                    local_x_batch = local_x_batch.cuda()
                    local_y_batch = local_y_batch.cuda()

                # zeroes the gradient buffers of all parameters
                self._optimizer.zero_grad()

                self._net.train()
                pred_y = self._net(local_x_batch, quantizers=self._quantizer)
                loss = self._loss(pred_y, local_y_batch)
                loss.backward()
                
                plot_grads = False
                
                if plot_grads and (i-1)%100==0:
                    import matplotlib.pyplot as plt

                    grads = {self._net._activ_func_str[i+1] + str(i): params.grad.data.view(-1).cpu().clone().numpy() for i, layer in enumerate(self._net.layers) for name, params in layer.named_parameters() if "weight" in name}

                    ## Plotting
                    columns = len(grads)
                    fig, ax = plt.subplots(1, columns, figsize=(columns*3.5, 2.5))
                    fig_index = 0
                    for key in grads:
                        key_ax = ax[fig_index%columns]
                        key_ax.hist(grads[key], bins=30)
                        key_ax.set_title(str(key))
                        key_ax.set_xlabel("Grad magnitude")
                        fig_index += 1
                    fig.subplots_adjust(wspace=0.45)
                    plt.show()
                    plt.close()
                
                self._optimizer.step()
                self.train_log.append(loss.cpu().item())
                full_loss += loss.cpu().item() * len(local_y_batch)
            self._n_epochs_trained += 1
            print_str = (
                "trained epoch: "
                + str(self._n_epochs_trained)
                + "; train loss: "
                + str(np.sum(full_loss) / len(feeder.dataset))
                + ("" if CLUSTER_MODE else "; test loss: " + str(self._tester.test(quantizer=self._quantizer)))
            )
            print(print_str)
            log.info(print_str)
        self._end_chapter()

    def _start_chapter(self, n_epochs_chapter, use_ipex=False):
        # WRONG: Does not check whether there is already a zeroth epoch!
        first_overall_epoch = self._n_epochs_trained == 0 and self.parent.run_id == 0
        first_epoch_in_run = self._n_epochs_trained == 0
        if first_overall_epoch:
            self.initialize_components(use_ipex)
            self.parent.lock_and_save_components()
        if first_epoch_in_run:
            self.parent.save_checkpoint()

        if n_epochs_chapter is None and self._n_epochs_chapter is None:
            log.error(
                "Please set n_epochs_chapter in set_parameters() or train_chapter()"
            )
            raise KeyError
        elif n_epochs_chapter is not None:
            self._n_epochs_chapter = n_epochs_chapter
        log.info("Started training chapter {}.".format(self._n_chapters_trained + 1))

    def initialize_components(self, use_ipex=False):
        if not self.key_parameters_set:
            log.error("Not all key parameters have been set.")
            raise KeyError

        self._net = self.parent.network
        if self._optim_str == "SGD":
            self._optimizer = OPTIMIZERS_PYTORCH[self._optim_str](
                self._net.parameters(), lr=self._lr, momentum=self.momentum
            )
        else:
            self._optimizer = OPTIMIZERS_PYTORCH[self._optim_str](
                self._net.parameters(), lr=self._lr
            )
        self._loss = LOSSES_PYTORCH[self._loss_str]()
        self._task = self.parent.task
        self._tester = self.parent.tester
        self._quantizer = quantizer_list_factory(self._quantizer_params, self.parent.network.get_limits_list()) 

        if use_ipex:
            import intel_extension_for_pytorch as ipex
            self._net, self._optimizer = ipex.optimize(self._net, optimizer=self._optimizer)

    def _end_chapter(self):
        self._n_chapters_trained += 1
        log.info("Finished training chapter {}.".format(self._n_chapters_trained))
        print("Finished training chapter {}.".format(self._n_chapters_trained))
        self.parent.save_checkpoint()


class Tester(ExperimentComponent):
    """
    Is called after each training chapter to perform predefined tests and save their results.

    Args:
        dataset_name (str): Name of the dataset in the TaskManagers dataset
            dict that should be tested on.
    """

    def __init__(self, dataset_name):
        super(Tester, self).__init__()
        self._dataset_name = dataset_name
        self._net = None
        self._task = None

    """
    # Not in use, just standard tests performed.
    def set_testing_parameters(
            self,
            dataset_name=None
            ):

        if dataset_name is not None:
            self._dataset_name = dataset_name

    @property
    def key_parameters_set(self):
        param_flag = True
        if self._dataset_name is None:
            print("Missing dataset name." +
                  " Set it via set_testing_parameters()")
            param_flag = False
        return param_flag
    """

    def test(self, dataset_name=None, return_accuracy=False, quantizer_params=None, quantizer=None):
        """
        Performs test with the dataset given as dataset_name at initialization of Tester.

        Returns:
            total loss over all test samples
        """
        self._net = self.parent.network
        self._task = self.parent.task
        loss_fn = self.parent.trainer.loss
        
        if quantizer is None:
            quantizer = quantizer_list_factory(quantizer_params, self._net.get_limits_list())

        self._net.eval()
        if dataset_name is None:
            dataset_name = self._dataset_name
        feeder = DataLoader(
            self._task[dataset_name], batch_size=1000#len(self._task[dataset_name])
        )
        test_loss = 0
        with torch.no_grad():
            if return_accuracy:
                total = 0
                correct = 0
                for x_test, y_test in feeder:
                    pred_y_test, _ = self._net.probe(x_test, quantizer)
                    if pred_y_test.shape[1] == 1:
                        pred = pred_y_test >= 0.5
                        truth = y_test >= 0.5
                        test_acc = float(1.0 * pred.eq(truth).sum() / pred_y_test.shape[0])
                    elif y_test.data.ndim == 1:
                        # One-hot-representation
                        correct += torch.sum(y_test.data == pred_y_test.argmax(dim=1)).item()
                        total += y_test.size(0)
                        test_acc = correct / total
                    else:
                        # Other, i.e. binary output representations
                        correct += torch.sum(torch.all(y_test.data ==
                                                       torch.round(pred_y_test), dim=1)).item()
                        total += y_test.size(0)
                        test_acc = correct / total
                return test_acc
            else:
                total_size=0
                for x_test, y_test in feeder:
                    pred_y_test = self._net(x_test, quantizer)
                    loss = loss_fn(pred_y_test, y_test)
                    test_loss += loss.item() * x_test.shape[0]
                    total_size += x_test.shape[0]
                return test_loss / total_size

    def get_testing_parameters(self):
        param_dict = {"dataset_name": self._dataset_name}
        return param_dict
