import pathlib
from collections import defaultdict

import matplotlib.colors
import numpy as np
import scipy.stats
from statsmodels.stats.weightstats import DescrStatsW
import wandb
from matplotlib import pyplot as plt, ticker

plt.rcParams['text.usetex'] = True

X_METRIC = "episode"
X_LABEL = "Episode"
X_METRIC_LIMIT = 13_000

VALID_ENV = "len300"
Y_METRIC = f"validation/{VALID_ENV}/reward_mean"
Y_LABEL = "Reward Mean"

GROUPBY_CONFIG_KEY = "memory_type"

WANDB_PROJECT = "ENTER WANDB PROJECT HERE"
WANDB_ENTITY = "ENTER WANDB USERNAME HERE"

RUN_NAMES = [
    "run-name-1",
    "run-name-2",
    "run-name-3",
    "run-name-4",
]

print(matplotlib.colors.TABLEAU_COLORS)

COLOR_MAP = {
    "sith": "tab:blue",
    "F_sub": "tab:orange",
    "F": "tab:green",
    "sith_sub_sum": "tab:cyan",
    "sith_sub_nosum": "tab:purple",
    "rnn": "tab:brown",
    # "lstm": "tab:cyan",
    "gru": "tab:pink",
    "rnn_frozen": "tab:gray",
    # "lstm_frozen": "tab:gray",
    "gru_frozen": "tab:olive",
}

# COLOR_MAP = {
#     "gru": "tab:blue",
#     "rnn": "tab:orange",
#     "gru_frozen": "tab:green",
#     "rnn_frozen": "tab:red",
# }
# can use a #00FFFF hex, or any named color at https://matplotlib.org/stable/gallery/color/named_colors.html

MARKER_MAP = {
    "sith": "o",
    "F_sub": "^",
    "F": "v",
    "sith_sub_sum": "2",
    "sith_sub_nosum": "<",
    "rnn": ">",
    "lstm": "D",
    "gru": "s",
    "rnn_frozen": "*",
    "gru_frozen": "1",
}

GROUP_LABEL_MAP = {
    "sith": r"$\tilde{f}$",
    "F_sub": r"$F_{sub}$",
    "F": "$F$",
    "sith_sub_sum": r"$\tilde{f}_{subsum}$",
    "sith_sub_nosum": r"$\tilde{f}_{sub}$",
    "rnn": "RNN",
    "lstm": "LSTM",
    "gru": "GRU",
    "rnn_frozen": r"RNN$_{frozen}$",
    "lstm_frozen": r"LSTM$_{frozen}$",
    "gru_frozen": r"GRU$_{frozen}$"
}


def main():
    api = wandb.Api()
    runs = api.runs(f'{WANDB_ENTITY}/{WANDB_PROJECT}', filters={
        "display_name": {"$in": RUN_NAMES}
    })

    print(f"{len(runs)} runs found.")

    if len(RUN_NAMES) != len(runs):
        ret_names = [run.name for run in runs]
        print(f"Missing runs: {set(RUN_NAMES) ^ set(ret_names)}")
        exit(1)

    vals: defaultdict[str, defaultdict[int, list[float]]] = defaultdict(lambda: defaultdict(list))
    metric_names = [X_METRIC, Y_METRIC]

    for run in runs:
        print(f"{run.name} ({run.url})")

        run_group = run.config[GROUPBY_CONFIG_KEY]

        history = run.scan_history(keys=metric_names)
        for row in history:
            if row[X_METRIC] > X_METRIC_LIMIT:
                continue

            vals[run_group][row[X_METRIC]].append(row[Y_METRIC])

    print()

    x_vals = {}
    y_vals = {}

    for group, group_dict in vals.items():
        x_vals[group] = np.array(list(group_dict.keys()))
        y_vals[group] = np.array(list(group_dict.values()))
        print(f"{group:>6}: {y_vals[group].shape[1]} runs {COLOR_MAP[group], MARKER_MAP[group]}")

    print()

    # # METHOD 1
    # fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=600, constrained_layout=True)
    # means = y_vals.mean(axis=0)
    # stds = y_vals.std(axis=0)
    # lower, upper = means - (2 * stds), means + (2 * stds)
    # ax.set_title("2-sigma Confidence Interval (numpy)", fontsize=16)
    # ax.set_xlabel("Episode (thousands)", fontsize=16)
    # ax.set_ylabel("Mean Reward", fontsize=16)
    # ax.fill_between(x_vals, lower, upper, alpha=0.3)
    # ax.plot(x_vals, means, ".-")
    # fig.show()
    #
    # # METHOD 2
    # fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=600, constrained_layout=True)
    # d = DescrStatsW(y_vals)
    # # lower, upper = d.tconfint_mean(alpha=0.05)
    # lower, upper = d.zconfint_mean(alpha=0.05)
    # ax.set_title("95% Z Confidence Interval (statsmodels)", fontsize=16)
    # ax.set_xlabel("Episode (thousands)", fontsize=16)
    # ax.set_ylabel("Mean Reward", fontsize=16)
    # ax.fill_between(x_vals, lower, upper, alpha=0.3)
    # ax.plot(x_vals, d.mean, ".-")
    # fig.show()

    # SEM
    fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=600, constrained_layout=True)
    ax.set_xlabel(X_LABEL, fontsize=10)
    ax.set_xlim(0, X_METRIC_LIMIT)

    ax.xaxis.set_minor_locator(ticker.MultipleLocator(50_000))
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: "0" if x == 0 else f"{int(x / 1000)}k"))

    # ax.yaxis.set_major_locator(ticker.MultipleLocator(5))

    ax.set_ylabel(Y_LABEL, fontsize=10)
    ax.set_ylim(0, 10.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(which="both", axis="both", direction="in")

    for group in x_vals.keys():
        means = y_vals[group].mean(axis=1)
        sem = scipy.stats.sem(y_vals[group], axis=1, ddof=0)
        lower, upper = means - sem, means + sem

        ax.fill_between(x_vals[group], lower, upper, alpha=0.3, facecolor=COLOR_MAP[group])
        ax.plot(x_vals[group],
                means,
                color=COLOR_MAP[group],
                linewidth=1,
                markersize=2,
                # marker=None,
                marker=MARKER_MAP[group],
                label=GROUP_LABEL_MAP[group],
                )

        limit_idx = np.where(x_vals[group] == X_METRIC_LIMIT)[0]
        print(f"{GROUP_LABEL_MAP[group]:<20} & {means[limit_idx][0]:.3f} $\\pm$ {sem[limit_idx][0]:.3f} \\\\")
        print(r"\hline                                                ")

    path = pathlib.Path(f"./out/rewards_{VALID_ENV}.png")
    path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(path)
    # fig.savefig(path, metadata={'Comment': f"Generated from runs: {RUN_NAMES}"})
    print(f"Figure saved to { path.absolute()}")

    fig_leg, ax_leg = plt.subplots(1, 1, figsize=(5, 1), dpi=600)
    ax_leg.legend(*ax.get_legend_handles_labels(), loc='center', ncol=4)
    ax_leg.axis('off')
    fig_leg.savefig('./out/rewards_legend.png')
    # fig_leg.savefig('./out/rewards_legend.pdf')


if __name__ == "__main__":
    main()
