"""
Plot training reward
"""
import argparse
import os
import pickle

import numpy as np
import pandas as pd
import pytablewriter
import seaborn
from matplotlib import pyplot as plt
from stable_baselines3.common.monitor import LoadMonitorResultsError, load_results
from stable_baselines3.common.results_plotter import X_EPISODES, X_TIMESTEPS, X_WALLTIME, ts2xy, window_func

# Activate seaborn
seaborn.set()

parser = argparse.ArgumentParser("Gather results, plot training success")
parser.add_argument("-a", "--algo", help="Algorithm to include", type=str, required=True)
parser.add_argument("-e", "--env", help="Environment to include", type=str, required=True)
parser.add_argument("-f", "--exp-folder", help="Folders to include", type=str, required=True)
parser.add_argument("--figsize", help="Figure size, width, height in inches.", nargs=2, type=int, default=[6.4, 4.8])
parser.add_argument("--fontsize", help="Font size", type=int, default=14)
parser.add_argument("-max", "--max-timesteps", help="Max number of timesteps to display", type=int)
parser.add_argument("-x", "--x-axis", help="X-axis", choices=["steps", "episodes", "time"], type=str, default="steps")
parser.add_argument("-y", "--y-axis", help="Y-axis", choices=["success", "reward", "continuity"], type=str, default="reward")
parser.add_argument("-w", "--episode-window", help="Rolling window size", type=int, default=100)
parser.add_argument("-loc", "--legend-loc", help="The location of the legend.", type=str, default="best")
parser.add_argument("-o", "--output", help="Output filename (DataFrame), where to save the post-processed data", type=str)
parser.add_argument("--skip", type=str)
parser.add_argument("--pareto", action="store_true", default=False, help="Display Pareto front")
parser.add_argument("--line", type=str)

args = parser.parse_args()


algo = args.algo
env = args.env

x_axis = {"steps": X_TIMESTEPS, "episodes": X_EPISODES, "time": X_WALLTIME}[args.x_axis]
x_label = {"steps": "Timesteps", "episodes": "Episodes", "time": "Walltime (in hours)"}[args.x_axis]

y_axis = {"success": "is_success", "reward": "r", "continuity": "continuity_score"}[args.y_axis]
y_label = {
    "success": "Training Success Rate",
    "reward": "Training Episodic Reward",
    "continuity": "Training Continuity Score",
}[args.y_axis]


# plt.figure(y_label, figsize=args.figsize)
# plt.title(y_label, fontsize=args.fontsize)
# plt.xlabel(f"{x_label}", fontsize=args.fontsize)
# plt.ylabel(y_label, fontsize=args.fontsize)

experiments = os.listdir(args.exp_folder)

data_frames = []
names = []
for experiment in experiments:

    if not os.path.isdir(os.path.join(args.exp_folder, experiment)):
        continue

    log_path = os.path.join(args.exp_folder, experiment, algo)

    # Skip some for readability
    if args.skip is not None and args.skip in experiment:
        continue

    names.append(experiment)

    dirs = [
        os.path.join(log_path, folder)
        for folder in os.listdir(log_path)
        if (env in folder and os.path.isdir(os.path.join(log_path, folder)))
    ]

    for exp_id, folder in enumerate(dirs):
        try:
            data_frame = load_results(folder)
        except LoadMonitorResultsError:
            continue
        if args.max_timesteps is not None:
            data_frame = data_frame[data_frame.l.cumsum() <= args.max_timesteps]
        try:
            y = np.array(data_frame[y_axis])
        except KeyError:
            print(f"No data available for {folder}")
            continue

        x, _ = ts2xy(data_frame, x_axis)
        data_frame["experiment"] = [experiment] * len(data_frame)
        data_frame["experiment_id"] = [f"{experiment}_{exp_id}"] * len(data_frame)

        # Load eval
        log = np.load(os.path.join(folder, "evaluations.npz"))
        data_frame["last_eval"] = np.squeeze(log["results"].mean(axis=1))[-1]

        if experiment.startswith("noise"):
            noise_type = "noise repeat"
        elif experiment.startswith("scaling"):
            noise_type = "scaling"
        elif experiment.startswith("inde"):
            noise_type = "independent repeat"
        elif experiment.startswith("unstructured"):
            noise_type = "unstructured"
        elif experiment.startswith("ou_noise"):
            noise_type = "ou_noise"
        elif experiment.startswith("param"):
            noise_type = "param_noise"
        elif experiment.startswith("no_noise"):
            noise_type = "no_noise"
        else:
            if experiment.startswith("sde"):
                noise_type = "gSDE"
            else:
                noise_type = experiment.split("_")[0]

        data_frame["noise_type"] = [noise_type] * len(data_frame)
        if noise_type in ["unstructured"]:
            sample_freq = 1
        elif noise_type in ["no_noise", "ou_noise"]:
            sample_freq = 1000
        else:
            sample_freq = int(experiment.split("_")[-1])
        data_frame["sample_freq"] = [sample_freq] * len(data_frame)
        data_frame[x_axis] = x
        data_frames.append(data_frame)

        # # Do not plot the smoothed curve at all if the timeseries is shorter than window size.
        # if x.shape[0] >= args.episode_window:
        #     # Compute and plot rolling mean with window of size args.episode_window
        #     x, y_mean = window_func(x, y, args.episode_window, np.mean)
        #     plt.plot(x, y_mean, linewidth=2, label=folder.split("/")[-1])


# re-order
# names = list(enumerate(names))
names.sort(key=lambda x: int(x.split("_")[-1]))
# order = [i for i, _ in names]
# print(names)

dataframe = pd.concat(data_frames, axis=0)

if args.y_axis == "reward":
    ylim = (-90, -5)  # state
    # ylim = (-300, -50) # acc
else:
    ylim = (0, 10)  # state
    # ylim = (0, 20)  # acc

# Plot standard deviation (ci="sd"), if left by default,
# plot confidence interval with bootstrapping
if not args.pareto:
    facet_grid = seaborn.relplot(
        x=x_axis,
        y=y_axis,
        kind="line",
        ci="sd",
        hue="experiment",
        data=dataframe,
        hue_order=names,
        facet_kws=dict(legend_out=False, ylim=ylim),
    )


# Pandas magic to aggregate all trials per experiment
# for train
# we can also use sum for continuity score
# .groupby(["experiment", "timesteps"]).mean()
# Keep only last eval
df = dataframe
df["cost"] = -df["r"]
df["return"] = df["r"]

# Keep only latest timesteps for
idx = df.groupby("experiment_id")["timesteps"].transform(max) == df["timesteps"]
df = df[idx]
df = df.sort_values(["sample_freq"])
# convert to string as it is meant as a label
# df["sample_freq"] = list(map(str, df["sample_freq"]))
# df = df.groupby(["experiment", "timesteps"]).mean()
mean_per_trial = df.groupby(["experiment", "sample_freq", "noise_type"]).mean().reset_index()
# mean_per_trial = mean_per_trial.sort_values(by="sample_freq", key=lambda col: int(col))
mean_per_trial = mean_per_trial.sort_values("sample_freq")
# mean_per_trial["sample_freq"] = list(map(str, mean_per_trial["sample_freq"]))

std_per_trial = df.groupby(["experiment", "sample_freq", "noise_type"]).std().reset_index()
std_per_trial = std_per_trial.sort_values("sample_freq")

# std_per_trial = df.groupby(["experiment", "sample_freq", "noise_type"]).std()
# mean_per_trial["cost_std"] = np.array(std_per_trial["cost"])
# mean_per_trial["continuity_score_std"] = np.array(std_per_trial["continuity_score"])

writer = pytablewriter.MarkdownTableWriter()
writer.table_name = "results_table"

# TODO: replace return by the deterministic evaluation
metrics = ["return", "continuity_score", "last_eval"]
# One additional row for the subheader
value_matrix = [[] for i in range(len(metrics) + 1)]

noise_types = mean_per_trial["noise_type"].unique()

headers = ["Metrics"]
# # Header and sub-header
# value_matrix[0].append("")
# for algo in args.algos:
#     for label in args.labels:
#         value_matrix[0].append(label)
#         headers.append(algo)

post_processed_results = {args.env: {}}

for i, metric in enumerate(metrics):
    value_matrix[i].append(metric)
    for noise_type in noise_types:
        one_type_only = mean_per_trial[mean_per_trial["noise_type"] == noise_type]
        one_type_only_std = std_per_trial[std_per_trial["noise_type"] == noise_type]
        sample_freqs = one_type_only["sample_freq"].unique()
        for sample_freq in sample_freqs:
            exp_name = f"{noise_type}-{sample_freq}"
            if i == 0:
                headers.append(exp_name)
                post_processed_results[args.env][exp_name] = {}

            one_exp = one_type_only[one_type_only["sample_freq"] == sample_freq]
            one_exp_std = one_type_only_std[one_type_only_std["sample_freq"] == sample_freq]
            # Compute std error instead of standard dev
            n_trials = 10
            std_error = float(one_exp_std[metric] / np.sqrt(n_trials))
            value_matrix[i].append(f"{float(one_exp[metric]):.1f} +/- {std_error:.1f}")

            post_processed_results[args.env][exp_name][metric] = float(one_exp[metric])
            post_processed_results[args.env][exp_name]["noise_type"] = noise_type
            post_processed_results[args.env][exp_name]["sample_freq"] = sample_freq
            post_processed_results[args.env][exp_name]["last_eval"] = float(one_exp["last_eval"])

writer.headers = headers
writer.value_matrix = value_matrix
writer.write_table()

markers = {
    "noise repeat": "o",
    "gsde": "h",
    "gSDE": "h",
    "scaling": "X",
    "independent repeat": "s",
    "ou_noise": "s",
    "unstructured": "o",
    "param_noise": "X",
    "no_noise": ">",
}

if args.pareto:
    # plt.figure("Pareto")
    seaborn.scatterplot(
        x="continuity_score",
        y="return",
        hue="sample_freq",
        style="noise_type",
        data=mean_per_trial,
        markers=markers,
        # size="sample_freq",
        palette="deep",
        # hue_norm=(0, 1),
    )

    if args.line is not None:
        one_type_only = mean_per_trial[mean_per_trial["noise_type"] == args.line]
        x = np.array(one_type_only["continuity_score"])
        y = np.array(one_type_only["return"])
        plt.plot(x, y, alpha=0.5)


plt.tight_layout()
plt.show()

if args.output is not None:
    print(f"Saving to {args.output}.csv")
    dataframe.to_csv(args.output + ".csv", index=False)

    print(f"Saving to {args.output}.pkl")
    with open(f"{args.output}.pkl", "wb") as file_handler:
        pickle.dump(post_processed_results, file_handler)
