import argparse
import os
from pathlib import Path

import numpy as np
import pandas as pd
import pytablewriter
import seaborn
from matplotlib import pyplot as plt
from rliable import metrics
from score_normalization import normalize_score

parser = argparse.ArgumentParser("Gather results, plot them and create table")
parser.add_argument("-a", "--algos", help="Algorithms to include", nargs="+", type=str)
parser.add_argument("-e", "--env", help="Environments to include", nargs="+", type=str)
parser.add_argument("-f", "--exp-folders", help="Folders to include", nargs="+", type=str)
parser.add_argument("-l", "--labels", help="Label for each folder", nargs="+", type=str)
parser.add_argument("-loc", "--legend-loc", help="The location of the legend.", type=str, default="best")
parser.add_argument("--figsize", help="Figure size, width, height in inches.", nargs=2, type=int, default=[14, 6])
parser.add_argument("--fontsize", help="Font size", type=int, default=14)

parser.add_argument("-ci", "--ci-size", help="Confidence interval size", type=float, default=95)
parser.add_argument("-latex", "--latex", help="Enable latex support", action="store_true", default=False)


args = parser.parse_args()

# Seaborn style
seaborn.set(style="whitegrid", font_scale=1.8)


# Enable LaTeX support
if args.latex:
    plt.rc("text", usetex=True)

results = {}
post_processed_results = {}

args.algos = [algo.upper() for algo in args.algos]
algo_names = {algo: algo.replace("_", " ") for algo in args.algos}

if args.labels is None:
    args.labels = args.exp_folders

# Convert to pandas dataframe, in order to use seaborn
algos_df, labels_df, envs_df, scores = [], [], [], []

for env in args.env:
    results[env] = {}
    post_processed_results[env] = {}

    for algo in args.algos:
        for folder_idx, exp_folder in enumerate(args.exp_folders):
            log_path = os.path.join(exp_folder, algo.lower(), env)

            # Special case for open-loop: not affected by sensor failure
            # as it is not using it
            if algo == "OPEN_LOOP" and "no_noise" not in exp_folder and "external_force" not in exp_folder:
                log_path = str(Path(exp_folder).parent / "no_noise" / algo.lower() / env)

            if not os.path.isdir(log_path):
                continue

            try:
                episode_rewards = np.load(os.path.join(log_path, "evaluations.npz"))["episode_rewards"]
                if len(episode_rewards) == 0:
                    raise FileNotFoundError
            except FileNotFoundError:
                print(f"Eval not found for {log_path}")
                # Use CSV reader
                try:
                    from stable_baselines3.common.monitor import LoadMonitorResultsError, load_results

                    data_frame = load_results(log_path)
                    episode_rewards = np.array(data_frame["r"])
                    np.savez(
                        os.path.join(log_path, "evaluations.npz"),
                        episode_rewards=episode_rewards,
                        episode_lengths=np.array(data_frame["l"]),
                    )
                except (LoadMonitorResultsError, KeyError):
                    print(f"No data available for {log_path}")
                    continue
                continue

            key = f"{algo_names[algo]}-{args.labels[folder_idx]}"

            # Force to have the same number of episodes per env
            # otherwise there will be an imbalance
            for score in episode_rewards[:25]:
                algos_df.append(algo_names[algo])
                labels_df.append(args.labels[folder_idx])
                envs_df.append(env)
                # Normalize score to compare different envs
                if len(args.env) > 1:
                    score = normalize_score(score, env)
                scores.append(float(score))

            std_error = episode_rewards.std() / np.sqrt(len(episode_rewards))

            results[env][key] = episode_rewards
            post_processed_results[env][key] = f"{np.mean(episode_rewards):.0f} +/- {std_error:.0f}"

# Markdown Table
writer = pytablewriter.MarkdownTableWriter(max_precision=3)
writer.table_name = "results_table"

headers = ["Environments"]

# One additional row for the subheader
value_matrix = [[] for i in range(len(args.env) + 1)]

headers = ["Environments"]
# Header and sub-header
value_matrix[0].append("")
for algo in args.algos:
    for label in args.labels:
        value_matrix[0].append(label)
        headers.append(algo_names[algo])

writer.headers = headers

for i, env in enumerate(args.env, start=1):
    value_matrix[i].append(env)
    for algo in args.algos:
        for label in args.labels:
            key = f"{algo_names[algo]}-{label}"
            value_matrix[i].append(f'{post_processed_results[env].get(key, "0.0 +/- 0.0")}')

writer.value_matrix = value_matrix
writer.write_table()

# Pandas dataframe for plotting with seaborn
data_frame = pd.DataFrame(
    data={
        "Algorithm": algos_df,
        "Failure": labels_df,
        "Environment": envs_df,
        "Normalized Score": scores,
    }
)

# Set the figure size
plt.figure(figsize=args.figsize)
plt.title("Robustness to Sensor Noise and Failures", fontsize=25)
ax = seaborn.barplot(
    data=data_frame,
    x="Failure",
    y="Normalized Score",
    hue="Algorithm",
    estimator=metrics.aggregate_iqm,
    errorbar=("ci", args.ci_size),
    n_boot=2000,
    capsize=0.05,
)
ax.set_xlabel("")
plt.legend(loc="upper center")
plt.tight_layout()
plt.show()
