import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from syne_tune import Reporter
import numpy as np
import logging
import math

from benchmarking.utils.experiment import Experiment
from .model import JointVAE
from ...train_utils.utils import EarlyStopper, save_checkpoint
from models.data_utils.util import get_device


EPS = 1e-12


class JointVAETrainer:
    """
    Trainer for the JointVAE based on the implementation provided here:
    https://github.com/Schlumberger/joint-vae
    """

    def __init__(
        self,
        model: JointVAE,
        cont_capacity: tuple = None,
        disc_capacity: tuple = None,
        print_loss_every: int = 50,
        record_loss_every: int = 5,
    ):
        """
        Class to handle training of the model with convolutional layers
        with a 1x1 kernel.

        Parameters
        ----------
        model : jointvae.models.VAE instance

        cont_capacity : tuple (float, float, int, float) or None
            Tuple containing (min_capacity, max_capacity, num_iters, gamma_z).
            Parameters to control the capacity of the continuous latent
            channels. Cannot be None if model.is_continuous is True.

        disc_capacity : tuple (float, float, int, float) or None
            Tuple containing (min_capacity, max_capacity, num_iters, gamma_c).
            Parameters to control the capacity of the discrete latent channels.
            Cannot be None if model.is_discrete is True.

        print_loss_every : int
            Frequency at which loss is printed during training.

        record_loss_every : int
            Frequency at which loss is recorded during training.
        """

        self.use_cuda = torch.cuda.is_available()
        self.model = model
        self.print_loss_every = print_loss_every
        self.record_loss_every = record_loss_every
        self.cont_capacity = cont_capacity
        self.disc_capacity = disc_capacity
        self.writer = SummaryWriter()

        if self.model.is_continuous and self.cont_capacity is None:
            raise RuntimeError("Model is continuous but cont_capacity not provided.")

        if self.model.is_discrete and self.disc_capacity is None:
            raise RuntimeError("Model is discrete but disc_capacity not provided.")

        if self.use_cuda:
            self.model.cuda()

        # Initialize attributes
        self.num_steps = 0
        self.batch_size = None
        self.losses = {"loss": [], "recon_loss": [], "kl_loss": []}

        # Keep track of divergence values for each latent variable
        if self.model.is_continuous:
            self.losses["kl_loss_cont"] = []
            # For every dimension of continuous latent variables
            for i in range(self.model.latent_spec["cont"]):
                self.losses[f"kl_loss_cont_{str(i)}"] = []

        if self.model.is_discrete:
            self.losses["kl_loss_disc"] = []
            # For every discrete latent variable
            for i in range(len(self.model.latent_spec["disc"])):
                self.losses[f"kl_loss_disc_{str(i)}"] = []

    def train(
        self,
        train_loader: DataLoader,
        test_loader: DataLoader = None,
        epochs: int = 10,
        lr: float = 1e-4,
        hyperparametertuning: bool = False,
        experiment: Experiment = None,
        early_stopping: bool = True,
        patience: int = 30,
        min_delta: float = 1.02,
        checkpointing: bool = True,
        device_id: int = None,
    ) -> None:
        """
        Trains the model.

        Parameters
        ----------
        train_loader : DataLoader

        epochs : int
            Number of epochs to train the model for.
        """

        self.batch_size = train_loader.batch_size
        self.model.train()
        device = get_device(device_id)
        self.model.to(device)
        self.model.device = device
        optimizer = optim.Adam(self.model.parameters(), lr=lr)

        if hyperparametertuning:
            report = Reporter()

        if early_stopping:
            early_stopper = EarlyStopper(patience=patience, min_delta=min_delta)

        for epoch in range(epochs):

            # Train one epoch
            mean_epoch_loss, mean_epoch_loss_test = self._train_epoch(
                train_loader, test_loader, optimizer, experiment
            )

            self.writer.add_scalar(
                "[Train] Total epoch loss",
                self.batch_size * self.model.input_dim * mean_epoch_loss,
                epoch + 1,
            )
            self.writer.add_scalar(
                "[Test] Total epoch loss",
                self.batch_size * self.model.input_dim * mean_epoch_loss_test,
                epoch + 1,
            )

            # In hyperparameter tuning mode, report metrics.
            if hyperparametertuning:
                report(
                    step=epoch,
                    mean_loss=self.batch_size * self.model.input_dim * mean_epoch_loss,
                    epoch=epoch + 1,
                )

            logging.info(
                f"Epoch [{epoch + 1}/{epochs}]: Avg. Train Loss: {mean_epoch_loss:.8f}, Avg. Test Loss: {mean_epoch_loss_test:.8f}\n"
            )

            if checkpointing:
                save_checkpoint(
                    model=self.model,
                    experiment=experiment,
                    device_id=device_id,
                    epoch=epoch,
                    best=(
                        early_stopper.is_best(mean_epoch_loss_test)
                        if early_stopping
                        else None
                    ),
                )

            if early_stopping and early_stopper.early_stop(mean_epoch_loss_test):
                logging.info("Training was stopped early.")
                break

        # Finally close the tensorboard writer.
        self.writer.flush()
        self.writer.close()

    def _train_epoch(
        self,
        train_loader: DataLoader,
        test_loader: DataLoader = None,
        optimizer: torch.optim.Adam = None,
        experiment: Experiment = None,
    ):
        """
        Trains the model for one epoch.

        Parameters
        ----------
        train_loader : DataLoader
        test_loader : DataLoader
        experiment : Experiment
        """

        epoch_loss = 0.0
        train_mse = 0.0
        print_every_loss = 0.0  # Keeps track of loss to print every

        # Run one epoch of training.
        num_train_batches = math.ceil(
            len(train_loader.dataset) / train_loader.batch_size
        )
        for batch_idx, data in enumerate(train_loader):
            data = data.unsqueeze(1)
            iter_loss, iter_mse = self._train_iteration(
                data, optimizer=optimizer, experiment=experiment
            )
            epoch_loss += iter_loss
            print_every_loss += iter_loss
            if experiment:
                train_mse += iter_mse

            # Print loss info every self.print_loss_every iteration
            if batch_idx % self.print_loss_every == 0:
                if batch_idx == 0:
                    mean_loss = print_every_loss
                else:
                    mean_loss = print_every_loss / self.print_loss_every
                logging.info(
                    "{}/{}\tLoss: {:.6f}".format(
                        batch_idx * len(data),
                        len(train_loader.dataset),
                        self.model.input_dim * mean_loss,
                    )
                )
                print_every_loss = 0.0

        # Calculate test loss
        if test_loader is not None:
            test_loss = 0.0
            test_mse = 0.0
            with torch.no_grad():
                num_test_batches = math.ceil(
                    len(test_loader.dataset) / test_loader.batch_size
                )
                for batch_idx, test_data in enumerate(test_loader):
                    test_data = test_data.unsqueeze(1)
                    test_iter_loss, iter_mse = self._train_iteration(
                        test_data,
                        training=False,
                        optimizer=optimizer,
                        experiment=experiment,
                    )
                    test_loss += test_iter_loss
                    if experiment:
                        test_mse += iter_mse

        if experiment:
            experiment.train_loss_mse.append(train_mse / num_train_batches)
            experiment.test_loss_mse.append(test_mse / num_test_batches)

        # Return mean epoch loss
        return epoch_loss / num_train_batches, test_loss / num_test_batches

    def _train_iteration(
        self,
        data: torch.Tensor,
        training: bool = True,
        optimizer: torch.optim.Adam = None,
        experiment: Experiment = None,
    ):
        """
        Trains the model for one iteration on a batch of data.

        Parameters
        ----------
        data : torch.Tensor
            A batch of data. Shape (N, C, Input Dim.)
        """

        if self.use_cuda:
            data = data.to(self.model.device)

        if training:
            self.num_steps += 1
            optimizer.zero_grad()
            recon_batch, latent_dist = self.model(data)
            loss = self._loss_function(data, recon_batch, latent_dist)
            loss.backward()
            optimizer.step()

            train_loss = loss.item()

            # In case of model comparison also report the mse
            if experiment:
                with torch.no_grad():
                    mse = nn.MSELoss()(
                        recon_batch.view(-1, self.model.input_dim),
                        data.view(-1, self.model.input_dim),
                    ).item()
            else:
                mse = None
            return train_loss, mse
        else:
            # Run inference and calculate loss for test dataset.
            with torch.no_grad():
                recon_batch, latent_dist = self.model(data)
                loss = self._loss_function(data, recon_batch, latent_dist)
                if experiment:
                    mse = nn.MSELoss()(
                        recon_batch.view(-1, self.model.input_dim),
                        data.view(-1, self.model.input_dim),
                    ).item()
                else:
                    mse = None

            test_loss = loss.item()
            return test_loss, mse

    def _loss_function(
        self, data: torch.Tensor, recon_data: torch.Tensor, latent_dist: dict
    ):
        """
        Calculates loss for a batch of data.

        Parameters
        ----------
        data : torch.Tensor
            Input data (e.g. batch of images). Should have shape (N, C, H, W)

        recon_data : torch.Tensor
            Reconstructed data. Should have shape (N, C, H, W)

        latent_dist : dict
            Dict with keys 'cont' or 'disc' or both containing the parameters
            of the latent distributions as values.
        """

        recon_loss = nn.MSELoss()(
            recon_data.view(-1, self.model.input_dim),
            data.view(-1, self.model.input_dim),
        )

        # Loss function takes mean over pixels, so unnormalise this
        recon_loss *= self.model.input_dim

        # Calculate KL divergences
        kl_cont_loss = 0  # Used to compute capacity loss (but not a loss in itself)
        kl_disc_loss = 0  # Used to compute capacity loss (but not a loss in itself)
        cont_capacity_loss = 0
        disc_capacity_loss = 0

        if self.model.is_continuous:
            # Calculate KL divergence
            mean, logvar = latent_dist["cont"]
            kl_cont_loss = self._kl_normal_loss(mean, logvar)
            # Linearly increase capacity of continuous channels
            cont_min, cont_max, cont_num_iters, cont_gamma = self.cont_capacity
            # Increase continuous capacity without exceeding cont_max
            cont_cap_current = (cont_max - cont_min) * self.num_steps / float(
                cont_num_iters
            ) + cont_min
            cont_cap_current = min(cont_cap_current, cont_max)
            # Calculate continuous capacity loss
            cont_capacity_loss = cont_gamma * torch.abs(cont_cap_current - kl_cont_loss)

        if self.model.is_discrete:
            # Calculate KL divergence
            kl_disc_loss = self._kl_multiple_discrete_loss(latent_dist["disc"])
            # Linearly increase capacity of discrete channels
            disc_min, disc_max, disc_num_iters, disc_gamma = self.disc_capacity
            # Increase discrete capacity without exceeding disc_max or theoretical
            # maximum (i.e. sum of log of dimension of each discrete variable)
            disc_cap_current = (disc_max - disc_min) * self.num_steps / float(
                disc_num_iters
            ) + disc_min
            disc_cap_current = min(disc_cap_current, disc_max)
            # Require float conversion here to not end up with numpy float
            disc_theoretical_max = sum(
                float(np.log(disc_dim)) for disc_dim in self.model.latent_spec["disc"]
            )
            disc_cap_current = min(disc_cap_current, disc_theoretical_max)
            # Calculate discrete capacity loss
            disc_capacity_loss = disc_gamma * torch.abs(disc_cap_current - kl_disc_loss)

        # Calculate total kl value to record it
        kl_loss = kl_cont_loss + kl_disc_loss

        # Calculate total loss
        total_loss = recon_loss + cont_capacity_loss + disc_capacity_loss

        # Record losses
        if self.model.training and self.num_steps % self.record_loss_every == 0:
            self.losses["recon_loss"].append(recon_loss.item())
            self.losses["kl_loss"].append(kl_loss.item())
            self.losses["loss"].append(total_loss.item())
            self.writer.add_scalar(
                "Reconstruction loss", recon_loss.item(), self.num_steps
            )
            self.writer.add_scalar(
                "KL loss: Continuous", kl_cont_loss.item(), self.num_steps
            )
            if self.model.is_discrete:
                self.writer.add_scalar(
                    "KL loss: Discrete", kl_disc_loss.item(), self.num_steps
                )
                self.writer.add_scalar(
                    "KL loss: Discrete - After cap and weighting",
                    kl_disc_loss.item(),
                    self.num_steps,
                )
            self.writer.add_scalar(
                "KL loss: Continuous - After cap and weighting",
                kl_cont_loss.item(),
                self.num_steps,
            )
            self.writer.add_scalar("Total loss", total_loss.item(), self.num_steps)
            self.writer.add_scalar(
                "Continuous capacity", cont_cap_current, self.num_steps
            )
            if self.model.is_discrete:
                self.writer.add_scalar(
                    "Discrete capacity", disc_cap_current, self.num_steps
                )

        # To avoid large losses normalise by number of pixels
        return total_loss / self.model.input_dim

    def _kl_normal_loss(self, mean: torch.Tensor, logvar: torch.Tensor):
        """
        Calculates the KL divergence between a normal distribution with
        diagonal covariance and a unit normal distribution.

        Parameters
        ----------
        mean : torch.Tensor
            Mean of the normal distribution. Shape (N, D) where D is dimension
            of distribution.

        logvar : torch.Tensor
            Diagonal log variance of the normal distribution. Shape (N, D)
        """

        # Calculate KL divergence
        kl_values = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp())
        # Mean KL divergence across batch for each latent variable
        kl_means = torch.mean(kl_values, dim=0)
        # KL loss is sum of mean KL of each latent variable
        kl_loss = torch.sum(kl_means)

        # Record losses
        if self.model.training and self.num_steps % self.record_loss_every == 1:
            self.losses["kl_loss_cont"].append(kl_loss.item())
            for i in range(self.model.latent_spec["cont"]):
                self.losses[f"kl_loss_cont_{str(i)}"].append(kl_means[i].item())

        return kl_loss

    def _kl_multiple_discrete_loss(self, alphas: list):
        """
        Calculates the KL divergence between a set of categorical distributions
        and a set of uniform categorical distributions.

        Parameters
        ----------
        alphas : list
            List of the alpha parameters of a categorical (or gumbel-softmax)
            distribution. For example, if the categorical atent distribution of
            the model has dimensions [2, 5, 10] then alphas will contain 3
            torch.Tensor instances with the parameters for each of
            the distributions. Each of these will have shape (N, D).
        """

        # Calculate kl losses for each discrete latent
        kl_losses = [self._kl_discrete_loss(alpha) for alpha in alphas]

        # Total loss is sum of kl loss for each discrete latent
        kl_loss = torch.sum(torch.cat(kl_losses))

        # Record losses
        if self.model.training and self.num_steps % self.record_loss_every == 1:
            self.losses["kl_loss_disc"].append(kl_loss.item())
            for i in range(len(alphas)):
                self.losses[f"kl_loss_disc_{str(i)}"].append(kl_losses[i].item())

        return kl_loss

    def _kl_discrete_loss(self, alpha: torch.Tensor):
        """
        Calculates the KL divergence between a categorical distribution and a
        uniform categorical distribution.

        Parameters
        ----------
        alpha : torch.Tensor
            Parameters of the categorical or gumbel-softmax distribution.
            Shape (N, D)
        """
        disc_dim = int(alpha.size()[-1])
        log_dim = torch.Tensor([np.log(disc_dim)])
        if self.use_cuda:
            log_dim = log_dim.cuda()
        # Calculate negative entropy of each row
        neg_entropy = torch.sum(alpha * torch.log(alpha + EPS), dim=1)
        # Take mean of negative entropy across batch
        mean_neg_entropy = torch.mean(neg_entropy, dim=0)
        # KL loss of alpha with uniform categorical variable
        kl_loss = log_dim + mean_neg_entropy
        return kl_loss
