import os
import numpy as np
import matplotlib.pyplot as plt
from deep_sprl.experiments import PointMass2DExperiment
from misc.visualize_maze_results import add_plot

plt.rc('text.latex', preamble=r'\usepackage{amsmath}')
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Roman"],
})

FONT_SIZE = 8
TICK_SIZE = 6


def performance_plot(ax, hard=False):
    base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")

    lines = []
    for method, color in zip(["self_paced", "np_self_paced", "random", "default", "wasserstein"],
                             ["C0", "C1", "C2", "C3", "C4"]):
        exp = PointMass2DExperiment(base_log_dir, method, "ppo", {"hard_likelihood": hard}, seed=0)
        log_dir = os.path.dirname(exp.get_log_dir())
        lines.append(add_plot(log_dir, ax, color))

    # Add the other baselines (always in the hard environment)
    for method, color in zip(["goal_gan", "alp_gmm"], ["C5", "C6"]):
        exp = PointMass2DExperiment(base_log_dir, method, "ppo", {"hard_likelihood": True}, seed=0)
        log_dir = os.path.dirname(exp.get_log_dir())
        lines.append(add_plot(log_dir, ax, color))

    ax.set_yticks([3, 6])
    ax.set_yticklabels([3, 6])

    if hard:
        ax.text(8, 4.15, r'$\mu_2(\mathbf{c})$', fontsize=FONT_SIZE + 2)
    else:
        ax.text(8, 4.15, r'$\mu_1(\mathbf{c})$', fontsize=FONT_SIZE + 2)

    return lines


def kl_plot(hard_likelihood, path=None):
    base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")

    f = plt.figure(figsize=(1.2, 1.4))
    if hard_likelihood:
        ax = plt.Axes(f, [0.35, 0.15, 0.99 - 0.35, 0.99 - 0.15])
    else:
        ax = plt.Axes(f, [0.27, 0.15, 0.99 - 0.27, 0.99 - 0.15])
    f.add_axes(ax)
    # ax = f.gca()

    np.random.seed(0)
    alg_kls = []
    for alg in ["self_paced", "np_self_paced"]:
        exp = PointMass2DExperiment(base_log_dir, alg, "ppo", {"hard_likelihood": hard_likelihood}, seed=0)
        log_dir = os.path.dirname(exp.get_log_dir())
        seeds = [f for f in os.listdir(log_dir) if f.startswith("seed-")]
        teacher = exp.create_self_paced_teacher()
        kls = []
        for seed_dir in seeds:
            teacher.load(os.path.join(os.path.join(log_dir, seed_dir), "iteration-195"))
            kls.append(teacher.target_context_kl(numpy=True))
        alg_kls.append(np.array(kls))

    mu = np.array([np.mean(ak) for ak in alg_kls])
    errs = np.flip(np.array([(np.max(ak), np.min(ak)) for ak in alg_kls]).T, axis=0)
    ax.bar([0.5, 1.], mu, width=0.4, color=["C0", "C1"], yerr=np.abs(errs - mu[None, :]), capsize=2)
    ax.set_xticks([0.5, 1.])
    ax.set_xticklabels(["G-SPRL", "NP-SPRL"])
    ax.tick_params(labelsize=FONT_SIZE)
    ax.set_ylabel("KL-Divergence", fontsize=FONT_SIZE)
    if path is None:
        plt.show()
    else:
        plt.savefig(path)


def g_sprl_plot(seed, iterations, hard_likelihood, path=None):
    base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")
    exp = PointMass2DExperiment(base_log_dir, "self_paced", "ppo", {"hard_likelihood": hard_likelihood}, seed=seed)
    seed_path = exp.get_log_dir()

    teacher = exp.create_self_paced_teacher()
    f = plt.figure(figsize=(4.1, 0.5))
    for i, iteration in enumerate(iterations):
        teacher.load(os.path.join(seed_path, "iteration-%d" % iteration))

        ax = plt.Axes(f, [0.25 * i + 0.01, 0, 0.23, 1])
        f.add_axes(ax)

        samples = []
        for i in range(0, 1000):
            samples.append(teacher.sample())
        samples = np.array(samples)

        ax.scatter(samples[:, 0], samples[:, 1], s=1, alpha=0.2, color="C0")
        ax.set_xlim(exp.LOWER_CONTEXT_BOUNDS[0], exp.UPPER_CONTEXT_BOUNDS[0])
        ax.set_ylim(exp.LOWER_CONTEXT_BOUNDS[1], 0.5 * exp.UPPER_CONTEXT_BOUNDS[1])
        ax.set_xticks([-3., -2., -1., 0., 1., 2., 3.])
        ax.set_yticks([1.125, 1.875, 2.625, 3.375])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params('both', length=0, width=0, which='major')
        ax.set_axisbelow(True)
        ax.grid()

    if path is not None:
        plt.savefig(path)
    else:
        plt.show()


def np_sprl_plot(seed, iterations, hard_likelihood, path=None):
    base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")
    exp = PointMass2DExperiment(base_log_dir, "np_self_paced", "ppo", {"hard_likelihood": hard_likelihood}, seed=seed)
    seed_path = exp.get_log_dir()

    teacher = exp.create_self_paced_teacher()
    f = plt.figure(figsize=(4.1, 0.5))
    for i, iteration in enumerate(iterations):
        ax = plt.Axes(f, [0.25 * i + 0.01, 0, 0.23, 1])
        # ax.set_axis_off()
        f.add_axes(ax)
        teacher.load(os.path.join(seed_path, "iteration-%d" % iteration))
        samples = []
        for i in range(0, 1000):
            samples.append(teacher.sample())
        samples = np.array(samples)

        ax.scatter(samples[:, 0], samples[:, 1], s=1, alpha=0.2, color="C1")
        ax.set_xlim(exp.LOWER_CONTEXT_BOUNDS[0], exp.UPPER_CONTEXT_BOUNDS[0])
        ax.set_ylim(exp.LOWER_CONTEXT_BOUNDS[1], 0.5 * exp.UPPER_CONTEXT_BOUNDS[1])
        ax.set_xticks([-3., -2., -1., 0., 1., 2., 3.])
        ax.set_yticks([1.125, 1.875, 2.625, 3.375])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params('both', length=0, width=0, which='major')
        ax.set_axisbelow(True)
        ax.grid()
    if path is not None:
        plt.savefig(path)
    else:
        plt.show()


def wb_sprl_plot(seed, iterations, hard_likelihood, path=None):
    base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")
    exp = PointMass2DExperiment(base_log_dir, "wasserstein", "ppo", {"hard_likelihood": hard_likelihood}, seed=0)
    log_dir = os.path.dirname(exp.get_log_dir())
    seed_path = os.path.join(log_dir, "seed-%d" % seed)

    f = plt.figure(figsize=(4.1, 0.5))
    teacher = exp.create_self_paced_teacher()
    for i, iteration in enumerate(iterations):
        ax = plt.Axes(f, [0.25 * i + 0.01, 0, 0.23, 1])
        # ax.set_axis_off()
        f.add_axes(ax)
        teacher.load(os.path.join(seed_path, "iteration-%d" % iteration))

        ax.scatter(teacher.teacher.current_samples[:, 0], teacher.teacher.current_samples[:, 1], s=1, alpha=0.2,
                   color="C4")
        ax.set_xlim(exp.LOWER_CONTEXT_BOUNDS[0], exp.UPPER_CONTEXT_BOUNDS[0])
        ax.set_ylim(exp.LOWER_CONTEXT_BOUNDS[1], exp.UPPER_CONTEXT_BOUNDS[1])
        ax.set_xticks([-3., -2., -1., 0., 1., 2., 3.])
        ax.set_yticks([2., 3.5, 5, 6.5])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params('both', length=0, width=0, which='major')
        ax.set_axisbelow(True)
        ax.grid()

    if path is not None:
        plt.savefig(path)
    else:
        plt.show()


def full_performance_plot(path=None):
    f = plt.figure(figsize=(4.1, 1.4))
    axs = []
    offsets = [0.07, 0.555]
    widths = [0.43, 0.43]
    for i, (offset, width) in enumerate(zip(offsets, widths)):
        ax = plt.Axes(f, [offset, 0.23, width, 0.62])
        f.add_axes(ax)
        lines = performance_plot(ax, hard=i == 1)
        axs.append(ax)
        ax.tick_params(axis='both', which='major', labelsize=TICK_SIZE)
        ax.tick_params(axis='both', which='minor', labelsize=TICK_SIZE)
        ax.set_xlabel("Epoch", fontsize=FONT_SIZE, labelpad=2)
        ax.grid()

    f.legend(lines, ["G-SPRL", "NP-SPRL", "Random", "Default", "WB-SPRL", "GoalGAN", "ALP-GMM"], fontsize=TICK_SIZE,
             loc='upper left', bbox_to_anchor=(-0.005, 1.015), ncol=7, columnspacing=0.6, handlelength=1.0,
             handletextpad=0.3)
    axs[0].set_ylabel("Cum. Disc. Ret.", fontsize=FONT_SIZE, labelpad=2)

    if path is None:
        plt.show()
    else:
        plt.savefig(path)


if __name__ == "__main__":
    os.makedirs("figures", exist_ok=True)
    full_performance_plot(path="figures/point_mass_performance.pdf")

    # This is a simple workaround around clashing tensorflow computation graphs that are created when we temporarily
    # instantiate CRL algorithms. Using an individual process for each plot avoids these clashes
    from joblib import Parallel, delayed

    jobs = [delayed(np_sprl_plot)(1, [10, 20, 40, 90], False, path="figures/np_point_mass_dist_soft.pdf"),
            delayed(g_sprl_plot)(1, [5, 10, 35, 55], False, path="figures/g_point_mass_dist_soft.pdf"),
            delayed(kl_plot)(False, "figures/point_mass_kl_soft.pdf"),
            delayed(np_sprl_plot)(2, [5, 10, 15, 20], True, path="figures/np_point_mass_dist_hard.pdf"),
            delayed(g_sprl_plot)(2, [5, 10, 35, 60], True, path="figures/g_point_mass_dist_hard.pdf"),
            delayed(kl_plot)(True, "figures/point_mass_kl_hard.pdf"),
            delayed(wb_sprl_plot)(1, [10, 15, 50, 100], False, "figures/wb_point_mass_dist_soft.pdf"),
            delayed(wb_sprl_plot)(1, [10, 15, 50, 100], True, "figures/wb_point_mass_dist_hard.pdf")]

    Parallel(n_jobs=8)(jobs)
