import json
from pathlib import Path

import wandb

# %%

# Initialize wandb API
api = wandb.Api()

# Define the project, groups, and metrics
project = "lp-cleanba"
groups = ["061-pfinal2-drc11", "061-pfinal2"]
metrics = [
    s.format(i=i)
    for i in range(33)
    for s in [
        "test_unfiltered/{i:02d}_episode_successes",
        "test_unfiltered/{i:02d}_episode_returns",
        "valid_medium/{i:02d}_episode_successes",
        "valid_medium/{i:02d}_episode_returns",
    ]
]


runs = api.runs(project, {"$or": [{"group": g} for g in groups]})

available_columns = set(runs[0].history(1).columns)

metrics = [m for m in metrics if m in available_columns]

all_fetch = []
for run in runs:
    hist = run.history(keys=metrics)
    all_fetch.append((dict(name=run.name, id=run.id, cfg=run.config), hist))

# %%


def arch_from_cfg(cfg):
    if "ConvLSTM" in cfg["net"]["_type_"]:
        arch = f"drc_{cfg['net']['repeats_per_step']}{cfg['net']['n_recurrent']}"
    else:
        assert "ResNet" in cfg["net"]["_type_"]
        arch = "resnet"
    return arch


out_dir = Path("learning_curves")
out_dir.mkdir(exist_ok=True)

for the_dict, hist in all_fetch:
    cfg = the_dict["cfg"]
    arch = arch_from_cfg(cfg)

    this_dir = out_dir / arch
    this_dir.mkdir(exist_ok=True)

    _id = the_dict["id"]
    with (this_dir / f"{_id}.cfg.json").open("w") as f:
        json.dump(cfg, f, indent=2)

    hist = hist.set_index(hist["_step"].astype(int))
    del hist["_step"]
    hist.index.name = "step"
    hist.to_csv((this_dir / f"{_id}.test.csv"))


# %% Training curves


train_metrics = [
    "losses/entropy",
    "losses/loss",
    "losses/policy_loss",
    "losses/value_loss",
    "charts/0/avg_episode_returns",
    "charts/0/avg_episode_success",
    "param_rms/total",
    "grad_rms/total",
    "stats/training_time",
    "var_explained",
]

train_fetch = []

LEN = 10000
for run in runs:
    hist = run.history(LEN, keys=train_metrics)
    assert len(hist) == LEN

    arch = arch_from_cfg(run.config)
    this_dir = out_dir / arch

    hist = hist.set_index(hist["_step"].astype(int))
    del hist["_step"]
    hist.index.name = "step"
    _id = run.id
    hist.to_csv((this_dir / f"{_id}.train.csv"))
