import argparse
import os
import pickle

import numpy as np
import pandas as pd
import pytablewriter
import seaborn
from matplotlib import pyplot as plt

parser = argparse.ArgumentParser("Gather results, plot them and create table")
parser.add_argument("-a", "--algos", help="Algorithms to include", nargs="+", type=str)
parser.add_argument("-e", "--env", help="Environments to include", nargs="+", type=str)
parser.add_argument("-f", "--exp-folders", help="Folders to include", nargs="+", type=str)
parser.add_argument("-l", "--labels", help="Label for each folder", nargs="+", type=str)
parser.add_argument("-loc", "--legend-loc", help="The location of the legend.", type=str, default="best")
parser.add_argument("--figsize", help="Figure size, width, height in inches.", nargs=2, type=int, default=[14, 6])
parser.add_argument("--fontsize", help="Font size", type=int, default=14)
parser.add_argument("-o", "--output", help="Output filename (pickle file), where to save the post-processed data", type=str)
parser.add_argument("-ci", "--ci-size", help="Confidence interval size", type=float, default=95)
parser.add_argument("-latex", "--latex", help="Enable latex support", action="store_true", default=False)


args = parser.parse_args()

# Seaborn style
seaborn.set(style="whitegrid", font_scale=1.8)


# Enable LaTeX support
if args.latex:
    plt.rc("text", usetex=True)

results = {}
post_processed_results = {}

args.algos = [algo.upper() for algo in args.algos]
algo_names = {algo: algo.replace("_", " ") for algo in args.algos}

if args.labels is None:
    args.labels = args.exp_folders

# Convert to pandas dataframe, in order to use seaborn
algos_df, labels_df, envs_df, scores = [], [], [], []

for env in args.env:
    results[env] = {}
    post_processed_results[env] = {}

    for algo in args.algos:
        for folder_idx, exp_folder in enumerate(args.exp_folders):
            log_path = os.path.join(exp_folder, algo.lower())

            if not os.path.isdir(log_path):
                print(f"Skipping {log_path}")
                continue

            dirs = [
                os.path.join(log_path, d)
                for d in os.listdir(log_path)
                if (env in d and os.path.isdir(os.path.join(log_path, d)))
            ]

            merged_results = []
            result_list = []
            for dirname in dirs:
                try:
                    episode_rewards = np.load(os.path.join(dirname, "evaluations.npz"))["episode_rewards"]
                    if len(episode_rewards) == 0:
                        raise FileNotFoundError
                except FileNotFoundError:
                    print(f"Eval not found for {dirname}")
                    continue

                key = f"{algo_names[algo]}-{args.labels[folder_idx]}"

                for score in episode_rewards:
                    algos_df.append(algo_names[algo])
                    labels_df.append(args.labels[folder_idx])
                    envs_df.append(env)
                    # Normalize score to compare different envs
                    # if len(args.env) > 1:
                    #     score = normalize_score(score, env)
                    scores.append(float(score))
                    merged_results.append(float(score))
                result_list.append(episode_rewards)

            merged_results = np.array(merged_results)

            std_error = merged_results.std() / np.sqrt(len(merged_results))
            aggregated_scores = np.array(result_list)

            results[env][key] = episode_rewards
            # (n_runs, n_evals) -> (n_runs,)
            mean_best_candidates = aggregated_scores.mean(axis=1)

            # Convert to RL-Zoo format
            post_processed_results[env][key] = {
                # "timesteps": timesteps,
                "mean": merged_results.mean(),
                "std_error": std_error,
                # shape: (n_runs,)
                "last_evals": mean_best_candidates,
                "std_error_last_eval": mean_best_candidates.std() / np.sqrt(len(mean_best_candidates)),
                # "mean_per_eval": mean_per_eval,
                "display": f"{np.mean(merged_results):.0f} +/- {std_error:.0f}",
            }

# Markdown Table
writer = pytablewriter.MarkdownTableWriter(max_precision=3)
writer.table_name = "results_table"

headers = ["Environments"]

# One additional row for the subheader
value_matrix = [[] for i in range(len(args.env) + 1)]

headers = ["Environments"]
# Header and sub-header
value_matrix[0].append("")
for algo in args.algos:
    for label in args.labels:
        value_matrix[0].append(label)
        headers.append(algo_names[algo])

writer.headers = headers

for i, env in enumerate(args.env, start=1):
    value_matrix[i].append(env)
    for algo in args.algos:
        for label in args.labels:
            key = f"{algo_names[algo]}-{label}"
            value_matrix[i].append(f'{post_processed_results[env].get(key, {}).get("display", "0.0 +/- 0.0")}')

writer.value_matrix = value_matrix
writer.write_table()

# Pandas dataframe for plotting with seaborn
data_frame = pd.DataFrame(
    data={
        "Algorithm": algos_df,
        "Label": labels_df,
        "Environment": envs_df,
        "Normalized Score": scores,
    }
)

post_processed_results["results_table"] = {"headers": headers, "value_matrix": value_matrix}

if args.output is not None:
    print(f"Saving to {args.output}.pkl")
    with open(f"{args.output}.pkl", "wb") as file_handler:
        pickle.dump(post_processed_results, file_handler)

# Set the figure size
plt.figure(figsize=args.figsize)
# plt.title("Robustness to Sensor Noise and Failures", fontsize=25)
ax = seaborn.barplot(
    data=data_frame,
    x="Environment",
    y="Normalized Score",
    hue="Algorithm",
    errorbar=("ci", args.ci_size),
    n_boot=2000,
    capsize=0.05,
)
# ax.set_xlabel("")
plt.legend()
plt.tight_layout()
plt.show()
