import os
import logging
import math
from functools import reduce
from collections import defaultdict
import json
from timeit import default_timer
import matplotlib.pyplot as plt
import wandb
from tqdm import trange, tqdm
import numpy as np
import torch

from disvae.models.losses import get_loss_f
from disvae.utils.math import log_density_gaussian
from disvae.utils.modelIO import save_metadata
import matplotlib.gridspec as gridspec
import brewer2mpl

bmap = brewer2mpl.get_map('Set1', 'qualitative', 3)
colors = bmap.mpl_colors

TEST_LOSSES_FILE = "test_losses.log"
METRICS_FILENAME = "metrics.log"
METRIC_HELPERS_FILE = "metric_helpers.pth"
VAR_THRESHOLD = 5e-2

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
eps = 1e-8
from tqdm import trange

class Normal(nn.Module):
    """Samples from a Normal distribution using the reparameterization trick.
    """

    def __init__(self, mu=0, sigma=1):
        super(Normal, self).__init__()
        self.normalization = Variable(torch.Tensor([np.log(2 * np.pi)]))

        self.mu = Variable(torch.Tensor([mu]))
        self.logsigma = Variable(torch.Tensor([math.log(sigma)]))

    def _check_inputs(self, size, mu_logsigma):
        if size is None and mu_logsigma is None:
            raise ValueError(
                'Either one of size or params should be provided.')
        elif size is not None and mu_logsigma is not None:
            mu = mu_logsigma.select(-1, 0).expand(size)
            logsigma = mu_logsigma.select(-1, 1).expand(size)
            return mu, logsigma
        elif size is not None:
            mu = self.mu.expand(size)
            logsigma = self.logsigma.expand(size)
            return mu, logsigma
        elif mu_logsigma is not None:
            mu = mu_logsigma.select(-1, 0)
            logsigma = mu_logsigma.select(-1, 1)
            return mu, logsigma
        else:
            raise ValueError(
                'Given invalid inputs: size={}, mu_logsigma={})'.format(
                    size, mu_logsigma))

    def sample(self, size=None, params=None):
        mu, logsigma = self._check_inputs(size, params)
        std_z = Variable(torch.randn(mu.size()).type_as(mu.data))
        sample = std_z * torch.exp(logsigma) + mu
        return sample

    def log_density(self, sample, params=None):
        if params is not None:
            mu, logsigma = self._check_inputs(None, params)
        else:
            mu, logsigma = self._check_inputs(sample.size(), None)
            mu = mu.type_as(sample)
            logsigma = logsigma.type_as(sample)

        c = self.normalization.type_as(sample.data)
        inv_sigma = torch.exp(-logsigma)
        tmp = (sample - mu) * inv_sigma
        return -0.5 * (tmp * tmp + 2 * logsigma + c)

    def NLL(self, params, sample_params=None):
        """Analytically computes
            E_N(mu_2,sigma_2^2) [ - log N(mu_1, sigma_1^2) ]
        If mu_2, and sigma_2^2 are not provided, defaults to entropy.
        """
        mu, logsigma = self._check_inputs(None, params)
        if sample_params is not None:
            sample_mu, sample_logsigma = self._check_inputs(None, sample_params)
        else:
            sample_mu, sample_logsigma = mu, logsigma

        c = self.normalization.type_as(sample_mu.data)
        nll = logsigma.mul(-2).exp() * (sample_mu - mu).pow(2) \
            + torch.exp(sample_logsigma.mul(2) - logsigma.mul(2)) + 2 * logsigma + c
        return nll.mul(0.5)

    def kld(self, params):
        """Computes KL(q||p) where q is the given distribution and p
        is the standard Normal distribution.
        """
        mu, logsigma = self._check_inputs(None, params)
        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mean^2 - sigma^2)
        kld = logsigma.mul(2).add(1) - mu.pow(2) - logsigma.exp().pow(2)
        kld.mul_(-0.5)
        return kld

    def get_params(self):
        return torch.cat([self.mu, self.logsigma])

    @property
    def nparams(self):
        return 2

    @property
    def ndim(self):
        return 1

    @property
    def is_reparameterizable(self):
        return True

    def __repr__(self):
        tmpstr = self.__class__.__name__ + ' ({:.3f}, {:.3f})'.format(
            self.mu.item(), self.logsigma.exp().item())
        return tmpstr

def estimate_entropies(qz_samples, qz_params, q_dist, n_samples=10000, weights=None,dw=True):
    """Computes the term:
        E_{p(x)} E_{q(z|x)} [-log q(z)]
    and
        E_{p(x)} E_{q(z_j|x)} [-log q(z_j)]
    where q(z) = 1/N sum_n=1^N q(z|x_n).
    Assumes samples are from q(z|x) for *all* x in the dataset.
    Assumes that q(z|x) is factorial ie. q(z|x) = prod_j q(z_j|x).

    Computes numerically stable NLL:
        - log q(z) = log N - logsumexp_n=1^N log q(z|x_n)

    Inputs:
    -------
        qz_samples (K, N) Variable
        qz_params  (N, K, nparams) Variable
        weights (N) Variable
    """
    if weights is None:
        qz_samples = qz_samples.index_select(1, Variable(torch.randperm(qz_samples.size(1))[:n_samples].cuda()))
    else:
        sample_inds = torch.multinomial(weights, n_samples, replacement=True)
        qz_samples = qz_samples.index_select(1, sample_inds)

    K, S = qz_samples.size()
    N, _, nparams = qz_params.size()
    assert(nparams == q_dist.nparams)
    assert(K == qz_params.size(1))

    if weights is None:
        weights = -math.log(N)
    else:
        weights = torch.log(weights.view(N, 1, 1) / weights.sum())

    entropies = torch.zeros(K).cuda()


    pbar = trange(S)
    k = 0
    while k < S:
        batch_size = min(10, S - k)
        logqz_i = q_dist.log_density(
            qz_samples.view(1, K, S).expand(N, K, S)[:, :, k:k + batch_size],
            qz_params.view(N, K, 1, nparams).expand(N, K, S, nparams)[:, :, k:k + batch_size])
        k += batch_size

        # computes - log q(z_i) summed over minibatch
        entropies += - torch.logsumexp(logqz_i + weights, dim=0, keepdim=False).data.sum(1)

        pbar.update(batch_size)
    pbar.close()

    entropies /= S

    return entropies
class Evaluator:
    """
    Class to handle training of model.

    Parameters
    ----------
    model: disvae.vae.VAE

    loss_f: disvae.models.BaseLoss
        Loss function.

    device: torch.device, optional
        Device on which to run the code.

    logger: logging.Logger, optional
        Logger.

    save_dir : str, optional
        Directory for saving logs.

    is_progress_bar: bool, optional
        Whether to use a progress bar for training.
    """

    def __init__(self, model, loss_f,
                 device=torch.device("cpu"),
                 logger=logging.getLogger(__name__),
                 save_dir="results",
                 is_progress_bar=True):

        self.device = device
        self.loss_f = loss_f
        self.model = model.to(self.device)
        self.logger = logger
        self.save_dir = save_dir
        self.is_progress_bar = is_progress_bar
        self.logger.info("Testing Device: {}".format(self.device))

    def __call__(self, data_loader, is_metrics=False, is_losses=True):
        """Compute all test losses.

        Parameters
        ----------
        data_loader: torch.utils.data.DataLoader

        is_metrics: bool, optional
            Whether to compute and store the disentangling metrics.

        is_losses: bool, optional
            Whether to compute and store the test losses.
        """
        start = default_timer()
        is_still_training = self.model.training
        self.model.eval()

        metric, losses = None, None
        if is_metrics:
            self.logger.info('Computing metrics...')
            _, _, metrics = self.compute_metrics(data_loader)
            self.logger.info('Losses: {}'.format(metrics))
            # save_metadata(metrics, self.save_dir, filename=METRICS_FILENAME)

        if is_losses:
            self.logger.info('Computing losses...')
            losses = self.compute_losses(data_loader)
            self.logger.info('Losses: {}'.format(losses))
            save_metadata(losses, self.save_dir, filename=TEST_LOSSES_FILE)

        if is_still_training:
            self.model.train()

        self.logger.info('Finished evaluating after {:.1f} min.'.format((default_timer() - start) / 60))

        return metric, losses

    def compute_losses(self, dataloader):
        """Compute all test losses.

        Parameters
        ----------
        data_loader: torch.utils.data.DataLoader
        """
        storer = defaultdict(list)
        for data, _ in tqdm(dataloader, leave=False, disable=not self.is_progress_bar):
            data = data.to(self.device)

            try:
                recon_batch, latent_dist, latent_sample = self.model(data)
                _ = self.loss_f(data, recon_batch, latent_dist, self.model.training,
                                storer, latent_sample=latent_sample)
            except ValueError:
                # for losses that use multiple optimizers (e.g. Factor)
                _ = self.loss_f.call_optimize(data, self.model, None, storer)

        losses = {k: sum(v) / len(v) for k, v in storer.items()}
        if wandb.run:
            wandb.log(losses)
        return losses

    def compute_metrics(self, dataloader, N_x_samples=1000, M_z_samples=100):
        """Compute all the metrics.

        Parameters
        ----------
        data_loader: torch.utils.data.DataLoader
        """
        try:
            lat_sizes = dataloader.dataset.lat_sizes
            lat_names = dataloader.dataset.lat_names
        except AttributeError:
            raise ValueError("Dataset needs to have known true factors of variations to compute the metric. This does not seem to be the case for {}".format(type(dataloader.__dict__["dataset"]).__name__))

        self.logger.info("Computing the empirical distribution q(z|x).")
        params_zCx, labels = self.compute(dataloader)
        samples_z = self.model.reparameterize(*params_zCx)
        N, K = samples_z.shape
        latent_dim = N
        qz_params = torch.cat([p.view(N, K, 1) for p in params_zCx], 2)

        self.logger.info("Estimating the marginal entropy.")
        # marginal entropy H(z_j)
        n_dist = Normal()
        H_zi = estimate_entropies(
            samples_z.view(N, K).transpose(0, 1),
            qz_params.view(N, K, 2),
            n_dist).cpu()

        # conditional entropy H(z|v)

        H_ziCvj = self._estimate_H_zCv(samples_z,qz_params,lat_sizes.tolist(),lat_names).cpu()

        # I[z_j;v_k] = E[log \sum_x q(z_j|x)p(x|v_k)] + H[z_j] = - H[z_j|v_k] + H[z_j]
        mut_info = H_zi - H_ziCvj
        sorted_mut_info = torch.sort(mut_info, dim=1, descending=True)[0].clamp(min=0)

        # H_vi
        H_vj = torch.from_numpy(lat_sizes).float().log()
        metric_helpers = {'H_zi': H_zi.cpu(),
                          'H_ziCvj': H_ziCvj.cpu(),
                          'H_vj':H_vj}
        mig = self._mutual_information_gap(sorted_mut_info, lat_sizes, storer=metric_helpers)
        aam = self._axis_aligned_metric(sorted_mut_info, storer=metric_helpers)
        metric_helpers['max_info']=mut_info.max(1)[0].cpu()

        torch.save(metric_helpers, os.path.join(self.save_dir, METRIC_HELPERS_FILE))
        lat_sizes = lat_sizes.tolist()
        if wandb.run:
            n_dict = {}
            summary_dict = {'mig':mig.item(),
                            'aam':aam.item()}
            for k, v in metric_helpers.items():
                if isinstance(v, torch.Tensor):
                    l = len(v.shape)
                    if l == 0:
                        summary_dict[k] = v.item()
                    elif l == 1:
                        title = [str(i) for i in range(len(v))]
                        n_dict[k] = wandb.Table(title, data=v.tolist())
                    else:
                        title = [str(i) for i in range(v.shape[1])]
                        n_dict[k] = wandb.Table(title, v.tolist())
                else:
                    summary_dict[k] = v

            self._plot_latent_vs_ground(params_zCx, n_dict,latnt_sizes=lat_sizes)
            wandb.log(n_dict)
            wandb.summary.update(summary_dict)
            wandb.summary.update({'mig':mig.item()})
        else:
            self._plot_latent_vs_ground(params_zCx, None, latnt_sizes=lat_sizes)


        return params_zCx, labels, metric_helpers

    def _plot_latent_vs_ground(self, param, n_dict=None, z_inds=None,latnt_sizes=[3, 6, 40, 32, 32]):

        K = param[0].shape[-1]
        qz_means = param[0].view(*(latnt_sizes+[K])).cpu().data
        var = torch.std(qz_means.contiguous().view(-1, K), dim=0).pow(2)

        active_units = torch.arange(0, K)[var > VAR_THRESHOLD].long()
        print('Active units: ' + ','.join(map(str, active_units.tolist())))
        n_active = len(active_units)
        print('Number of active units: {}/{}'.format(n_active, K))

        if z_inds is None:
            z_inds = active_units

        # subplots where subplot[i, j] is gt_i vs. z_j
        mean_scale = qz_means.mean(2).mean(2).mean(2)  # (shape, scale, latent)
        mean_rotation = qz_means.mean(1).mean(2).mean(2)  # (shape, rotation, latent)
        mean_shape = mean_rotation.mean(1)
        mean_pos = qz_means.mean(0).mean(0).mean(0)  # (pos_x, pos_y, latent)

        columns=4
        fig = plt.figure(figsize=(columns, len(z_inds)+1))  # default is (8,6)
        gs = gridspec.GridSpec(len(z_inds), columns)
        gs.update(wspace=0, hspace=0)  # set the spacing between axes.


        # first column
        vmin_pos = torch.min(mean_pos)
        vmax_pos = torch.max(mean_pos)
        for i, j in enumerate(z_inds):
            ax = fig.add_subplot(gs[i * columns])
            ax.imshow(mean_pos[:, :, j].numpy(), cmap=plt.get_cmap('coolwarm'),
                      vmin=vmin_pos, vmax=vmax_pos)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_ylabel(r'$z_' + str(j.item()) + r'$')
            if i == len(z_inds) - 1:
                ax.set_xlabel(r'pos')

        # second column
        vmin_scale = torch.min(mean_scale)
        vmax_scale = torch.max(mean_scale)
        for i, j in enumerate(z_inds):
            ax = fig.add_subplot(gs[1 + i * columns])
            ax.plot(mean_scale[0, :, j].numpy(), color=colors[2])
            ax.plot(mean_scale[1, :, j].numpy(), color=colors[0])
            ax.plot(mean_scale[2, :, j].numpy(), color=colors[1])
            ax.set_ylim([vmin_scale, vmax_scale])
            ax.set_xticks([])
            ax.set_yticks([])
            x0, x1 = ax.get_xlim()
            y0, y1 = ax.get_ylim()
            ax.set_aspect(abs(x1 - x0) / abs(y1 - y0))
            if i == len(z_inds) - 1:
                ax.set_xlabel(r'scale')

        # third column
        vmin_rotation = torch.min(mean_rotation)
        vmax_rotation = torch.max(mean_rotation)
        for i, j in enumerate(z_inds):
            ax = fig.add_subplot(gs[2 + i * columns])
            ax.plot(mean_rotation[0, :, j].numpy(), color=colors[2])
            ax.plot(mean_rotation[1, :, j].numpy(), color=colors[0])
            ax.plot(mean_rotation[2, :, j].numpy(), color=colors[1])
            ax.set_ylim([vmin_rotation, vmax_rotation])
            ax.set_xticks([])
            ax.set_yticks([])
            x0, x1 = ax.get_xlim()
            y0, y1 = ax.get_ylim()
            ax.set_aspect(abs(x1 - x0) / abs(y1 - y0))
            if i == len(z_inds) - 1:
                ax.set_xlabel(r'rotation')

        # forth column
        vmin_shape = mean_shape.min()
        vmax_shape = mean_shape.max()
        for i, j in enumerate(z_inds):
            ax = fig.add_subplot(gs[3 + i * columns])
            ax.plot(mean_shape[ :, j].numpy(), color=colors[2])
            ax.set_ylim([vmin_shape, vmax_shape])
            ax.set_xticks([])
            ax.set_yticks([])
            x0, x1 = ax.get_xlim()
            y0, y1 = ax.get_ylim()
            ax.set_aspect(abs(x1 - x0) / abs(y1 - y0))
            if i == len(z_inds) - 1:
                ax.set_xlabel(r'shape')

        fig.text(0.5, 0.03, 'Ground Truth', ha='center')
        fig.text(0.01, 0.5, 'Learned Latent Variables ', va='center', rotation='vertical')
        if n_dict is None:
            fig.savefig(os.path.join(self.save_dir,'gt_vs_latent.png'))
        else:
            n_dict['gt_vs_latent'] = wandb.Image(fig)

        plt.close()

    def _mutual_information_gap(self, sorted_mut_info, lat_sizes, storer=None):
        """Compute the mutual information gap as in [1].

        References
        ----------
           [1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
           autoencoders." Advances in Neural Information Processing Systems. 2018.
        """
        # difference between the largest and second largest mutual info
        delta_mut_info = sorted_mut_info[:, 0] - sorted_mut_info[:, 1]
        # NOTE: currently only works if balanced dataset for every factor of variation
        # then H(v_k) = - |V_k|/|V_k| log(1/|V_k|) = log(|V_k|)
        H_v = torch.from_numpy(lat_sizes).float().log()
        mig_k = delta_mut_info / H_v
        mig = mig_k.mean()  # mean over factor of variations

        if storer is not None:
            storer["mig_k"] = mig_k

        return mig

    def _axis_aligned_metric(self, sorted_mut_info, storer=None):
        """Compute the proposed axis aligned metrics."""
        numerator = (sorted_mut_info[:, 0] - sorted_mut_info[:, 1:].sum(dim=1)).clamp(min=0)
        aam_k = numerator / sorted_mut_info[:, 0]
        aam_k[torch.isnan(aam_k)] = 0
        aam = aam_k.mean()  # mean over factor of variations

        if storer is not None:
            storer["aam_k"] = aam_k

        return aam

    def compute(self, dataloader):
        """Compute the empiricall disitribution of q(z|x).

        Parameter
        ---------
        dataloader: torch.utils.data.DataLoader
            Batch data iterator.

        Return
        ------
        samples_zCx: torch.tensor
            Tensor of shape (len_dataset, latent_dim) containing a sample of
            q(z|x) for every x in the dataset.

        params_zCX: tuple of torch.Tensor
            Sufficient statistics q(z|x) for each training example. E.g. for
            gaussian (mean, log_var) each of shape : (len_dataset, latent_dim).

        recons: torch.tensor
            reconstruction images.

        labels: torch.tensor
            ground-truth factors.
        """
        len_dataset = len(dataloader.dataset)
        latent_dim = self.model.latent_dim
        n_suff_stat = 2

        q_zCx = torch.zeros(len_dataset, latent_dim, n_suff_stat, device=self.device)
        labels = []
        n = 0
        with torch.no_grad():
            for x, label in dataloader:
                batch_size = x.size(0)
                idcs = slice(n, n + batch_size)
                q_zCx[idcs, :, 0], q_zCx[idcs, :, 1] = self.model.encoder(x.to(self.device))
                z = self.model.reparameterize(q_zCx[idcs, :, 0], q_zCx[idcs, :, 1])

                labels.append(label)
                n += batch_size

        params_zCX = q_zCx.unbind(-1)
        labels = torch.cat(labels)
        return params_zCX, labels

    def _compute_q_zCx(self, dataloader):
        """Compute the empiricall disitribution of q(z|x).

        Parameter
        ---------
        dataloader: torch.utils.data.DataLoader
            Batch data iterator.

        Return
        ------
        samples_zCx: torch.tensor
            Tensor of shape (len_dataset, latent_dim) containing a sample of
            q(z|x) for every x in the dataset.

        params_zCX: tuple of torch.Tensor
            Sufficient statistics q(z|x) for each training example. E.g. for
            gaussian (mean, log_var) each of shape : (len_dataset, latent_dim).
        """
        len_dataset = len(dataloader.dataset)
        latent_dim = self.model.latent_dim
        n_suff_stat = 2

        q_zCx = torch.zeros(len_dataset, latent_dim, n_suff_stat, device=self.device)

        n = 0
        with torch.no_grad():
            for x, label in dataloader:
                batch_size = x.size(0)
                idcs = slice(n, n + batch_size)
                q_zCx[idcs, :, 0], q_zCx[idcs, :, 1] = self.model.encoder(x.to(self.device))
                n += batch_size

        params_zCX = q_zCx.unbind(-1)
        samples_zCx = self.model.reparameterize(*params_zCX)

        return samples_zCx, params_zCX

    def _estimate_H_zCv(self, samples_z, params_zCx, lat_sizes, lat_names):
        """Estimate conditional entropies :math:`H[z|v]`."""
        K = samples_z.size(-1)
        N = reduce((lambda x, y: x * y), lat_sizes)

        nparams = 2
        q_dist = Normal()
        qz_params = params_zCx.view(*(lat_sizes + [K, nparams]))
        qz_samples = samples_z.view(*(lat_sizes + [K]))

        cond_entropies = torch.zeros(len(lat_sizes), K, device=self.device)
        i_fac_var: int
        for i_fac_var, (lat_size, lat_name) in enumerate(zip(lat_sizes, lat_names)):
            idcs = [slice(None)] * len(lat_sizes)
            self.logger.info("Estimating conditional entropies for the {}.".format(lat_name))
            for i in range(lat_size):
                idcs[i_fac_var] = i

                qz_samples_scale = qz_samples[idcs].contiguous()
                qz_params_scale = qz_params[idcs].contiguous()

                cond_entropies_i = estimate_entropies(
                    qz_samples_scale.view(N // lat_size, K).transpose(0, 1),
                    qz_params_scale.view(N // lat_size, K, nparams),
                    q_dist)

                cond_entropies[i_fac_var] += cond_entropies_i / lat_size
        return cond_entropies
