import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from deep_sprl.experiments import MazeExperiment
from misc.visualize_maze import maze

plt.rc('text.latex', preamble=r'\usepackage{amsmath}')
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Roman"],
})

FONT_SIZE = 8
TICK_SIZE = 6


def add_plot(base_log_dir, ax, color, max_seed=None, marker="o", markevery=3):
    iterations = []
    seed_performances = []

    if max_seed is None:
        max_seed = int(1e6)

    seeds = [int(d.split("-")[1]) for d in os.listdir(base_log_dir) if d.startswith("seed-")]
    for seed in [s for s in seeds if s <= max_seed]:
        seed_dir = "seed-" + str(seed)
        seed_log_dir = os.path.join(base_log_dir, seed_dir)

        if os.path.exists(os.path.join(seed_log_dir, "performance.pkl")):
            iteration_dirs = [d for d in os.listdir(seed_log_dir) if d.startswith("iteration-")]
            unsorted_iterations = np.array([int(d[len("iteration-"):]) for d in iteration_dirs])
            idxs = np.argsort(unsorted_iterations)
            iterations = unsorted_iterations[idxs]

            with open(os.path.join(seed_log_dir, "performance.pkl"), "rb") as f:
                seed_performances.append(pickle.load(f))
        else:
            pass
            # raise RuntimeError("No Performance log was found")

    if len(seed_performances) > 0:
        print("Found %d completed seeds" % len(seed_performances))
        min_length = np.min([len(seed_performance) for seed_performance in seed_performances])
        iterations = iterations[0: min_length]
        seed_performances = [seed_performance[0: min_length] for seed_performance in seed_performances]

        mid = np.mean(seed_performances, axis=0)
        sem = np.std(seed_performances, axis=0) / np.sqrt(len(seed_performances))
        low = mid - 2 * sem
        high = mid + 2 * sem

        l, = ax.plot(iterations, mid, color=color, linewidth=1, marker=marker, markersize=2, markevery=markevery)
        ax.fill_between(iterations, low, high, color=color, alpha=0.5)
        return l
    else:
        return None


def performance_plot(path=None):
    base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")

    f = plt.figure(figsize=(2.3, 1.4))

    ax = plt.Axes(f, [0.172, 0.23, 0.77, 0.52])
    f.add_axes(ax)

    lines = []
    for method, color in zip(["self_paced", "np_self_paced", "random", "wasserstein", "goal_gan", "alp_gmm"],
                             ["C0", "C1", "C2", "C4", "C5", "C6"]):
        exp = MazeExperiment(base_log_dir, method, "ppo", {"hard_likelihood": False}, seed=0)
        log_dir = os.path.dirname(exp.get_log_dir())
        lines.append(add_plot(log_dir, ax, color))

    plt.ylabel("Success Rate", fontsize=FONT_SIZE, labelpad=2.)
    plt.xlabel("Epoch", fontsize=FONT_SIZE, labelpad=2.)

    f.legend(lines, ["G-SPRL", "NP-SPRL", "Random", "WB-SPRL", "GoalGAN", "ALP-GMM"], fontsize=TICK_SIZE,
             loc='upper left', bbox_to_anchor=(0.09, 1.02), ncol=3, columnspacing=0.6, handlelength=1.0)

    ax.set_xticks([0, 100, 200, 300, 400])
    ax.set_xticklabels([r"$0$", r"$100$", r"$200$", r"$300$", r"$400$"])
    ax.set_xlim([0, 400])

    ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8])
    ax.set_yticklabels([r"$0$", r"$0.2$", r"$0.4$", r"$0.6$", r"$0.8$"])
    ax.set_ylim([0, 0.81])

    ax.tick_params(axis='both', which='major', labelsize=TICK_SIZE)
    ax.tick_params(axis='both', which='minor', labelsize=TICK_SIZE)
    plt.grid()
    plt.tight_layout()
    if path is None:
        plt.show()
    else:
        plt.savefig(path)


def np_distribution_visualization(seed, iterations, path=None):
    base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")
    exp = MazeExperiment(base_log_dir, "np_self_paced", "ppo", {"hard_likelihood": False}, seed=0)
    log_dir = os.path.dirname(exp.get_log_dir())
    seed_path = os.path.join(log_dir, "seed-%d" % seed)

    teacher = exp.create_self_paced_teacher()
    f = plt.figure(figsize=(3., 0.58))
    for i, iteration in enumerate(iterations):
        teacher.load(os.path.join(seed_path, "iteration-%d" % iteration))
        samples = []
        for _ in range(0, 2000):
            samples.append(teacher.sample())
        samples = np.array(samples)

        ax = plt.Axes(f, [0.25 * i + 0.005, 0., 0.24, 1.])
        f.add_axes(ax)

        ax.imshow(maze(), extent=(exp.LOWER_CONTEXT_BOUNDS[0], exp.UPPER_CONTEXT_BOUNDS[0], exp.LOWER_CONTEXT_BOUNDS[1],
                                  exp.UPPER_CONTEXT_BOUNDS[1]), origin="lower")
        ax.scatter(samples[:, 0], samples[:, 1], alpha=0.2, color="C1", s=1)

        ax.set_xlim(exp.LOWER_CONTEXT_BOUNDS[0], exp.UPPER_CONTEXT_BOUNDS[0])
        ax.set_ylim(exp.LOWER_CONTEXT_BOUNDS[1], exp.UPPER_CONTEXT_BOUNDS[1])

        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params('both', length=0, width=0, which='major')
        ax.set_rasterized(True)

    if path is not None:
        plt.savefig(path)
    else:
        plt.show()


def wb_distribution_visualization(seed, iterations, path=None):
    base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")
    exp = MazeExperiment(base_log_dir, "wasserstein", "ppo", {"hard_likelihood": False}, seed=0)
    log_dir = os.path.dirname(exp.get_log_dir())
    seed_path = os.path.join(log_dir, "seed-%d" % seed)

    teacher = exp.create_self_paced_teacher()
    f = plt.figure(figsize=(1.2, 1.2))
    for i, iteration in enumerate(iterations):
        teacher.load(os.path.join(seed_path, "iteration-%d" % iteration))
        samples = []
        for _ in range(0, 2000):
            samples.append(teacher.sample())
        samples = np.array(samples)

        ax = plt.Axes(f, [0.5 * (i % 2) + 0.005, 0.5 * (1 - (i // 2)) + 0.005, 0.49, 0.49])
        f.add_axes(ax)

        ax.imshow(maze(), extent=(exp.LOWER_CONTEXT_BOUNDS[0], exp.UPPER_CONTEXT_BOUNDS[0], exp.LOWER_CONTEXT_BOUNDS[1],
                                  exp.UPPER_CONTEXT_BOUNDS[1]), origin="lower")
        ax.scatter(samples[:, 0], samples[:, 1], alpha=0.2, color="C4", s=1)

        ax.set_xlim(exp.LOWER_CONTEXT_BOUNDS[0], exp.UPPER_CONTEXT_BOUNDS[0])
        ax.set_ylim(exp.LOWER_CONTEXT_BOUNDS[1], exp.UPPER_CONTEXT_BOUNDS[1])

        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params('both', length=0, width=0, which='major')
        ax.set_rasterized(True)

    if path is not None:
        plt.savefig(path)
    else:
        plt.show()


def g_distribution_visualization(seed, iterations, path=None):
    base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")
    exp = MazeExperiment(base_log_dir, "self_paced", "ppo", {"hard_likelihood": False}, seed=0)
    log_dir = os.path.dirname(exp.get_log_dir())
    seed_path = os.path.join(log_dir, "seed-%d" % seed)

    teacher = exp.create_self_paced_teacher()
    f = plt.figure(figsize=(3., 0.58))
    for i, iteration in enumerate(iterations):
        teacher.load(os.path.join(seed_path, "iteration-%d" % iteration))
        samples = []
        for _ in range(0, 1000):
            samples.append(teacher.sample())
        samples = np.array(samples)

        ax = plt.Axes(f, [0.25 * i + 0.005, 0., 0.24, 1.])
        f.add_axes(ax)

        ax.imshow(maze(), extent=(exp.LOWER_CONTEXT_BOUNDS[0], exp.UPPER_CONTEXT_BOUNDS[0], exp.LOWER_CONTEXT_BOUNDS[1],
                                  exp.UPPER_CONTEXT_BOUNDS[1]), origin="lower")
        ax.scatter(samples[:, 0], samples[:, 1], alpha=0.2, color="C0", s=4)

        ax.set_xlim(exp.LOWER_CONTEXT_BOUNDS[0], exp.UPPER_CONTEXT_BOUNDS[0])
        ax.set_ylim(exp.LOWER_CONTEXT_BOUNDS[1], exp.UPPER_CONTEXT_BOUNDS[1])

        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params('both', length=0, width=0, which='major')
        ax.set_rasterized(True)

    if path is not None:
        plt.savefig(path)
    else:
        plt.show()


if __name__ == "__main__":
    os.makedirs("figures", exist_ok=True)
    # This is a simple workaround around clashing tensorflow computation graphs that are created when we temporarily
    # instantiate CRL algorithms. Using an individual process for each plot avoids these clashes
    from joblib import Parallel, delayed

    jobs = [delayed(np_distribution_visualization)(1, [20, 40, 70, 250], path="figures/np_maze_dist.pdf"),
            delayed(wb_distribution_visualization)(1, [20, 40, 70, 270], path="figures/wb_maze_dist.pdf"),
            delayed(g_distribution_visualization)(1, [10, 30, 80, 390], path="figures/g_maze_dist.pdf"),
            delayed(performance_plot)("figures/maze_performance.pdf")]

    Parallel(n_jobs=4)(jobs)
