import pdb
import time
import torch
import torch.optim as optim
import numpy as np

from tqdm import *
from gpvae.estimators.gpvae_estimators import elbo_estimator
from gpvae.utils.metric_utils import rmse, mll


class Trainer:
    """Train your model!

    :param: model: A nn.Module, the model to be trained.
    :param loss_fn: A function, the loss function.
    :param train_loader: A DataLoader, iterates over the training
    dataset.
    :param train_data: A tuple (x, y), the train data to evaluate test data
    performance every print_epoch_freq epochs.
    :param test_data: A tuple (x, y), the test data to evaluate
    performance on every print_epoch_freq epochs.
    :param epochs: An int, the number of epochs to train for.
    :param learning_rate: A float, the learning rate.
    :param num_samples: An int, the number of samples to estimate
    the ELBO with.
    :param print_epoch_freq: An int, how frequently to print
    performance.
    :param log_dir: A string (or None), directory to output
    tensorboard.
    :param device: A torch.device, device to perform computations on.
    """
    def __init__(self, model, loss_fn, train_loader, epochs, test_data=None,
                 learning_rate=0.001, num_samples=1, print_epoch_freq=50,
                 print_batch_freq=None, checkpoints=None, log_dir=None,
                 device=torch.device('cpu')):

        self.model = model
        self.model.to(device)
        self.loss_fn = loss_fn
        self.train_loader = train_loader
        self.N = len(self.train_loader.dataset)
        self.epochs = epochs

        self.train_data = self.train_loader.dataset.get_dataset(device)
        if test_data is not None:
            self.test_data = (test_data[0].to(device), test_data[1].to(device))
        else:
            self.test_data = None

        self.learning_rate = learning_rate
        self.optimiser = optim.Adam(self.model.parameters(),
                                    self.learning_rate)
        self.num_samples = num_samples
        self.print_epoch_freq = print_epoch_freq
        self.print_batch_freq = print_batch_freq
        self.checkpoints = checkpoints
        self.log_dir = log_dir
        self.device = device
        self.start_epoch = 0

        # performance metrics
        self.epoch_losses = []
        self.compute_loss_times = []
        self.gradient_step_times = []
        self.test_rmses = []
        self.test_lls = []
        self.elbos = []

    def save_checkpoint(self, epoch):
        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.model.state_dict(),
            'optimiser': self.optimiser.state_dict(),
            'losses': self.epoch_losses},
            self.checkpoints)

    def load_checkpoint(self):
        try:
            print('Loading checkpoint from {}'.format(self.checkpoints))
            checkpoint = torch.load(self.checkpoints)
            self.start_epoch = checkpoint['epoch']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimiser.load_state_dict(checkpoint['optmiser'])
            self.epoch_losses = checkpoint['losses']
            print('Resuming training from epoch {}'.format(self.start_epoch))
        except FileNotFoundError:
            print('No checkpoint exists at {}. Starting training '
                  'afresh.'.format(self.checkpoints))
            self.start_epoch = 0

    def train_model(self, test_samples=10, contains_nan=False,
                    decoder_scale=None, compute_elbo=False):
        self.model.train(True)
        for epoch in tqdm(range(self.start_epoch, self.epochs)):
            losses = []
            loss_times = []
            backwards_times = []
            for i, batch in enumerate(self.train_loader):
                batch = self.train_loader.dataset.move_data(batch, self.device)
                self.optimiser.zero_grad()
                t1 = time.time()
                loss = self.loss_fn(self.model, batch, self.num_samples,
                                    contains_nan, decoder_scale)
                t2 = time.time()
                loss.backward()
                self.optimiser.step()
                t3 = time.time()
                losses.append(loss.item())
                loss_times.append(t2 - t1)
                backwards_times.append(t3 - t2)

                if ((self.print_batch_freq is not None) and
                        (i % self.print_batch_freq)):
                    tqdm.write('Batch {}. Loss: {:2.2f}'.format(i,
                                                                loss.item()))

            mean_loss = np.mean(losses)
            mean_loss_time = np.mean(loss_times)
            mean_backwards_time = np.mean(backwards_times)

            self.epoch_losses.append(mean_loss)
            self.compute_loss_times.append(mean_loss_time)
            self.gradient_step_times.append(mean_backwards_time)

            if ((epoch % self.print_epoch_freq) == 0
                    or epoch == (self.epochs - 1)):
                tqdm.write('Epoch {}. Average loss: {:.3f}'.format(epoch,
                                                                   mean_loss))

                if self.checkpoints is not None:
                    self.save_checkpoint(epoch)

                if compute_elbo:
                    elbo = elbo_estimator(self.model, self.train_data,
                                          num_samples=100,
                                          contains_nan=contains_nan,
                                          decoder_scale=decoder_scale)
                    tqdm.write('Train ELBO: {:.3f}.'.format(elbo))
                    self.elbos.append(elbo)

                if self.test_data is not None:
                    self.model.eval()
                    mean, sigma, mean_samples, sigma_samples = \
                        self.model.predict_y(self.train_data,
                                             self.test_data[0],
                                             num_samples=test_samples,
                                             contains_nan=contains_nan)
                    self.model.train(True)
                    pdb.set_trace()
                    test_rmse = rmse(self.test_data[1], mean)
                    test_ll = mll(self.test_data[1], mean, sigma)
                    test_ll_samples = mll(self.test_data[1],
                                                 mean_samples, sigma_samples)
                    tqdm.write('Test RMSE: {:.3f}, Test LL: {:.3f} or '
                               '{:.3f}'.format(test_rmse, test_ll, test_ll_samples))
                    self.test_rmses.append(test_rmse)
                    self.test_lls.append(test_ll)

        tqdm.write('Training complete.')
