import re
import wandb 
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np


entity = "coder66-lab"
project = "diagonal-net-loss-AdamE"

api = wandb.Api()
runs = api.runs(f"{entity}/{project}")

pat_AdamE = re.compile(
    r"^AdamE-N(\d+)-D(\d+)-K(\d+)-LR(\d+(?:\.\d+)?)-DT(\d+(?:\.\d+)?)-E(\d+(?:\.\d+)?)-test$"
)

valid_d, valid_k, valid_delta = 10000, 50, 0.5
valid_n = list(range(300, 500, 10)) + list(range(500, 1050, 50))
valid_lr = [1e-2]
valid_beta2 = [0.999]
valid_beta1 = [0.9]
valid_exponent = [0.25, 0.5, 0.75, 1]
results = []

for run in runs:
    match = pat_AdamE.match(run.name)

    if match is None:
        print(f"Skipping {run.name}")
        continue

    result = {
        "opt": "AdamE",
        "n": int(match.group(1)),
        "d": int(match.group(2)),
        "k": int(match.group(3)),
        "lr": float(match.group(4)),
        "delta": float(match.group(5)),
        "exponent": float(match.group(6)),
    }
    
    # if result["n"] not in valid_n \
    #     or result["lr"] not in valid_lr \
    #     or result["d"] != valid_d \
    #     or result["k"] != valid_k \
    #     or result["delta"] != valid_delta:
    #     continue
    
    if "eval_test/loss" in run.summary:
        result["loss"] = run.summary["eval_test/loss"]
    else:
        raise ValueError(f"Missing eval_test/loss in {run.name}")
    if "total_norm" in run.summary:
        result["total_norm"] = run.summary["total_norm"]
    else:
        raise ValueError(f"Missing total_norm in {run.name}")
    
    result["beta1"] = run.config["beta1"]
    result["beta2"] = run.config["beta2"]
    
    print("result:", result)
    results.append(result)


print(f"Valid runs of AdamE: {len(results)}")
df = pd.DataFrame(results)
df.to_csv("results_AdamE.csv", index=False)
