import functools

import pathlib
import warnings
from collections import defaultdict
from multiprocessing import Pool, RLock, current_process

import gym
import matplotlib.colors
import numpy as np
import scipy.stats
import torch
from statsmodels.stats.weightstats import DescrStatsW
import wandb
from matplotlib import pyplot as plt, ticker
from tqdm.auto import tqdm, trange

from collector.strategy import record_episode
from helpers import init_envs, init_model
from policy.base import BasePolicy


plt.rcParams['text.usetex'] = True

X_METRIC = "episode"
X_LABEL = "Episode"
X_METRIC_LIMIT = 10_000

VALID_ENV = "len300"
Y_METRIC = f"validation/{VALID_ENV}/reward_mean"
Y_LABEL = "Reward"

VALID_EPS = 1000
AGGREGATION_FUNC = np.mean

GROUPBY_CONFIG_KEY = "memory_type"

WANDB_PROJECT = "ENTER WANDB PROJECT HERE"
WANDB_ENTITY = "ENTER WANDB USERNAME HERE"

RUNS_NAMES = [
    "run-name-1",
    "run-name-2",
    "run-name-3",
    "run-name-4",
]

# print(matplotlib.colors.TABLEAU_COLORS)

COLOR_MAP = {
    "sith": "tab:blue",
    "F_sub": "tab:orange",
    "F": "tab:green",
    "sith_sub_sum": "tab:cyan",
    "sith_sub_nosum": "tab:purple",
    "rnn": "tab:brown",
    # "lstm": "tab:cyan",
    "gru": "tab:pink",
    "rnn_frozen": "tab:gray",
    # "lstm_frozen": "tab:gray",
    "gru_frozen": "tab:olive",
}

# COLOR_MAP = {
#     "gru_frozen": "tab:blue",
#     "rnn_frozen": "tab:orange",
#     "rnn": "tab:green",
#     "gru": "tab:pink",
# }
# can use a #00FFFF hex, or any named color at https://matplotlib.org/stable/gallery/color/named_colors.html

MARKER_MAP = {
    "sith": "o",
    "F_sub": "^",
    "F": "v",
    "sith_sub_sum": "2",
    "sith_sub_nosum": "<",
    "rnn": ">",
    "lstm": "D",
    "gru": "s",
    "rnn_frozen": "*",
    "gru_frozen": "1",
}

GROUP_LABEL_MAP = {
    "sith": r"$\tilde{f}$",
    "F_sub": r"$F_{sub}$",
    "F": "$F$",
    "sith_sub_sum": r"$\tilde{f}_{subsum}$",
    "sith_sub_nosum": r"$\tilde{f}_{sub}$",
    "rnn": "RNN",
    "lstm": "LSTM",
    "gru": "GRU",
    "rnn_frozen": r"RNN$_{frozen}$",
    "lstm_frozen": r"LSTM$_{frozen}$",
    "gru_frozen": r"GRU$_{frozen}$"
}


def validation_no_log(valid_env: gym.Env, policy: BasePolicy, n_episodes: int, pbar_pos: int, pbar_desc: str) -> np.array:
    valid_rewards = []

    for _ in trange(n_episodes, position=pbar_pos, desc=pbar_desc, leave=True):
        trajectory = record_episode(valid_env, policy, explore=False)
        valid_rewards.append(trajectory.total_reward)

    # valid_reward_mean = sum(valid_rewards) / len(valid_rewards)
    valid_rewards = np.array(valid_rewards)

    return valid_rewards


def worker_gen_metrics(run_name: str):
    pbar_pos = current_process()._identity[0]

    api = wandb.Api()
    runs = api.runs(f'{WANDB_ENTITY}/{WANDB_PROJECT}', filters={"display_name": run_name})

    # print(f"{len(runs)} run(s) found.")
    assert len(runs) == 1

    run = runs[0]

    config = dict(run.config)

    # print()
    # pprint(config)
    # print()

    custom_valid_envs = {"len3000": {"seq_len": 3000, "_max_episode_steps": 3200}}
    train_env, valid_envs_dict = init_envs(config, custom_valid_envs)
    model = init_model(config, train_env)

    # print()
    # print(model.net)
    # print()
    #
    # total_params = 0
    # for name, p in model.named_parameters():
    #     print(name, p.numel())
    #     total_params += p.numel()
    # print(f"Total Params: ", total_params)
    # print()

    tqdm.write(f"{run.name} ({run.url})")

    max_episode = min(run.summary_metrics["episode"], X_METRIC_LIMIT)
    checkpoint_freq = run.config['log_checkpoint_freq']

    history = []

    for episode in np.arange(0, max_episode+1, checkpoint_freq):
        path = pathlib.Path(f"./out/generated_rewards/{VALID_ENV}/{run_name}_ep{episode}.npy")
        path.parent.mkdir(parents=True, exist_ok=True)

        if path.exists():
            tqdm.write(f"File already exists, skipping: {path.absolute()}")
            valid_rewards = np.load(path)
        else:
            file = run.file(f"checkpoints/{episode}.pt")
            file.download(root=f"./tmp/{run.name}", replace=True)
            model.load_state_dict(torch.load(f"./tmp/{run.name}/{file.name}")["model_state_dict"])
            model.eval()

            valid_env = valid_envs_dict[VALID_ENV]

            valid_rewards = validation_no_log(valid_env, model, VALID_EPS, pbar_pos, f"{run.name} ({episode})")

            # try saving
            np.save(path, valid_rewards)
            tqdm.write(f"Saved generated rewards to: {path.absolute()}")

        avg_reward = AGGREGATION_FUNC(valid_rewards).item()

        history.append({"episode": episode.item(), f"validation/{VALID_ENV}/reward_mean": avg_reward})

    return history


def main():
    api = wandb.Api()
    runs = api.runs(f'{WANDB_ENTITY}/{WANDB_PROJECT}', filters={
        "display_name": {"$in": RUN_NAMES}
    })

    print(f"{len(runs)} runs found.")

    if len(RUN_NAMES) != len(runs):
        ret_names = [run.name for run in runs]
        print(f"Missing runs: {set(RUN_NAMES) ^ set(ret_names)}")
        exit(1)

    vals: defaultdict[str, defaultdict[int, list[float]]] = defaultdict(lambda: defaultdict(list))
    metric_names = [X_METRIC, Y_METRIC]

    tqdm.set_lock(RLock())
    with Pool(initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),)) as pool:
        run_names = [run.name for run in runs]
        histories = pool.map(worker_gen_metrics, run_names)

    for run, history in zip(runs, histories):
        print(f"{run.name} ({run.url})")

        # path = pathlib.Path(f"./tmp/{run.name}_eps{VALID_EPS}_{AGGREGRATION_FUNC.__name__}.npy")
        # if path.exists():
        #     history = np.load(path, allow_pickle=True)
        #     print(f"Loaded generated rewards from: {path.absolute()}")
        print(f"{len(history)} entries, from episode {history[0][X_METRIC]} to {history[-1][X_METRIC]}.")
        #     print()
        # else:
        #     raise RuntimeError(f"Something went wrong. Missing file: {path.absolute()}")

        run_group = run.config[GROUPBY_CONFIG_KEY]

        for row in history:
            if row[X_METRIC] > X_METRIC_LIMIT:
                continue

            vals[run_group][row[X_METRIC]].append(row[Y_METRIC])

    print()

    x_vals = {}
    y_vals = {}

    for group, group_dict in vals.items():
        lens = [len(x) for x in list(group_dict.values())]
        assert len(set(lens)) == 1, "values missing for some runs"
        x_vals[group] = np.array(list(group_dict.keys()))
        y_vals[group] = np.array(list(group_dict.values()))
        print(f"{group:>6}: {y_vals[group].shape[1]} runs {COLOR_MAP[group], MARKER_MAP[group]}")

    print()

    # SEM
    fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=600, constrained_layout=True)
    ax.set_xlabel(X_LABEL, fontsize=10)
    ax.set_xlim(0, X_METRIC_LIMIT)

    ax.xaxis.set_minor_locator(ticker.MultipleLocator(50_000))
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: "0" if x == 0 else f"{int(x / 1000)}k"))

    # ax.yaxis.set_major_locator(ticker.MultipleLocator(5))

    ax.set_ylabel(Y_LABEL, fontsize=10)
    ax.set_ylim(0, 10.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(which="both", axis="both", direction="in")

    for group in x_vals.keys():
        means = y_vals[group].mean(axis=1)
        sem = scipy.stats.sem(y_vals[group], axis=1, ddof=0)
        lower, upper = means - sem, means + sem

        ax.fill_between(x_vals[group], lower, upper, alpha=0.3, facecolor=COLOR_MAP[group])
        ax.plot(x_vals[group],
                means,
                color=COLOR_MAP[group],
                linewidth=1,
                markersize=2,
                marker=None,
                # marker=MARKER_MAP[group],
                label=GROUP_LABEL_MAP[group],
                )

        limit_idx = np.where(x_vals[group] == X_METRIC_LIMIT)[0]
        print(f"{GROUP_LABEL_MAP[group]:<20} & {means[limit_idx][0]:.3f} $\\pm$ {sem[limit_idx][0]:.3f} \\\\")
        print(r"\hline                                                ")

    path = pathlib.Path(f"./out/rewards_{VALID_ENV}.png")
    path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(path)
    # fig.savefig(path, metadata={'Comment': f"Generated from runs: {RUN_NAMES}"})
    print(f"Figure saved to { path.absolute()}")

    fig_leg, ax_leg = plt.subplots(1, 1, figsize=(5, 0.75), dpi=600)
    ax_leg.legend(*ax.get_legend_handles_labels(), loc='center', ncol=4, frameon=False)
    ax_leg.axis('off')
    fig_leg.savefig('./out/rewards_legend.png')
    # fig_leg.savefig('./out/rewards_legend.pdf')


if __name__ == "__main__":
    main()
