import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import os
import seaborn as sns
import matplotlib
from scipy.signal import savgol_filter


def load_results_several_path(
    paths: list, interval=100, with_l=False, additionals=[]
) -> pd.DataFrame:
    """
    Load all Monitor logs from directories path if monitor matches``*monitor.csv``

    :param path: (str) the directory path containing the log file(s)
    :return: (pandas.DataFrame) the logged data
    """

    data_frames, headers = [], []
    for file_name in paths:
        data_frame = pd.read_csv(file_name, names=["time_steps", "r"])
        data_frames.append(data_frame)

    for data in data_frames:
        data["r"] = savgol_filter(data["r"], window_length=15, polyorder=0)
    data_frames = [data for data in data_frames if data.shape[0] > 10]

    if with_l:
        new_dataframe_l = pd.DataFrame(
            columns=["time_steps"] + list(np.arange(len(data_frames)))
        )
    add_df = []
    for key in additionals:
        add_df.append(
            pd.DataFrame(columns=["time_steps"] + list(np.arange(len(data_frames))))
        )

    new_dataframe = pd.DataFrame(
        columns=["time_steps"] + list(np.arange(len(data_frames)))
    )
    # for idx in range()

    min_steps = min([data["time_steps"].to_numpy()[-1] for data in data_frames])
    max_steps = max([data["time_steps"].to_numpy()[0] for data in data_frames])

    if with_l:
        for idx in range(int(max_steps / interval) + 1, int(min_steps / interval)):
            new_dataframe_l.loc[idx] = [idx * interval] + [
                data["l"].loc[
                    np.where(data["time_steps"].to_numpy() < idx * interval)[0][-1]
                ]
                for data in data_frames
            ]
    for df, key in zip(add_df, additionals):
        for idx in range(int(max_steps / interval) + 1, int(min_steps / interval)):
            df.loc[idx] = [idx * interval] + [
                data[key].loc[
                    np.where(data["time_steps"].to_numpy() < idx * interval)[0][-1]
                ]
                for data in data_frames
            ]

    for idx in range(int(max_steps / interval) + 1, int(min_steps / interval)):
        new_dataframe.loc[idx] = [idx * interval] + [
            data["r"].loc[
                np.where(data["time_steps"].to_numpy() < idx * interval)[0][-1]
            ]
            for data in data_frames
        ]

    new_dataframe = pd.melt(new_dataframe, id_vars=["time_steps"])
    if with_l:
        new_dataframe_l = pd.melt(new_dataframe_l, id_vars=["time_steps"])
        new_dataframe_l["length"] = new_dataframe_l["value"]

    add_df2 = []
    for df, key in zip(add_df, additionals):
        df = pd.melt(df, id_vars=["time_steps"])
        df[key] = df["value"]
        add_df2.append(df)

    new_dataframe["reward"] = new_dataframe["value"]

    for df in add_df2:
        new_dataframe = pd.merge(
            new_dataframe,
            df,
            how="left",
            left_on=["time_steps", "variable"],
            right_on=["time_steps", "variable"],
        )

    if with_l:
        return pd.merge(
            new_dataframe,
            new_dataframe_l,
            how="left",
            left_on=["time_steps", "variable"],
            right_on=["time_steps", "variable"],
        )
    return new_dataframe


tex_fonts = {
    "text.usetex": False,
    "font.family": "DejaVu Sans",
    "axes.labelsize": 16,
    "font.size": 16,
    "legend.fontsize": 13,
    "xtick.labelsize": 13,
    "ytick.labelsize": 13,
}

plt.rcParams.update(tex_fonts)


def plot_from_folder(log_folder, methods, save_name):
    paths = {}
    for method in methods[:]:
        paths[method] = []
        for file in os.listdir(os.path.join(log_folder, method)):
            paths[method].append(os.path.join(log_folder, method, file))

    dfs = {}
    for method in methods:
        dfs[method] = load_results_several_path(paths[method], interval=5000)
        # dfs[method] = dfs[method].query("time_steps<2000000")
        dfs[method]["method"] = (
            "Q-learning" if method == "classic" else "Q-learning + EASEE"
        )
    df = pd.concat([dfs[m] for m in methods], axis=0).reset_index(drop=True)
    df.head()
    x_size = 5
    # plt.figure(figsize=(x_size, x_size * 3 / 4))
    # plt.figure(figsize= (7,5))
    sns.lineplot(
        data=df,
        x="time_steps",
        y="reward",
        hue="method",
        palette=sns.color_palette("colorblind")[:2][::-1],
    )
    plt.xlabel("Number of timesteps")
    # plt.xticks(rotation = 45)
    plt.ylabel("Reward")
    plt.legend(title=None)
    plt.tight_layout()
    print(log_folder)
    # plt.show()
    plt.savefig(os.path.join("images", save_name))
