import re
import wandb 
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np


entity = "coder66-lab"
project = "diagonal-net-loss-trend"

api = wandb.Api()
runs = api.runs(f"{entity}/{project}")

pat = re.compile(
    r"^(SGD|Adam)-N(\d+)-D(\d+)-K(\d+)-LR(\d+(?:\.\d+)?)-DT(\d+(?:\.\d+)?)$"
)

valid_d, valid_k, valid_delta = 10000, 50, 0.5
valid_n = list(range(50, 500, 10)) + list(range(500, 1050, 50))
valid_lr = [5e-3, 1e-2, 5e-2, 1e-1]
valid_beta2 = [0.95, 0.999]
valid_beta1 = [0.9]
results = []

for run in runs:
    match = pat.match(run.name)
    if match is None:
        continue
    
    result = {
        "opt": match.group(1),
        "n": int(match.group(2)),
        "d": int(match.group(3)),
        "k": int(match.group(4)),
        "lr": float(match.group(5)),
        "delta": float(match.group(6)),
    }
    if result["n"] not in valid_n \
        or result["lr"] not in valid_lr \
        or result["d"] != valid_d \
        or result["k"] != valid_k \
        or result["delta"] != valid_delta:
        continue

    print(list(run.summary.keys()))
    if "eval_test/loss" in run.summary:
        result["loss"] = run.summary["eval_test/loss"]
    else:
        raise ValueError(f"Missing eval_test/loss in {run.name}")
    
    if "total_norm" in run.summary:
        result["total_norm"] = run.summary["total_norm"]
    else:
        raise ValueError(f"Missing total_norm in {run.name}")

    if result["opt"] == "Adam":
        result["beta1"] = run.config["beta1"]
        result["beta2"] = run.config["beta2"]

    print("result:", result)
    results.append(result)


print(f"Valid runs: {len(results)}")
df = pd.DataFrame(results)
df["beta1"] = df["beta1"].fillna(0.0)
df["beta2"] = df["beta2"].fillna(0.0)
df.to_csv("results.csv", index=False)

