import glob
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

################# Uncomment path for plot an graph on an environment ##################################################

################################ For penal A formulation##########
# For plotting Hopper results
path = "../Hopper/"

# For plotting halfcheetah results
path = "../HalfCheetah/"

# For plotting Walker2d results
path = "../Walker2d/"

########### For penal  A/C formulation
path="../Bellman_critic_results/Half/results_test/"

path="../Bellman_critic_results/Walker/results_test/"

path="../Bellman_critic_results/Halker/results_test/"

######### For noisy env plots 
path="../noisy_env/half/test/"
#path="../noisy_env/hopper/test/"
#path="../noisy_env/walker/test/"



def test_csv_min(log_dir_file, log_dir_out, name="plot"):
    data = pd.read_csv(log_dir_file, delimiter=";")

    y = data["mean"].values
    std = data["std"].values
    x = data["relative_mass"].values
    min_y = data["min"].values
    filtre = np.invert(np.isnan(y))

    x = x[filtre]
    y = y[filtre]
    std = std[filtre]
    min_y = min_y[filtre]

    plt.plot(x, min_y, label=name)


def test_csv_mean(log_dir_file, log_dir_out, n, i, name="plot"):

    data = pd.read_csv(log_dir_file, delimiter=";")
    y = data["mean"].values
    std = data["std"].values
    x = data["relative_mass"].values
    min_y = data["min"].values

    filtre = np.invert(np.isnan(y))
    x = x[filtre]
    y = y[filtre]
    max = np.max(y)
    std = std[filtre]
    min_y = min_y[filtre]

    clrs = sns.color_palette(n_colors=7)
###### For half
    
    if "SAC" in name:
        color = clrs[0]
    if "\u03B1=0" in name:
        color = clrs[1]
    if "\u03B1=0.1" in name:
        color = clrs[2]
    if "\u03B1=0.5" in name:
        color = clrs[3]
    if "\u03B1=1" in name:
        color = clrs[4]
    if "\u03B1=1.5" in name:
        color = clrs[5]
    if "\u03B1=2" in name:
        color = clrs[6]

   
    # else: 
    #     if "SAC" in name:
    #         color = clrs[0]
    #     if "\u03B1=0" in name:
    #         color = clrs[1]
    #     if "\u03B1=1" in name:
    #         color = clrs[2]
    #     if "\u03B1=2" in name:
    #         color = clrs[3]
    #     if "\u03B1=3" in name:
    #         color = clrs[4]
    #     if "\u03B1=4" in name:
    #         color = clrs[5]
    #     if "\u03B1=5" in name:
    #         color = clrs[6]

    #plt.plot(x, y/1, label=name, c=color)


    plt.fill_between(
            x,
            (y - std)/max,
            (y + std)/max,
            alpha=0.5,
            linestyle="dashdot",
            facecolor=color,
        ) 
    ## For normalised plot uncomment the following line
    plt.plot(x,y/max,label=name,c=color)


def test_plot(path):

    csv_files = glob.glob(os.path.join(path, "*.csv"))
    i = 0
    for f in csv_files:
        filename = os.path.basename(f)
        print("f",f)

        

        # Change name of the file for legend

        print(f)

        filename = filename.replace(
            "results_summary_bestHopper-v3replay_TQC_Hopper-v3_5_std.csv", "\u03B1=4"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3replay_TQC_Hopper-v3_6_std.csv", "\u03B1=5"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3replay_TQC_Hopper-v3_4_std.csv", "\u03B1=3"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3replay_TQC_Hopper-v3_3_std.csv", "\u03B1=2"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3replay_TQC_Hopper-v3_2_std.csv", "\u03B1=1"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3replay_TQC_Hopper-v3_1_std.csv", "\u03B1=0"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3replay_SAC_Hopper-v3_std.csv", "SAC"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3SAC_HOPPER_1_std.csv", "SAC"
        )

        filename = filename.replace(
            "results_summary_bestWalker2d-v3TQC_Walker2d-v3_6_std.csv", "\u03B1=5"
        )
        filename = filename.replace(
            "results_summary_bestWalker2d-v3TQC_Walker2d-v3_4_std.csv", "\u03B1=3"
        )
        filename = filename.replace(
            "results_summary_bestWalker2d-v3TQC_Walker2d-v3_3_std.csv", "\u03B1=2"
        )
        filename = filename.replace(
            "results_summary_bestWalker2d-v3TQC_Walker2d-v3_2_std.csv", "\u03B1=1"
        )
        filename = filename.replace(
            "results_summary_bestWalker2d-v3SAC_Walker2d-v3_1_std.csv", "SAC"
        )
        filename = filename.replace(
            "results_summaryWalker2d-v3TQC_Walker2d-v3_1_var.csv", "\u03B1=0"
        )
        filename = filename.replace(
            "results_summary_bestWalker2d-v3TQC_Walker2d-v3_5_std.csv", "\u03B1=4"
        )

        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3TQC_HalfCheetah_1_std.csv", "\u03B1=0"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3TQC_HalfCheetah_2_std.csv", "\u03B1=1"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3TQC_HalfCheetah_3_std.csv", "\u03B1=2"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3TQC_HalfCheetah_4_std.csv", "\u03B1=3"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3TQC_HalfCheetah_5_std.csv", "\u03B1=4"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3TQC_HalfCheetah_6_std.csv", "\u03B1=5"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3SAC_HalfCheetah_1_std.csv", "SAC"
        )

        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3replay_TQC_HalfCheetah_1_std.csv",
            "\u03B1=0",
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3replay_TQC_HalfCheetah_2_std.csv",
            "\u03B1=0.5",
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3replay_TQC_HalfCheetah_3_std.csv",
            "\u03B1=0.6",
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3replay_TQC_HalfCheetah_4_std.csv",
            "\u03B1=1",
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3replay_TQC_HalfCheetah_5_std.csv",
            "\u03B1=1.5",
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3replay_TQC_HalfCheetah_6_std.csv",
            "\u03B1=2",
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3replay_SAC_HalfCheetah_1_std.csv", "SAC"
        )

        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3replay_SAC_HalfCheetah_std.csv", "SAC"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3replay_TQC_HalfCheetah-v3_1_std.csv",
            "\u03B1=0",
        )
        filename = filename.replace(
            "results_summaryWalker2d-v3replay_TQC_Walker2D_5_std.csv", "\u03B1=4"
        )
        filename = filename.replace(
            " results_summary_bestWalker2d-v3TQC_Walker2d-v3_5_std.csv", "\u03B1=4"
        )



        filename = filename.replace(
            "results_summary_bestHopper-v3Bellman_aC_TQC_Hopper_3_std.csv", "\u03B1=5"
        )

        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3Bellman_aC_TQC_HalfCheetah_6_std.csv", "\u03B1=2"
        )

        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3nsBellman_aC_TQC_HalfCheetah_2_std.csv", "\u03B1=0.5"
        )


        filename = filename.replace(
            "results_summary_bestWalker2d-v3Walker2d-v3bellman_True1__noise_a_0_noise_s_0quantile0_penal2.0.csv", "\u03B1=2"
        )

        filename = filename.replace(
            "results_summary_bestWalker2d-v3Walker2d-v3bellman_True1__noise_a_0_noise_s_0quantile0_penal5.0.csv", "\u03B1=5"
        )

        ############
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.0001.csv", "\u03B1=0"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.012.csv", "SAC"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal1.0.csv", "\u03B1=1"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal2.0.csv", "\u03B1=2"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal3.0.csv", "\u03B1=3"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal4.0.csv", "\u03B1=4"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal5.0.csv", "\u03B1=5"
        )





        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal0.0001.csv", "\u03B1=0"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal0.012.csv", "SAC"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal1.0.csv", "\u03B1=1"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal2.0.csv", "\u03B1=2"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal3.0.csv", "\u03B1=3"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal4.0.csv", "\u03B1=4"
        )
        filename = filename.replace(
            "results_summary_bestHopper-v3Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal5.0.csv", "\u03B1=5"
        )
       #########
        filename = filename.replace(
            "results_summary_bestWalker2d-v3Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.001.csv", "\u03B1=0"
        )
        filename = filename.replace(
            "results_summary_bestWalker2d-v3Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.012.csv", "SAC"
        )
        filename = filename.replace(
            "results_summary_bestWalker2d-v3Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal1.0.csv", "\u03B1=1"
        )
        filename = filename.replace(
            "results_summary_bestWalker2d-v3Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal2.0.csv", "\u03B1=2"
        )
        filename = filename.replace(
            "results_summary_bestWalker2d-v3Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal3.0.csv", "\u03B1=3"
        )
        filename = filename.replace(
            "results_summary_bestWalker2d-v3Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal4.0.csv", "\u03B1=4"
        )
        filename = filename.replace(
            "results_summary_bestWalker2d-v3Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal5.0.csv", "\u03B1=5"
        )
        ###################

        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.1.csv", "\u03B1=0.1"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.012.csv", "SAC"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.5.csv", "\u03B1=0.5"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal2.0.csv", "\u03B1=2"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal1.5.csv", "\u03B1=1.5"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal2.5.csv", "\u03B1=2.5"
        )
        filename = filename.replace(
            "results_summary_bestHalfCheetah-v3HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.0001.csv", "\u03B1=0"
        )
       
        with plt.style.context("ggplot"):
            name = filename
            print(name)
            test_csv_mean(f, path, name=name, n=len(csv_files), i=i)
            i += 1

    plt.xlabel("Relative mass")
    plt.ylabel("Mean Reward")
    handles, labels = plt.gca().get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
    plt.legend(handles, labels)
    os.makedirs(path, exist_ok=True)
    plt.savefig(path + "fig_{}.pdf".format(name))
    plt.tight_layout()
    plt.show()


# plot the graph
test_plot(path)
