import glob
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

################# Uncomment path for plot an graph on an environment ##################################################

# For plotting Hopper results
path = "../Hopper/"

# For plotting halfcheetah results
path = "../HalfCheetah/"

# For plotting Walker2d results
path = "../Walker2d/"


path='/Users/pierreclavier/Documents/These/Code/Software/Robust-Distributional-RL/Robust_RL/results_plot/ablation/hopper_file/4'
path='/Users/pierreclavier/Documents/These/Code/Software/Robust-Distributional-RL/Robust_RL/results_plot/ablation/walker2d_file/3'
#path='/Users/pierreclavier/Documents/These/Code/Software/Robust-Distributional-RL/Robust_RL/results_plot/ablation/halfcheetah_file/0'
#path='/Users/pierreclavier/Documents/These/Code/Software/Robust-Distributional-RL/Robust_RL/results_plot/ablation/hopper_file/train_vs_null'
def test_csv_min(log_dir_file, log_dir_out, name="plot"):
    data = pd.read_csv(log_dir_file, delimiter=";")

    y = data["mean"].values
    std = data["std"].values
    x = data["relative_mass"].values
    min_y = data["min"].values
    filtre = np.invert(np.isnan(y))

    x = x[filtre]
    y = y[filtre]
    std = std[filtre]
    min_y = min_y[filtre]

    plt.plot(x, min_y, label=name)


def test_csv_mean(log_dir_file, log_dir_out, n, i, name="plot"):

    data = pd.read_csv(log_dir_file, delimiter=";")
    y = data["mean"].values
    std = data["std"].values
    x = data["relative_mass"].values
    min_y = data["min"].values

    filtre = np.invert(np.isnan(y))
    x = x[filtre]
    y = y[filtre]
    max = np.max(y)
    std = std[filtre]
    min_y = min_y[filtre]

    clrs = sns.color_palette(n_colors=5)

    if "Full penalization" in name:
        color = clrs[0]
    if "Train penalization" in name:
        color = clrs[1]
    if "Test penalization" in name:
        color = clrs[2]
    if "TQC" in name:
        color = clrs[3]
    if "old" in name:
        color = clrs[4]
    

    plt.plot(x, y, label=name, c=color)
    ## For normalised plot uncomment the following line
    # plt.plot(x,y/max,label=name,c=color)


def test_plot(path):

    csv_files = glob.glob(os.path.join(path, "*.csv"))
    i = 0
    for f in csv_files:
        filename = os.path.basename(f)

        # Change name of the file for legend

        print(f)
        if  'no_penal_test_summary' in filename :
            filename = "Train penalization"
            print("train")
        elif "no_penal_train" in filename:
            filename =  "Test penalization" 
        elif "seed" in filename:
            continue
        elif "results_summary_no_penal" in filename:
            filename="old"
        elif 'no_penal_testbest' in filename :
            continue
        elif '1_var' in filename:
            filename="TQC"
        
        else :
            filename = "Full penalization"

       
        with plt.style.context("ggplot"):
            name = filename
            test_csv_mean(f, path, name=name, n=len(csv_files), i=i)
            i += 1

    plt.xlabel("Relative mass")
    plt.ylabel("Mean Reward")
    handles, labels = plt.gca().get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
    plt.legend(handles, labels)
    os.makedirs(path, exist_ok=True)
    plt.savefig(path + "fig_{}.pdf".format(name))
    plt.tight_layout()

    plt.show()


# plot the graph
test_plot(path)
