import glob
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# Path where there are results files
path = "../Cartpole/"


def test_csv(log_dir_file, log_dir_out, name="plot"):
    data = pd.read_csv(log_dir_file, delimiter=";")

    y = data["mean"].values
    std = data["std"].values

    if "length" in data.columns:
        x = data["length"].values
    else:
        x = data["lenght"].values

    filtre = np.invert(np.isnan(y))
    x = x[filtre]
    print(x)
    y = y[filtre]
    std = std[filtre]
    x=x[:-4]
    y=y[:-4]
    std=std[:-4]

    clrs = sns.color_palette(n_colors=7)
    if "PPO" in name:
        color = clrs[0]
    if "0" in name:
        color = clrs[1]
    if "1" in name:
        color = clrs[2]
    if "3" in name:
        color = clrs[3]
    if "5" in name:
        color = clrs[4]
    if "7" in name:
        color = clrs[5]
    if "6" in name:
        color = clrs[6]
    plt.xscale("log")
    plt.plot(x,y, label=name, c=color)

    plt.fill_between(
            x,
            y - std,
            y + std,
            alpha=0.5,
            linestyle="dashdot",
            facecolor=color,
        ) 


def test_plot(path):

    csv_files = glob.glob(os.path.join(path, "*.csv"))
    for f in csv_files:

        filename = os.path.basename(f)

        filename = filename.replace(
            "results_summary_bestCartPole-v1QRDQN_cart_1_std.csv", "\u03B1=0"
        )
        filename = filename.replace(
            "results_summary_bestCartPole-v1QRDQN_cart_2_std.csv", "\u03B1=1"
        )
        filename = filename.replace(
            "results_summary_bestCartPole-v1QRDQN_cart_3_std.csv", "\u03B1=3"
        )
        filename = filename.replace(
            "results_summary_bestCartPole-v1QRDQN_cart_4_std.csv", "\u03B1=5"
        )
        filename = filename.replace(
            "results_summary_bestCartPole-v1QRDQN_cart_5_std.csv", "\u03B1=7"
        )
        filename = filename.replace("results_summaryCartPole-v1PPO_cart_std.csv", "PPO")

        name = filename.replace("results_summary_best", "")

        with plt.style.context("ggplot"):
            test_csv(f, path, name=name)

    handles, labels = plt.gca().get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
    plt.plot(1, 500, marker="*", color="k", markersize=12)
    plt.legend(handles, labels)
    plt.xlabel("length of the pole")
    plt.ylabel("Mean Reward")
    os.makedirs(path, exist_ok=True)
    plt.savefig(path + "fig_{}.pdf".format(name))
    plt.tight_layout()
    plt.show()


test_plot(path)
