import os
import numpy as np
from deep_sprl.experiments import PickAndPlaceExperiment
import matplotlib.colors as colors
import matplotlib.cm as cmx
import matplotlib.pyplot as plt
from misc.visualize_maze_results import add_plot
from PIL import Image
import pickle
from deep_sprl.environments.pick_and_place import generate_demonstration, PickAndPlace
from gym.wrappers import TimeLimit

plt.rc('text.latex', preamble=r'\usepackage{amsmath}')
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Roman"],
})

FONT_SIZE = 8
TICK_SIZE = 6


def performance_plot(ax, hard=False):
    base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")
    lines = []
    for method, color in zip(["np_self_paced", "random", "default", "wasserstein"],
                             ["C1", "C2", "C3", "C4"]):
        exp = PickAndPlaceExperiment(base_log_dir, method, "sac", {"hard_likelihood": hard}, seed=0)
        log_dir = os.path.dirname(exp.get_log_dir())
        lines.append(add_plot(log_dir, ax, color))

    for method, color in zip(["goal_gan", "alp_gmm", "acl"], ["C5", "C6", "C7"]):
        exp = PickAndPlaceExperiment(base_log_dir, method, "sac", {"hard_likelihood": True}, seed=0)
        log_dir = os.path.dirname(exp.get_log_dir())
        lines.append(add_plot(log_dir, ax, color))

    ax.set_yticks([0, 0.5, 1.])
    ax.set_xticks([0, 50, 100, 150, 200])

    if hard:
        ax.text(5, 0.7, r'$\mu_2(\mathbf{c})$', fontsize=FONT_SIZE + 2)
    else:
        ax.text(5, 0.7, r'$\mu_1(\mathbf{c})$', fontsize=FONT_SIZE + 2)

    return lines


def wb_sprl_plot(seed, iterations, path=None):
    base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")
    exp = PickAndPlaceExperiment(base_log_dir, "wasserstein", "sac", {"hard_likelihood": False}, seed=0)
    log_dir = os.path.dirname(exp.get_log_dir())
    seed_path = os.path.join(log_dir, "seed-%d" % seed)

    teacher = exp.create_self_paced_teacher()

    jet = plt.get_cmap("viridis")
    c_norm = colors.Normalize(vmin=0, vmax=max(iterations))
    scalar_map = cmx.ScalarMappable(norm=c_norm, cmap=jet)

    f = plt.figure(figsize=(2.06, 1.4))
    ax = plt.Axes(f, [0.08, 0.23, 0.91, 0.47])
    f.add_axes(ax)
    ax.set_facecolor((0.9, 0.9, 0.9))
    for iteration in iterations:
        teacher.load(os.path.join(seed_path, "iteration-%d" % iteration))

        # Bin the samples and make a bar plot
        bin_edges = np.concatenate(
            ([0], np.linspace(0, teacher.contexts.shape[0] - 1, teacher.contexts.shape[0])[:-1] + 0.5,
             [teacher.contexts.shape[0] - 1]))

        probs = np.histogram(teacher.teacher.current_samples[:, 0], bins=bin_edges)[0] / \
                teacher.teacher.current_samples.shape[0]
        plt.bar(teacher.contexts[:, 0], probs / np.max(probs), alpha=0.8, color=scalar_map.to_rgba(iteration),
                width=1.)
        plt.xlim([-2, 77])
        plt.ylim([0., 1.])
        ax.set_xticks([0, 25, 50, 75])
        ax.set_xticklabels([r'$0$', r'$0.33$', r'$0.66$', r'$1$'], fontsize=TICK_SIZE)
        ax.set_yticklabels([])
        ax.tick_params('y', length=0, width=0, which='major')
        plt.xlabel(r"$t$", fontsize=FONT_SIZE, labelpad=2)
        plt.ylabel(r"$p(t)$", fontsize=FONT_SIZE, labelpad=-1)

    cbar = plt.colorbar(scalar_map)
    cbar.ax.tick_params(labelsize=TICK_SIZE)
    ax.tick_params(labelsize=TICK_SIZE)
    ax.set_axisbelow(True)
    ax.grid()

    img = Image.open("pick_and_place_step_0.png")
    ax = plt.Axes(f, [0.04, 0.72, 0.29, 0.29])
    ax.set_axis_off()
    f.add_axes(ax)
    ax.imshow(np.array(img))

    img = Image.open("pick_and_place_step_34.png")
    ax = plt.Axes(f, [0.39, 0.72, 0.29, 0.29])
    ax.set_axis_off()
    f.add_axes(ax)
    ax.imshow(np.array(img))

    img = Image.open("pick_and_place_step_75.png")
    ax = plt.Axes(f, [0.74, 0.72, 0.29, 0.29])
    ax.set_axis_off()
    f.add_axes(ax)
    ax.imshow(np.array(img))

    ax = plt.Axes(f, [0., 0., 1., 1.])
    f.add_axes(ax)
    ax.set_axis_off()
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.vlines([0.09, 0.44, 0.79], 0.23, 1, linestyle="--", alpha=0.5)

    if path is not None:
        plt.savefig(path, dpi=1200)
    else:
        plt.show()


def performance_plots(path=None):
    f = plt.figure(figsize=(3.3, 1.4))
    axs = []
    offsets = [0.115, 0.595]
    widths = [0.38, 0.38]
    for i, (offset, width) in enumerate(zip(offsets, widths)):
        ax = plt.Axes(f, [offset, 0.23, width, 0.53])
        f.add_axes(ax)
        lines = performance_plot(ax, hard=i == 1)
        axs.append(ax)
        ax.tick_params(axis='both', which='major', labelsize=TICK_SIZE)
        ax.tick_params(axis='both', which='minor', labelsize=TICK_SIZE)
        ax.set_xlabel("Epoch", fontsize=FONT_SIZE, labelpad=2)
        ax.grid()

    f.legend(lines, ["NP-SPRL", "Random", "Default", "WB-SPRL", "GoalGAN", "ALP-GMM", "ACL"], fontsize=TICK_SIZE,
             loc='upper left', bbox_to_anchor=(0.16, 1.02), ncol=4, columnspacing=0.6, handlelength=1.0)
    axs[0].set_ylabel("Success Rate", fontsize=FONT_SIZE, labelpad=2)
    # plt.legend(lines, )
    if path is None:
        plt.show()
    else:
        plt.savefig("figures/pick_and_place_performance.pdf")


def np_sprl_plot(seed, iterations, hard_likelihood=False, path=None):
    base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")
    exp = PickAndPlaceExperiment(base_log_dir, "np_self_paced", "sac", {"hard_likelihood": hard_likelihood}, seed=seed)
    seed_path = exp.get_log_dir()

    teacher = exp.create_self_paced_teacher()
    f = plt.figure(figsize=(2.5, 1.))
    for i, iteration in enumerate(iterations):
        ax = plt.Axes(f, [0.055 + (i % 2) * (0.48), 0.25 + (i < 2) * 0.385, 0.46, 0.36])
        f.add_axes(ax)
        ax.set_facecolor((0.9, 0.9, 0.9))
        teacher.load(os.path.join(seed_path, "iteration-%d" % iteration))

        n_samples = teacher.teacher.cur_log_pdf.shape[0]
        probs = np.exp(teacher.teacher.cur_log_pdf)
        plt.bar(np.linspace(0, n_samples - 1, n_samples), probs / np.max(probs),
                color="C1", width=1.)

        ax.set_xlim([-2, 77])
        ax.set_xticks([0, 37, 75])
        if i >= 2:
            ax.set_xticklabels([r'$0$', r'$0.5$', r'$1$'])
            ax.set_xlabel(r"$t$", fontsize=FONT_SIZE, labelpad=0)
        else:
            ax.set_xticklabels([])
            ax.tick_params('x', length=0, width=0, which='major')

        ax.set_ylim([0., 1.])
        ax.set_yticklabels([])
        ax.tick_params('y', length=0, width=0, which='major')
        if i % 2 == 0:
            ax.set_ylabel(r"$p_{\alpha, \eta}(t)$", fontsize=FONT_SIZE, labelpad=-1)

        if i >= 2:
            ax.set_xlabel(r"$t$", fontsize=FONT_SIZE, labelpad=-1)

        ax.tick_params(labelsize=FONT_SIZE)
        ax.set_axisbelow(True)
        ax.grid()

        ax.text(3, 0.75, r'Iteration $%d$' % iteration, fontsize=FONT_SIZE)

    if path is None:
        plt.show()
    else:
        plt.savefig(path)


def qualitative_plot():
    # Compute the demonstration success rate and number of steps (upon success)
    log_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "pick_and_place_eval.zip")
    if os.path.exists(log_file):
        with open(log_file, "rb") as f:
            baseline_picks, baseline_n_steps, learned_picks, learned_n_steps = pickle.load(f)
    else:
        base_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "logs")
        exp = PickAndPlaceExperiment(base_log_dir, "wasserstein", "sac", {"hard_likelihood": False}, seed=0)
        env = TimeLimit(PickAndPlace(), max_episode_steps=150)

        baseline_picks = []
        baseline_n_steps = []
        actions = generate_demonstration()[-1]
        for i in range(0, 500):
            obs = env.reset()
            step_count = 0
            done = False
            while not done and step_count < actions.shape[0]:
                obs, rewards, done, infos = env.step(actions[step_count])
                step_count += 1

            final_dist = np.linalg.norm(env.env._get_obs()["achieved_goal"] - env.env.goal)
            baseline_picks.append(final_dist < 0.05)
            baseline_n_steps.append(actions.shape[0])

        exp_dir = os.path.dirname(exp.get_log_dir())
        learned_picks = []
        learned_n_steps = []
        for seed in [f for f in os.listdir(exp_dir) if f.startswith("seed-")]:
            model_load_path = os.path.join(exp_dir, seed, 'iteration-198', "model.zip")
            model = exp.learner.load_for_evaluation(model_load_path, exp.vec_eval_env)

            for i in range(0, 500):
                obs = env.reset()
                step_count = 0
                done = False
                while not done:
                    action = model.step(obs[None, :], state=None, deterministic=False)
                    step_count += 1
                    obs, rewards, done, infos = env.step(action[0])

                if infos['is_success']:
                    learned_n_steps.append(step_count)

                final_dist = np.linalg.norm(env.env._get_obs()["achieved_goal"] - env.env.goal)
                learned_picks.append(final_dist < 0.05)

        with open(log_file, "wb") as f:
            pickle.dump((baseline_picks, baseline_n_steps, learned_picks, learned_n_steps), f)

    print(r"\begin{table}{rcl}")
    print(r"& Pick Rate & & Steps")
    print(r"Demonstration & %.2f & %.2f" % (np.mean(baseline_picks), np.mean(baseline_n_steps)))
    print(r"Learned (600k steps) & %.2f & %.2f" % (np.mean(learned_picks), np.mean(learned_n_steps)))
    print(r"\end{table}")


if __name__ == "__main__":
    os.makedirs("figures", exist_ok=True)
    performance_plots("figures/pick_and_place_performance.pdf")
    wb_sprl_plot(1, [10, 40, 44, 110, 130, 140], path="figures/pick_and_place_curriculum.pdf")
    np_sprl_plot(1, [24, 26, 28, 30], hard_likelihood=False, path="figures/pick_and_place_np_curriculum_soft.pdf")
    np_sprl_plot(1, [24, 26, 28, 30], hard_likelihood=True, path="figures/pick_and_place_np_curriculum_hard.pdf")
    qualitative_plot()
