# %%

import argparse
from datetime import datetime
from pathlib import Path

import numpy as np
import wandb
from matplotlib import pyplot as plt

import learned_planners.interp.plot  # noqa
from learned_planners import IS_NOTEBOOK, ON_CLUSTER, get_default_args

parser = argparse.ArgumentParser()
parser.add_argument("--output_base_path", type=str, default="/training/icml-plots/")

if IS_NOTEBOOK:
    args = get_default_args(parser)
    # add any custom here when running as notebook
    args.output_base_path = "/tmp/icml-plots/"
else:
    args = parser.parse_args()

base_path = Path(args.output_base_path)

api = wandb.Api()

project = "lp-cleanba"
# groups = ["071-noop-training-check", "072-evaluate-fig1-for-noop-models", "073-evaluate-noops-cycles"]
groups = ["073-evaluate-noop-cycles-new"]
metrics = [
    "valid_medium/00_cycles_steps_per_eps_incl_noops",
    "valid_medium/00_episode_num_noops_per_eps",
    "valid_medium/00_episode_successes",
]
pretty_metrics = ["Cycles steps\n(incl. NOOPs)", "NOOPs", r"% solved"]

ent_coef = 0.01
reward_noops = np.array([0, 0.01, 0.03, 0.05, 0.09])
penalty = 0.1
noop_ratio = (penalty - reward_noops) / penalty


def run_dir(run):
    created_at = datetime.strptime(run.metadata["startedAt"], "%Y-%m-%dT%H:%M:%S.%f").strftime("%Y%m%d_%H%M%S")
    return f"/training/cleanba/{run.group}/wandb/run-{created_at}-{run.id}/"


def run_latest_dir(run=None, run_group=None, run_id=None):
    if run_group is None or run_id is None:
        assert run is not None
        run_group, run_id = run.group, run.id
    base_path = Path(f"/training/cleanba/{run_group}/wandb/")
    search_dir = f"run-*-{run_id}/"
    found_dirs = list(base_path.glob(search_dir))
    if len(found_dirs) == 0:
        raise FileNotFoundError(f"No run found for {run_group} {run_id}")
    found_dirs.sort()
    return found_dirs[-1]


def download_learning_curves(reward_noops, groups):
    query = {
        "$and": [
            {"$or": [{"group": g} for g in groups]},
            {"config.train_env.reward_noop": {"$in": reward_noops.tolist()}},
            {"config.loss.ent_coef": ent_coef},
            {"created_at": {"$lt": datetime(2025, 1, 27).isoformat()}},
        ]
    }
    runs = api.runs(project, query)
    print(len(runs))
    run_hists = [[] for _ in reward_noops]
    for run in runs:
        cfg = run.config
        if not all(m in run.summary for m in metrics):
            continue
        if "061-pfinal2" in cfg.get("load_other_run", ""):
            continue
        try:
            idx = np.where(reward_noops == cfg["train_env"]["reward_noop"])[0][0]
        except IndexError:
            continue
        hist = run.history(keys=metrics)
        hist = hist.set_index(hist["_step"].astype(int))
        del hist["_step"]
        hist.index.name = "step"
        info_dict = dict(name=run.name, id=run.id, group=run.group, reward_noop=cfg["train_env"]["reward_noop"])
        if ON_CLUSTER:
            info_dict["run_dir"] = run_latest_dir(run)
        else:
            info_dict["run_dir"] = run_dir(run)
        run_hists[idx].append((info_dict, hist))
    return run_hists


run_hists = download_learning_curves(reward_noops, groups)

print("Run dirs:")
for i, reward_noop in enumerate(reward_noops):
    for info_dict, _ in run_hists[i]:
        print(info_dict["run_dir"])
# %%
# Find step where the solve rate reaches 40%
solve_rate_threshold = 0.4
steps = [[] for _ in reward_noops]

for i, reward_noop in enumerate(reward_noops):
    for info_dict, hist in run_hists[i]:
        hist_at_threshold = hist.loc[hist["valid_medium/00_episode_successes"] >= solve_rate_threshold]
        earliest_step = hist_at_threshold.index.min()
        print(earliest_step)
        steps[i].append(earliest_step)
metric_values = np.array(
    [
        [
            [hist.loc[steps[noop_i][run_i], m].item() for run_i, (_, hist) in enumerate(hists_for_noop)]
            for noop_i, hists_for_noop in enumerate(run_hists)
        ]
        for m in metrics
    ]
)
info_dicts = [[info_dict for info_dict, _ in run_hist] for run_hist in run_hists]


# %%

step = 8
metric_values = np.array(
    [[[hist.loc[(hist.index // (10**8) == step), m].item() for _, hist in hists] for hists in run_hists] for m in metrics]
)
info_dicts = [[info_dict for info_dict, _ in run_hist] for run_hist in run_hists]
np.save(base_path / "noop_reward.npy", metric_values)
np.save(base_path / "noop_ratio.npy", noop_ratio)
# %% Min/Max
# divide in ratio 25/75
full_col = True
remove_last, legend_to_right = True, False
if full_col:
    pretty_metrics = ["Cycle steps (incl. NOOPs)", "NOOPs", r"% solved"]
else:
    pretty_metrics = ["Cycle\nsteps", "NOOPs", r"% solved"]
fig, axs = plt.subplots(2, 1, figsize=(2.0 if not full_col else 2.8, 1.6), sharex=True, gridspec_kw={"height_ratios": [2, 1]})
# ax2 = ax1.twinx()
for i, m in enumerate(metrics):
    metric_value = metric_values[i]
    if m == "valid_medium/00_episode_successes":
        # plot on a 2nd y-axis
        ax = axs[1]
        metric_value = metric_value * 100
        ax.set_ylabel(r"% solved")
        ax.set_xlabel("NOOP to move-action penalty ratio")
        ax.set_yticks([60, 70])
    else:
        ax = axs[0]
        ax.set_ylabel("steps")
        ax.set_yticks([0, 1, 2, 3])

    metric_value = metric_value[:-1] if remove_last else metric_value
    x = noop_ratio[:-1] if remove_last else noop_ratio
    mean = np.mean(metric_value, axis=1)
    ax.plot(x, mean, label=pretty_metrics[i], color=f"C{i}")
    min_vals = np.min(metric_value, axis=1)
    max_vals = np.max(metric_value, axis=1)
    ax.fill_between(x, min_vals, max_vals, alpha=0.2, color=f"C{i}")
    # ax.grid(True)
axs[0].set_yticks([0, 1, 2, 3])
axs[0].set_xticks(noop_ratio)
axs[0].set_yticklabels([f"   {int(t)}" for t in axs[0].get_yticks()])

if legend_to_right:
    axs[0].legend(bbox_to_anchor=(1.0, 1.0), handlelength=1.0)
else:
    axs[0].legend(loc="upper center", bbox_to_anchor=(0.5, 1.6), ncol=2, handlelength=1.0, columnspacing=0.5)

plt.savefig(
    base_path / f"noop_reward_{'side_legend' if legend_to_right and not full_col else ''}.pdf",
    bbox_inches="tight" if full_col else None,
)
plt.show()
# %% Planning emergence
upto_steps = 810 * 10**6
upto_idx = np.where(run_hists[0][0][1].index <= upto_steps)[0][-1]
metric = "valid_medium/00_cycles_steps_per_eps_incl_noops"
metric_values = np.array([hist.loc[:, metric].tolist()[:upto_idx] for _, hist in run_hists[0]])

fig, ax = plt.subplots(figsize=(2.5, 1.6))
mean = np.mean(metric_values, axis=0)
min_vals = np.min(metric_values, axis=0)
max_vals = np.max(metric_values, axis=0)
ax.plot(x, mean, label="Mean", color="C0")
ax.fill_between(x, min_vals, max_vals, alpha=0.2, color="C0")
ax.set_xlabel("Episode timesteps")
ax.set_ylabel("Cycles steps")
ax.grid()
plt.savefig(
    base_path / "cycle_steps_vs_env_steps.pdf",
    bbox_inches="tight",
)
