import os

# We disable CUDA for this script since the geomloss library can give serious headaches depending on the installed
# CUDA version
os.environ["CUDA_VISIBLE_DEVICES"] = ""

from geomloss import SamplesLoss
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from sklearn.neighbors import KernelDensity  # display as density curves

torch.set_num_threads(10)


def logsumexp(log_pdf, axis=None):
    max_lp = np.max(log_pdf, axis=axis, keepdims=True)
    return np.log(np.sum(np.exp(log_pdf - max_lp), axis=axis)) + np.squeeze(max_lp)


def generate_targets(normal, samples=False):
    n = 500
    n_samples = 300000
    x = np.linspace(0., 1., n)

    if samples:
        if normal:
            init_samples = multivariate_normal.rvs(0.5 * np.ones(1), 0.01 * np.ones(1), size=(n_samples,))[:, None]
            target_samples = np.where(np.random.uniform(0, 1, size=(n_samples,)) > 0.5,
                                      multivariate_normal.rvs(0.1 * np.ones(1), 0.001 * np.ones(1), size=(n_samples,)),
                                      multivariate_normal.rvs(0.9 * np.ones(1), 0.001 * np.ones(1), size=(n_samples,)))[
                             :,
                             None]
        else:
            init_samples = np.random.uniform(0.3, 0.7, size=(n_samples, 1))
            target_samples = np.where(np.random.uniform(0, 1, size=(n_samples, 1)) > 0.5,
                                      np.random.uniform(0.05, 0.15, size=(n_samples, 1)),
                                      np.random.uniform(0.85, 0.95, size=(n_samples, 1)))

        return x, init_samples, target_samples
    else:
        if normal:
            log_pdf = multivariate_normal.logpdf(x, 0.5 * np.ones(1), 0.01 * np.ones(1))
        else:
            log_pdf = -1 * np.ones(n)
            log_pdf[np.logical_and(x > 0.3, x < 0.7)] = 4
        log_pdf -= logsumexp(log_pdf)

        if normal:
            target_log_pdf = logsumexp(
                np.stack((multivariate_normal.logpdf(x, 0.1 * np.ones(1), 0.001 * np.ones(1)),
                          multivariate_normal.logpdf(x, 0.9 * np.ones(1), 0.001 * np.ones(1))), axis=-1), axis=-1)
        else:
            target_log_pdf = -1 * np.ones(n)
            target_log_pdf[np.logical_and(x > 0.05, x < 0.15)] = 6
            target_log_pdf[np.logical_and(x > 0.85, x < 0.95)] = 6
        target_log_pdf -= logsumexp(target_log_pdf)

        return x, log_pdf, target_log_pdf


def sprl_interpolations(normal, axs, alphas, color):
    x, log_pdf, target_log_pdf = generate_targets(normal)

    for ax, alpha in zip(axs, alphas):
        interp = alpha * target_log_pdf + (1 - alpha) * log_pdf
        pdf = np.exp(interp - logsumexp(interp))
        # ax.plot(x, pdf / np.max(pdf), color=color)
        ax.fill_between(x, 0, pdf / np.max(pdf), color=color, alpha=0.5)


def wasserstein_interpolations(normal, axs, alphas, color, device="cpu"):
    x, init_samples, target_samples = generate_targets(normal, samples=True)

    loss = SamplesLoss("sinkhorn", blur=0.01, scaling=0.9, backend="multiscale")
    alpha = torch.ones(init_samples.shape[0], device=device, dtype=torch.float64) / init_samples.shape[0]
    tmp = torch.from_numpy(init_samples).requires_grad_(True).to(device)
    wdist = loss(alpha, tmp, alpha, torch.from_numpy(target_samples).to(device))
    tp = -(torch.autograd.grad(wdist, [tmp])[0] / alpha[:, None]).detach().numpy()

    for ax, alpha in zip(axs, alphas):
        interp_samples = init_samples + alpha * tp
        kde = KernelDensity(kernel="gaussian", bandwidth=0.01).fit(interp_samples)
        dens = np.exp(kde.score_samples(x[:, None]))
        dens /= np.max(dens)
        ax.fill_between(x, 0, dens, color=color, alpha=0.5)


def plot(wasserstein):
    f = plt.figure(figsize=(5.3, 0.6))
    scale = 0.25
    axs = []
    for i in range(0, 4):
        ax = plt.Axes(f, [scale * i + 0.025 * scale, 0, 0.95 * scale, 1])
        f.add_axes(ax)
        axs.append(ax)

    if wasserstein:
        wasserstein_interpolations(True, axs, [0., 0.33, 0.66, 1.], "C0")
        wasserstein_interpolations(False, axs, [0., 0.35, 0.45, 1.], "C1")
    else:
        sprl_interpolations(True, axs, [0., 0.05, 0.2, 1.], "C0")
        sprl_interpolations(False, axs, [0., 0.35, 0.45, 1.], "C1")

    for ax in axs:
        ax.set_xlim(-0.01, 1.01)
        ax.set_ylim(-0.01, 1.01)
        ax.set_xticks(np.linspace(0, 1, 6)[1:-1])
        ax.set_yticks(np.linspace(0, 1, 6)[1:-1])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params('both', length=0, width=0, which='major')
        ax.set_axisbelow(True)
        ax.grid()

    if wasserstein:
        plt.savefig("figures/wasserstein_interpolation.pdf")
    else:
        plt.savefig("figures/kl_interpolation.pdf")


if __name__ == "__main__":
    os.makedirs("figures", exist_ok=True)
    plot(wasserstein=False)
    plot(wasserstein=True)
