import glob
import os

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import cm
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from tensorboard.backend.event_processing.event_accumulator import \
    EventAccumulator

path = "../acrobot/"


def test_csv2D(log_dir_file, log_dir_out, n, ax, fig, name="plot"):

    data = pd.read_csv(log_dir_file, delimiter=";")

    y = data["mean"].values
    std = data["std"].values
    x = data["length1"].values
    x2 = data["length2"].values

    cmap = cm.YlOrBr
    # cmap= plt.get_cmap('inferno')
    # cmap= plt.get_cmap('jet')
    cmap = plt.get_cmap("hot")

    filtre = np.invert(np.isnan(y))
    x = x[filtre]
    y = y[filtre]
    std = std[filtre]
    sqrt = np.int(np.sqrt(y.shape[0]))
    x = x.reshape(sqrt, sqrt)
    x2 = x2.reshape(sqrt, sqrt)

    x_label_list = x[:, 0]
    y_label_list = x2[0, :]
    filtre3 = x < 8
    x_label_list = x_label_list[x_label_list < 8]

    y = y.reshape(sqrt, sqrt)
    y = y[filtre3]
    y = y.reshape(-1, sqrt)

    img = ax[n].imshow(y, cmap=cmap, origin="lower", aspect="auto")
    # Show all ticks and label them with the respective list entries
    ax[n].set_xticks(np.arange(len(y_label_list)))
    ax[n].set_yticks(np.arange(len(x_label_list)))

    plt.setp(ax[n].get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    ax[n].set_xticklabels(y_label_list)
    ax[n].set_yticklabels(x_label_list)

    plt.ylabel("length 1")
    plt.xlabel("length 2")

    plt.title(name)
    divider = make_axes_locatable(ax[n])
    ax_cb = divider.new_horizontal(size="5%", pad=0.05)
    plt.colorbar(img, cax=ax_cb)

    fig.colorbar(img)


def test_csv2D2(log_dir_file, log_dir_out, n, ax, fig, name="plot"):
    data = pd.read_csv(log_dir_file, delimiter=";")
    y = data["mean"].values
    std = data["std"].values
    x = data["length1"].values
    x2 = data["length2"].values

    cmap = cm.YlOrBr
    cmap = plt.get_cmap("inferno")
    cmap = plt.get_cmap("jet")
    cmap = plt.get_cmap("hot")

    filtre = np.invert(np.isnan(y))
    x = x[filtre]
    y = y[filtre]
    std = std[filtre]

    sqrt = np.int(np.sqrt(y.shape[0]))
    x = x.reshape(sqrt, sqrt)
    x2 = x2.reshape(sqrt, sqrt)

    x_label_list = x[:, 0]
    y_label_list = x2[0, :]

    filtre3 = x < 8
    x_label_list = x_label_list[x_label_list < 8]

    y = y.reshape(sqrt, sqrt)
    y = y[filtre3]
    y = y.reshape(-1, sqrt)

    img = ax.imshow(y, cmap=cmap, origin="lower", aspect="auto", vmin=-500, vmax=-50)
    # Show all ticks and label them with the respective list entries
    ax.set_xticks(np.arange(len(y_label_list)))
    ax.set_yticks(np.arange(len(x_label_list)))

    # ax.set_xticks(x_label_list)
    # ax.set_yticks(y_label_list)
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    ax.set_xticklabels(y_label_list)
    ax.set_yticklabels(x_label_list)

    plt.ylabel("length 1")
    plt.xlabel("length 2")

    plt.title(name)
    divider = make_axes_locatable(ax)
    ax_cb = divider.new_horizontal(size="5%", pad=0.05)
    plt.colorbar(img, cax=ax_cb)

    fig.colorbar(img)


def test_csv(log_dir_file, log_dir_out, n, name="plot"):

    data = pd.read_csv(log_dir_file, delimiter=";")
    y = data["mean"].values
    std = data["std"].values
    x = data["length1"].values
    x2 = data["length2"].values

    filtre = np.invert(np.isnan(y))
    x = x[filtre]
    y = y[filtre]
    std = std[filtre]

    sqrt = np.int(np.sqrt(y.shape[0]))
    x = x.reshape(sqrt, sqrt)
    x2 = x2.reshape(sqrt, sqrt)

    x_label_list = x[:, 0]
    y_label_list = x2[0, :]

    filtre3 = x < 100
    x_label_list = x_label_list[x_label_list < 100]

    y = y.reshape(sqrt, sqrt)
    y = y[filtre3]
    y = y.reshape(-1, sqrt)

    y = y[:, 5]

    y = y[:13]
    y_label_list = y_label_list[:13]
    std=std[:13]

    clrs = sns.color_palette(n_colors=7)
    if "PPO" in name:
        color = clrs[0]
    if "0" in name:
        color = clrs[1]
    if "1" in name:
        color = clrs[2]
    if "2" in name:
        color = clrs[3]
    if "3" in name:
        color = clrs[4]
    if "4" in name:
        color = clrs[5]
    if "5" in name:
        color = clrs[6]

    plt.plot(y_label_list, y, label=name, c=color)  # c=color)
    plt.fill_between(
            y_label_list,
            y - std,
            y + std,
            alpha=0.5,
            linestyle="dashdot",
            facecolor=color,
        ) 

def test_plot2D(path):

    csv_files = glob.glob(os.path.join(path, "*.csv"))

    for count, f in enumerate(csv_files):

        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10))

        the_string = os.path.basename(f)

        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1QRDQN_acro_", "Acrobot QRDQN \u03B1="
        )
        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1PPO_acro_", "Acrobot PPO "
        )
        the_string = the_string.replace("_std.csv", "")

        name = the_string

        test_csv2D2(f, path, name=name, n=count, ax=ax, fig=fig)

    plt.xlabel("length of the pole")
    plt.ylabel("Mean Reward")
    plt.legend()
    os.makedirs(path, exist_ok=True)
    plt.savefig(path + "fig_{}.pdf".format(name))
    plt.tight_layout()
    plt.show()


def coupe_1D(path):

    csv_files = glob.glob(os.path.join(path, "*.csv"))

    for count, f in enumerate(csv_files):

        the_string = os.path.basename(f)

        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1QRDQN_acro_1", "\u03B1=0"
        )
        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1QRDQN_acro_2", "\u03B1=1"
        )
        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1QRDQN_acro_3", "\u03B1=2"
        )
        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1QRDQN_acro_4", "\u03B1=3"
        )
        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1PPO_acro_", "PPO "
        )
        the_string = the_string.replace(r"PPO std.csv", "PPO ")
        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1replay2_QRDQN_acro_6", "\u03B1=5 "
        )
        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1replay2_QRDQN_acro_5", "\u03B1=4 "
        )
        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1replay2_QRDQN_acro_4", "\u03B1=3 "
        )
        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1replay2_QRDQN_acro_3", "\u03B1=2 "
        )
        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1replay2_QRDQN_acro_2", "\u03B1=1 "
        )
        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1replay2_QRDQN_acro_1", "\u03B1=0 "
        )
        the_string = the_string.replace(
            r"results_summary_bestAcrobot-v1replay3_PPO_acro", "PPO "
        )
        the_string = the_string.replace("_std.csv", "")

        name = the_string

        with plt.style.context("ggplot"):
            test_csv(f, path, name=name, n=count)
    plt.xlabel("length of the pole")
    plt.ylabel("Mean Reward")
    plt.plot(1.7, -78, marker="*", color="k", markersize=12)

    handles, labels = plt.gca().get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
    plt.legend(handles, labels)

    os.makedirs(path, exist_ok=True)
    plt.savefig(path + "fig_{}.pdf".format(name))

    plt.tight_layout()

    plt.show()


# test_plot2D(path)

coupe_1D(path)
