# %%
from pathlib import Path
from warnings import simplefilter

import pandas as pd
import wandb

simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

# %%
eval_runs = False

# Initialize wandb API
api = wandb.Api()

# Define the project, groups, and metrics
project = "lp-cleanba"
if eval_runs:
    groups = ["074-evaluate-fig1-error-bars"]
else:
    groups = ["061-pfinal2-drc11", "061-pfinal2"]


metrics = [
    s.format(i=i)
    for i in range(33)
    for s in [
        "test_unfiltered/{i:02d}_episode_successes",
        "test_unfiltered/{i:02d}_episode_returns",
        "valid_medium/{i:02d}_episode_successes",
        "valid_medium/{i:02d}_episode_returns",
        "hard/{i:02d}_episode_successes",
        "hard/{i:02d}_episode_returns",
    ]
]


runs = api.runs(project, {"$or": [{"group": g} for g in groups]})

available_columns = set(runs[0].history(1).columns)

metrics = [m for m in metrics if m in available_columns]

all_fetch = []
for run in runs:
    hist = run.history(keys=metrics)
    all_fetch.append((dict(name=run.name, id=run.id, cfg=run.config), hist))

# %%


def arch_from_cfg(cfg):
    if "ConvLSTM" in cfg["net"]["_type_"]:
        arch = f"drc_{cfg['net']['repeats_per_step']}{cfg['net']['n_recurrent']}"
    else:
        assert "ResNet" in cfg["net"]["_type_"]
        arch = "resnet"
    return arch


out_dir = Path("data" if eval_runs else "learning_curves")
out_dir.mkdir(exist_ok=True)

archs_to_hist: dict[str, pd.DataFrame] = {}
archs_count = {}

for the_dict, hist in all_fetch:
    assert isinstance(hist, pd.DataFrame)
    cfg = the_dict["cfg"]
    arch = arch_from_cfg(cfg)

    hist = hist.set_index(hist["_step"].astype(int))
    del hist["_step"]
    hist.index.name = "step"
    if arch not in archs_to_hist:
        archs_to_hist[arch] = hist
        archs_count[arch] = 1
    else:
        archs_to_hist[arch] = archs_to_hist[arch].join(hist, how="outer", rsuffix=f"_{archs_count[arch]}")
        archs_count[arch] += 1

for arch, hist in archs_to_hist.items():
    hist = hist.rename(columns={c: f"{c}_0" for c in metrics})
    for m in metrics:
        col_names = [f"{m}{suffix}" for suffix in [f"_{i}" for i in range(archs_count[arch])]]
        min_metric = hist[col_names].min(axis=1)
        max_metric = hist[col_names].max(axis=1)
        mean_metric = hist[col_names].mean(axis=1)
        hist[f"{m}_min"] = min_metric
        hist[f"{m}_max"] = max_metric
        hist[f"{m}_mean"] = mean_metric

    hist.to_csv(out_dir / f"{arch}.csv")

# %% Training curves

if not eval_runs:
    train_metrics = [
        "losses/entropy",
        "losses/loss",
        "losses/policy_loss",
        "losses/value_loss",
        "charts/0/avg_episode_returns",
        "charts/0/avg_episode_success",
        "param_rms/total",
        "grad_rms/total",
        "stats/training_time",
        "var_explained",
    ]

    train_fetch = []

    LEN = 10000
    for run in runs:
        hist = run.history(LEN, keys=train_metrics)
        assert len(hist) == LEN

        arch = arch_from_cfg(run.config)
        this_dir = out_dir / arch
        try:
            hist = hist.set_index(hist["_step"].astype(int))
        except KeyError:
            print(run.id, run.name)
        del hist["_step"]
        hist.index.name = "step"
        _id = run.id
        hist.to_csv((this_dir / f"{_id}.train.csv"))
