import glob
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from typing import Callable, Tuple

#  Number of points for smoothing of curves
EPISODES_WINDOW = 15


# Comment the path of the environment desired

########### For penal A
path = "../progress_test_continuous/walker2d/progress/"

path='../progress_test_continuous/hopper/progress/'

path='../progress_test_continuous/halfcheetah/progress/'

########### For penal AC
path="../Bellman_critic_results/Hopper/results_train/"

path="../Bellman_critic_results/Walker/results_train/"

path="../Bellman_critic_results/Half/results_train/"

##### For moving env 

path="../noisy_env/half/train/"
path="../noisy_env/hopper/train/"
#ot_cpath="../noisy_env/walker/train/"



def rolling_window(array: np.ndarray, window: int) -> np.ndarray:
    """
    Apply a rolling window to a np.ndarray
    :param array: the input Array
    :param window: length of the rolling window
    :return: rolling window on the input array
    """
    shape = array.shape[:-1] + (array.shape[-1] - window + 1, window)
    strides = array.strides + (array.strides[-1],)
    return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)


def window_func(
    var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Apply a function to the rolling window of 2 arrays
    :param var_1: variable 1
    :param var_2: variable 2
    :param window: length of the rolling window
    :param func: function to apply on the rolling window on variable 2 (such as np.mean)
    :return:  the rolling output with applied function
    """
    var_2_window = rolling_window(var_2, window)
    function_on_var2 = func(var_2_window, axis=-1)
    return var_1[window - 1 :], function_on_var2


def plot_csv(log_dir_file, log_dir_out, n, name=None, filename="half"):
    """
    function for one plot with window_func
    """
    data = pd.read_csv(log_dir_file)

    y = data["eval/mean_reward"].values
    std = data["eval/std_reward"].values
    x = data["time/total_timesteps"].values
    min_y = data["eval/min_reward"].values

    filtre = np.invert(np.isnan(y))

    x = x[filtre]
    y = y[filtre]
    std = std[filtre]
    min_y = min_y[filtre]

    # Do not plot the smoothed curve at all if the timeseries is shorter than window size.
    if x.shape[0] >= EPISODES_WINDOW:
        # Compute and plot rolling mean with window of size EPISODE_WINDOW
        x_new, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean)
        #print(y_mean)
        x_new, y_min_mean = window_func(x, min_y, EPISODES_WINDOW, np.mean)
        x, std_mean = window_func(x, std, EPISODES_WINDOW, np.mean)

        clrs = sns.color_palette(n_colors=7)

        if "HalfCheetah" in filename:
            filtre = x < 5e6
            if "SAC" in name:
                color = clrs[0]
            if "\u03B1=0" in name:
                color = clrs[1]
            if "\u03B1=0.1" in name:
                color = clrs[2]
            if "\u03B1=0.5" in name:
                color = clrs[3]
            if "\u03B1=1" in name:
                color = clrs[4]
            if "\u03B1=1.5" in name:
                color = clrs[5]
            if "\u03B1=2" in name:
                color = clrs[6]
        else:
            filtre = x < 5e6

            if "SAC" in name:
                color = clrs[0]
            if "\u03B1=0" in name:
                color = clrs[1]
            if "\u03B1=1" in name:
                color = clrs[2]
            if "\u03B1=2" in name:
                color = clrs[3]
            if "\u03B1=3" in name:
                color = clrs[4]
            if "\u03B1=4" in name:
                color = clrs[5]
            if "\u03B1=5" in name:
                color = clrs[6]

        if "Hopper" in filename:
            filtre = x < 2.5e6

        filtre= filtre & (x>5e5)

        #print(y_mean[filtre])

        plt.plot(x[filtre], y_mean[filtre], label=name, c=color)
        # add the standard deviation on the plot
        plt.fill_between(
            x[filtre],
            y_mean[filtre] - std_mean[filtre],
            y_mean[filtre] + std_mean[filtre],
            alpha=0.5,
            linestyle="dashdot",
            facecolor=color,
        )  # edgecolor='#CC4F1B', facecolor='#FF9848',


def all_plot(path):

    """
    funcion to call with the path to plot all csv in the folder

    Return a graph with the mean progressiion smoothed by EPISODES_WINDOW number of points in function of the number of iterations
    """

    cmap = plt.get_cmap("inferno")
    # slicedCM = cmap(np.linspace(0, 1, len(csv_files)))
    # lis tous les fichiers avec une extension csv dans un folder qui est relié par la variable path
    csv_files = glob.glob(os.path.join(path, "*.csv"))
    print(csv_files)
    for f in csv_files:
        print(f)
        filename = os.path.basename(f)

        filename = filename.replace("progress_TQC_HalfCheetah_1_std.csv", "\u03B1=0")
        filename = filename.replace("progress_TQC_HalfCheetah_2_std.csv", "\u03B1=0.5")
        filename = filename.replace("progress_TQC_HalfCheetah_3_std.csv", "\u03B1=0.6")
        filename = filename.replace("progress_TQC_HalfCheetah_4_std.csv", "\u03B1=1")
        filename = filename.replace("progress_TQC_HalfCheetah_5_std.csv", "\u03B1=1.5")
        filename = filename.replace("progress_TQC_HalfCheetah_6_std.csv", "\u03B1=2")
        filename = filename.replace("progress_SAC_HalfCheetah_1_std.csv", "SAC")

        filename = filename.replace("progress_TQC_Walker2d-v3_5_std.csv", "\u03B1=4")
        filename = filename.replace("progress_TQC_Walker2d-v3_6_std.csv", "\u03B1=5")
        filename = filename.replace("progress_TQC_Walker2d-v3_4_std.csv", "\u03B1=3")
        filename = filename.replace("progress_TQC_Walker2d-v3_3_std.csv", "\u03B1=2")
        filename = filename.replace("progress_TQC_Walker2d-v3_2_std.csv", "\u03B1=1")
        filename = filename.replace("progress_TQC_Walker2d-v3_1_std.csv", "\u03B1=0")
        filename = filename.replace("progress_SAC_Walker2d-v3_1_std.csv", "SAC")

        filename = filename.replace("progress_Walker2d-v3bellman_True1__noise_a_0_noise_s_0quantile0_penal2.0.csv",'\u03B1=2')
        filename = filename.replace("progress_Walker2d-v3bellman_True1__noise_a_0_noise_s_0quantile0_penal5.0.csv",'\u03B1=5')

        filename = filename.replace("progress_Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal3.0.csv", "progress_TQC_HOPPER_3_std.csv")
        filename = filename.replace("progress_Bellman_aC_TQC_Hopper_5_std.csv", "progress_TQC_HOPPER_6_std.csv")
        filename = filename.replace("progress_Bellman_aC_TQC_Hopper_3_std.csv", "progress_TQC_HOPPER_3_std.csv")

        filename = filename.replace("progress_Bellman_aC_TQC_HalfCheetah_6_std.csv",'\u03B1=2')
        filename = filename.replace("progress_nsBellman_aC_TQC_HalfCheetah_2_std.csv",'\u03B1=0.5')

        filename = filename.replace("progress_TQC_HOPPER_1_std.csv", "\u03B1=0")
        filename = filename.replace("progress_TQC_HOPPER_2_std.csv", "\u03B1=1")
        filename = filename.replace("progress_TQC_HOPPER_3_std.csv", "\u03B1=2")
        filename = filename.replace("progress_TQC_HOPPER_4_std.csv", "\u03B1=3")
        filename = filename.replace("progress_TQC_HOPPER_5_std.csv", "\u03B1=4")
        filename = filename.replace("progress_TQC_HOPPER_6_std.csv", "\u03B1=5")
        filename = filename.replace("progress_replay_SAC_Hopper-v3_std.csv", "SAC")


        filename = filename.replace(
            "progress_Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.0001.csv", "\u03B1=0"
        )
        filename = filename.replace(
            "progress_Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.012.csv", "SAC"
        )
        filename = filename.replace(
            "progress_Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal1.0.csv", "SAC"
        )
        filename = filename.replace(
            "progress_Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal2.0.csv", "\u03B1=2"
        )
        filename = filename.replace(
            "progress_Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal3.0.csv", "\u03B1=3"
        )
        filename = filename.replace(
            "progress_Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal4.0.csv", "\u03B1=4"
        )
        filename = filename.replace(
            "progress_Hopper-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal5.0.csv", "\u03B1=5"

        )

        filename = filename.replace(
            "progress_Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal0.0001.csv", "\u03B1=0"
        )
        filename = filename.replace(
            "progress_Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal0.012.csv", "SAC"
        )
        filename = filename.replace(
            "progress_Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal1.0.csv", "SAC"
        )
        filename = filename.replace(
            "progress_Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal2.0.csv", "\u03B1=2"
        )
        filename = filename.replace(
            "progress_Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal3.0.csv", "\u03B1=3"
        )
        filename = filename.replace(
            "progress_Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal4.0.csv", "\u03B1=4"
        )
        filename = filename.replace(
            "progress_Hopper-v3bellman_False0.5__noise_a_0.01_noise_s_0quantile0_penal5.0.csv", "\u03B1=5"
        )
       #########
        filename = filename.replace(
            "progress_Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.001.csv", "\u03B1=0"
        )
        filename = filename.replace(
            "progress_Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.012.csv", "SAC"
        )
        filename = filename.replace(
            "progress_Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal1.0.csv", "\u03B1=1"
        )
        filename = filename.replace(
            "progress_Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal2.0.csv", "\u03B1=2"
        )
        filename = filename.replace(
            "progress_Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal3.0.csv", "\u03B1=3"
        )
        filename = filename.replace(
            "progress_Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal4.0.csv", "\u03B1=4"
        )
        filename = filename.replace(
            "progress_Walker2d-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal5.0.csv", "\u03B1=5"
        )

        ##########

        filename = filename.replace(
            "progress_HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.1.csv", "\u03B1=0.1"
        )
        filename = filename.replace(
            "progress_HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.012.csv", "SAC"
        )
        filename = filename.replace(
            "progress_HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.5.csv", "\u03B1=0.5"
        )
        filename = filename.replace(
            "progress_HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal2.0.csv", "\u03B1=2"
        )
        filename = filename.replace(
            "progress_HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal1.5.csv", "\u03B1=2"
        )
        filename = filename.replace(
            "progress_HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal2.5.csv", "\u03B1=2.5"
        )
        filename = filename.replace(
            "progress_HalfCheetah-v3bellman_True0.5__noise_a_0.01_noise_s_0quantile0_penal0.0001.csv", "\u03B1=0"
        )
       
        name = filename
        print(name)

        with plt.style.context("ggplot"):
            plot_csv(f, path, name=name, n=len(csv_files), filename=os.path.basename(f))

    handles, labels = plt.gca().get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
    plt.legend(handles, labels)

    plt.title("Mean Reward of 20 trajectories ± standard deviation")
    plt.xlabel("Number of steps")
    plt.ylabel("Mean Reward")
    # plt.legend()
    plt.tight_layout()
    os.makedirs(path, exist_ok=True)
    plt.savefig(path + "fig_{}.pdf".format(name))
    plt.show()


all_plot(path)
