import copy

import numpy as np

from sklearn.utils import check_random_state

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.utils.data as data

from torch.utils.tensorboard import SummaryWriter
from time import time

class EarlyStopping:
    """
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    """
    def __init__(self, patience=5, min_delta=1e-2):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            # reset counter if validation loss improves
            self.counter = 0
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                print('\n INFO: Early stopping')
                self.early_stop = True



def _set_device(disable_cuda=False):
    """Set device to CPU or GPU.

    Parameters
    ----------
    disable_cuda : bool (default=False)
        Whether to use CPU instead of GPU.

    Returns
    -------
    device : torch.device object
        Device to use (CPU or GPU).
    """
    # XXX we might also want to use CUDA_VISIBLE_DEVICES if it is set
    if not disable_cuda and torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    return device


def train(model, dataset_train, dataset_valid=None,
          validation_fraction=None, n_epochs=10, batch_size=128,
          loss_fn=nn.MSELoss(), optimizer=None, scheduler=None,
          return_best_model=False, disable_cuda=False,
          batch_size_predict=None, drop_last=False, numpy_random_state=None,
          is_vae=False, is_nvp=False, is_packed_autoreg=False,
          val_loss_fn=None, verbose=True, shuffle=True,
          tensorboard_path=None, log_loss=False, sampler=None, early_stopping=None):
    """Training model using the provided dataset and given loss function.

    model : pytorch nn.Module
        Model to be trained.

    dataset_train : Tensor dataset.
        Training data set.

    dataset_valid : Tensor dataset.
        If not None, data set used to compute a validation loss. This data set
        is not used to train the model.

    validation_fraction : float in (0, 1).
        If not None, fraction of samples from dataset to put aside to be
        use as a validation set. If dataset_valid is not None then
        dataset_valid overrides validation_fraction.

    n_epochs : int
        Number of epochs

    batch_size : int
        Batch size.

    loss_fn : function
        Pytorch loss function.

    optimizer : object
        Pytorch optimizer

    scheduler : object
        Pytorch scheduler.

    return_best_model : bool
        Whether to return the best model on the validation loss. More exactly,
        if set to True, the model trained at the epoch that lead to the best
        performance on the validation dataset is returned. In this case the
        best validation loss is also returned.

    disable_cuda : bool
        Whether to use CPU instead of GPU.

    batch_size_predict : int
        Batch size to use for the computation of the validation loss
        in case of a very large valid dataset. If None, no batch size is used.

    drop_last : bool
        Whether to drop the last batch in the dataloader if incomplete.

    numpy_random_state : int or numpy RNG
        Used when shuffling the training dataset before splitting it into
        a training and a validation datasets.

    is_vae : bool
        Whether the model we are training is a VAE.

    is_nvp : bool
        Whether the model we are training is a RealNVP.

    val_loss_fn : function
        The function to be used for valid loss.
        If None, train_loss will be used.

    verbose : bool
        Whether to print training information.

    shuffle : bool
        Whether to drop shuffle the data.

    tensorboard_path : string
        Path to the tensorboard directory. If set to none, ignored.


    early_stopping : object
        EarlyStopping object

    Returns
    -------
    model : pytorch nn.Module
        Trained model. If return_best_model is set to True the best validation
        loss is also returned.

    """
    # This makes the training extremely slow but useful for debugguing
    # torch.autograd.set_detect_anomaly(True)

    # use GPU by default if cuda is available, otherwise use CPU
    device = _set_device(disable_cuda=disable_cuda)
    model = model.to(device)
    numpy_rng = check_random_state(numpy_random_state)

    if optimizer is None:
        optimizer = optim.Adam(model.parameters(), lr=1e-3)

    if val_loss_fn is None:
        val_loss_fn = loss_fn

    if tensorboard_path is not None:
        s_time = int(time())
        writer_train = SummaryWriter(tensorboard_path + f"/train_{s_time}", flush_secs=1)

    # dataset_valid has priority over validation_fraction. if no dataset_valid
    # but validation fraction then build dataset_train and dataset_valid
    if dataset_valid is None and validation_fraction is not None:
        # split dataset into a training and validation set
        if validation_fraction <= 0 or validation_fraction >= 1:
            raise ValueError('validation_fraction should be in (0, 1).')

        n_samples = len(dataset_train)
        indices = np.arange(n_samples)
        if shuffle:
            numpy_rng.shuffle(indices)
        ind_split = int(np.floor(validation_fraction * n_samples))
        train_indices, val_indices = indices[ind_split:], indices[:ind_split]
        dataset_valid = data.TensorDataset(*dataset_train[val_indices])
        dataset_train = data.TensorDataset(*dataset_train[train_indices])

    if dataset_valid is not None:
        if is_packed_autoreg:
            dataloader = torch.utils.data.DataLoader(dataset=dataset_valid, batch_size=dataset_valid.__len__())
            X_valid, y_valid = next(iter(dataloader))
        else:
            X_valid = dataset_valid.tensors[0]
            y_valid = dataset_valid.tensors[1]

        if return_best_model:
            if is_packed_autoreg:
                best_per_dim_val_losses = None
            else:
                best_val_loss = np.inf

        if tensorboard_path is not None:
            writer_valid = SummaryWriter(tensorboard_path + f"/valid_{s_time}", flush_secs=1)
        if log_loss:
            model.tracked_val_losses = []

    dataset_train = data.DataLoader(dataset_train, batch_size=batch_size,
                                    shuffle=shuffle, drop_last=drop_last,
                                    sampler=sampler)

    n_train = len(dataset_train.dataset)

    val_scheduler = isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau) \
                    or (
                            isinstance(scheduler, list)
                            and
                            isinstance(scheduler[0],
                                       optim.lr_scheduler.ReduceLROnPlateau)
                    )

    if log_loss:
        model.tracked_losses = []

    for epoch in range(n_epochs):

        model.train()

        if scheduler is not None and not val_scheduler:
            if isinstance(scheduler, list):
                for i, s in enumerate(scheduler):
                    s.step()
            else:
                scheduler.step()

        train_loss = 0
        total_norm = 0.0
        if is_packed_autoreg:
            per_dim_losses = torch.zeros((model.n_preds))

        # training
        for (i, (x, y)) in enumerate(dataset_train):
            x, y = x.to(device), y.to(device)
            x, y = Variable(x), Variable(y)
            model.zero_grad()

            if is_vae:
                out = model(y, x)
                loss = loss_fn(y, out)
                train_loss += len(x) * loss.item()
            elif is_packed_autoreg:
                out = model(y, x)
                loss, per_dim_loss = loss_fn(y, out)
                per_dim_losses += len(x) * per_dim_loss
                train_loss += len(x) * loss.item()
            elif is_nvp:
                loss = -loss_fn(y, x).sum()
                train_loss += loss.item()
            else:
                out = model(x)
                loss = loss_fn(y, out)
                train_loss += len(x) * loss.item()

            if torch.isnan(x).any():
                print(f"at batch {i} x has NaN")

            if torch.isnan(y).any():
                print(f"at batch {i} y has NaN")

            # backward and optimization
            loss.backward()

            # compute gradient norm
            norm = 0.0
            for p in model.parameters():
                param_norm = p.grad.data.norm(2)
                norm += param_norm.item() ** 2
            norm = norm ** (1. / 2)
            # sanity check
            if torch.isnan(torch.Tensor([norm])).any():
                print(f"at batch {i} the gradients are: {norm}")
                print("skip this batch")
            else:
                # clip gradient norm
                torch.nn.utils.clip_grad_norm_(model.parameters(), 200.0)

                if isinstance(optimizer, list):
                    for o in optimizer:
                        o.step()
                else:
                    optimizer.step()

                total_norm += norm

        if verbose:
            print(f"===== gradient norm check (average norm across batches) =====")
            print(f"total norm: {total_norm / n_train}")
            print("\n")

        train_loss /= n_train
        if is_packed_autoreg:
            per_dim_losses /= n_train

        if log_loss:
            model.tracked_losses.append(train_loss)

        if verbose:
            if dataset_valid is None:
                print('[{}/{}] Training loss: {:.4f}'
                      .format(epoch, n_epochs - 1, train_loss))
            else:
                print('[{}/{}] Training loss: {:.4f}'
                      .format(epoch, n_epochs - 1, train_loss), end='\t')

        if tensorboard_path:
            writer_train.add_scalar('loss', train_loss, epoch)
            if is_packed_autoreg:
                for i, value in enumerate(per_dim_losses):
                    writer_train.add_scalar(f'{i}_dim_loss', value, epoch)

        if is_packed_autoreg:
            per_dim_val_losses = torch.zeros((model.n_preds))

        # loss on validation set
        if dataset_valid is not None:
            if not (is_nvp or is_packed_autoreg):
                y_valid_pred = predict(model, X_valid,
                                       batch_size=batch_size_predict,
                                       disable_cuda=disable_cuda, verbose=0,
                                       is_vae=is_vae)
                if is_vae:
                    y_valid_pred = [y_valid_pred, *model.encode(y_valid, X_valid)]

                val_loss = val_loss_fn(y_valid, y_valid_pred).item()
            elif is_packed_autoreg:
                y_valid_pred = predict(model, dataset_valid,
                                       batch_size=batch_size_predict,
                                       disable_cuda=disable_cuda, verbose=0,
                                       is_vae=is_vae,
                                       is_packed_autoreg=is_packed_autoreg)

                val_loss, per_dim_val_losses = val_loss_fn(y_valid, y_valid_pred)
                val_loss = val_loss.item()
            else:
                model.eval()
                val_loss = 0
                for batch_idx, data_t in enumerate(dataset_valid):
                    cond_data = data_t[0].float()
                    cond_data = cond_data.to(device)
                    data_t = data_t[1]
                    data_t = data_t.to(device)
                    with torch.no_grad():
                        val_loss += -val_loss_fn(data_t,
                                                 cond_data).mean().item()  # sum up batch loss

                val_loss = val_loss / len(dataset_valid)



            if log_loss:
                model.tracked_val_losses.append(val_loss)

            if verbose:
                print('Validation loss: {:.4f}'.format(val_loss))

            if tensorboard_path:
                writer_valid.add_scalar('loss', val_loss, epoch)
                if is_packed_autoreg:
                    for i, value in enumerate(per_dim_val_losses):
                        writer_valid.add_scalar(f'{i}_dim_loss', value, epoch)

            if val_scheduler:
                if isinstance(scheduler, list):
                    for i, s in enumerate(scheduler):
                        s.step(per_dim_val_losses[i])
                else:
                    scheduler.step(val_loss)

            if return_best_model:
                if is_packed_autoreg:
                    if best_per_dim_val_losses is None:
                        best_per_dim_val_losses = torch.ones_like(per_dim_losses).detach().numpy()
                        best_per_dim_val_losses *= np.inf
                        best_model = copy.deepcopy(model)
                    for i in range(len(per_dim_losses)):
                        if per_dim_losses[i] < best_per_dim_val_losses[i]:
                            best_model.nets[i] = copy.deepcopy(model.nets[i])
                            best_per_dim_val_losses[i] = per_dim_losses[i]
                    best_val_loss = best_per_dim_val_losses.mean()

                else:
                    if val_loss < best_val_loss:
                        if isinstance(model, torch.jit.RecursiveScriptModule):
                            model.save("my_model")
                            best_model = torch.jit.load("my_model")
                            best_val_loss = val_loss
                        else:
                            best_model = copy.deepcopy(model)  # XXX I don't like this
                            best_val_loss = val_loss
                        best_model.selected_epoch = epoch

            if early_stopping is not None:
                early_stopping(val_loss)
                if early_stopping.early_stop:
                    break

    # return best model and best val loss if we want it
    if (dataset_valid is not None) and return_best_model:

        if n_epochs == 0:  # we return the passed model
            best_model = model
            y_valid_pred = predict(best_model, X_valid,
                                   batch_size=batch_size_predict,
                                   disable_cuda=disable_cuda, verbose=0,
                                   is_packed_autoreg=is_packed_autoreg,
                                   is_vae=is_vae)

            if is_vae:
                y_valid_pred = [y_valid_pred, model.encode(y_valid, X_valid)]
            best_val_loss = val_loss_fn(y_valid, y_valid_pred).item()

        if return_best_model:
            model = best_model

        return best_model, best_val_loss

    return model


def predict(model, dataset, batch_size=None, disable_cuda=False, verbose=0,
            is_vae=False, is_packed_autoreg=False, ):
    """Predict outputs of dataset using trained model"""

    if batch_size is None:
        batch_size = len(dataset)

    model.eval()
    device = _set_device(disable_cuda=disable_cuda)

    dataset = data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    predictions = []
    with torch.no_grad():
        for i, x in enumerate(dataset):
            if is_packed_autoreg:
                (x, y) = x
                y = y.to(device)
            x = x.to(device)
            if is_vae:
                predictions.append(model.sample(x).cpu())
            elif is_packed_autoreg:
                predictions.append(model.forward(y, x).cpu())
            else:
                predictions.append(model.forward(x).cpu())

            if verbose and i % 100 == 0:
                print('[{}/{}]'.format(i, len(dataset)))

    return torch.cat(predictions, dim=0)
