# %%

import argparse
from datetime import datetime
from pathlib import Path

import numpy as np
import wandb
from matplotlib import pyplot as plt

import learned_planners.interp.plot  # noqa
from learned_planners import IS_NOTEBOOK, LP_DIR, get_default_args

parser = argparse.ArgumentParser()
parser.add_argument("--output_base_path", type=str, default="/training/icml-plots/")

if IS_NOTEBOOK:
    args = get_default_args(parser)
    # add any custom here when running as notebook
    args.output_base_path = "/tmp/icml-plots/"
else:
    args = parser.parse_args()

base_path = Path(args.output_base_path)
api = wandb.Api()

project = "learned-planners"
groups = ["devbox"]
global_f1 = 84.5

metric_names = ["precision", "recall", "f1"]


def download_learning_curves(groups):
    query = {
        "$and": [
            {"$or": [{"group": g} for g in groups]},
            {"created_at": {"$gt": datetime(2025, 3, 30).isoformat()}},
        ]
    }
    runs = api.runs(project, query)
    print(len(runs))
    run_names = {}
    metrics = {}
    for run in runs:
        cfg = run.config
        # check if test/l-\d+_c-all_ds-boxes_future_direction_map_mpg-False/f1 is in the name and pick out the l-<number>
        layer = None
        try:
            layer = cfg["cmd"]["train_on"]["layer"]
        except KeyError:
            continue
        if layer is None or layer < 0:
            continue

        run_names[layer] = run.name
        run_metrics = []
        for name in metric_names:
            run_metrics.append(run.summary[f"test/l-{layer}_c-all_ds-boxes_future_direction_map_mpg-False/{name}"])
        metrics[layer] = np.array(run_metrics)
    return run_names, metrics


run_names, metrics = download_learning_curves(groups)

metrics = [(k, v) for k, v in sorted(metrics.items(), key=lambda item: item[0])]
x = [k for k, v in metrics]
y = np.array([100 * v for k, v in metrics])
plt.figure(figsize=(2.0, 1.6))
plt.plot(x, y[:, 2], marker="o", markersize=4)
plt.axhline(global_f1, linestyle="--", color="C1", label="All layers")
plt.xticks(x)
plt.grid()
plt.legend()

plt.xlabel("Layer")
plt.ylabel(r"% F1 score")
plt.savefig(LP_DIR / "resnet_plots" / "selected_plots" / "layer_wise_probe_f1.pdf", bbox_inches="tight")
plt.show()

# %%
