#!/usr/bin/env python3
"""
generate_alignment_tables.py

Scan an experiment directory of runs (named run_{run_id}_graph_compare_regression_mlp_transformer_runs),
extract the tail-averaged Hebbian alignment for a specified layer, and for each model type in CONFIG_MODELS
produce a table whose rows are optimizers and whose columns are weight-decay values, showing mean±std.
"""

from pathlib import Path
import json
from typing import Optional, Union
import numpy as np
import pandas as pd

# ── CONFIGURATION ────────────────────────────────────────────────────────────
# point this at the parent "results" folder containing all your run_* directories
CONFIG_EXP_DIR      = Path("results")
CONFIG_MODELS       = ["regression-mlp","transformer"] # 
CONFIG_TAIL         = 300                       # how many final steps to average
CONFIG_METRIC       = "L1"                      # choose from "L1","L2","L3"
CONFIG_OUTPUT_PREFX = None                      # e.g. Path("alignment") or None to print
FILE_NAME           = "run_*__path__"
# ──────────────────────────────────────────────────────────────────────────────


def _parse_num(x: str) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None


def _get_param(run_dir: Path, key: str) -> Optional[Union[float, str]]:
    cfg = run_dir / "metrics" / "config.json"
    if not cfg.exists():
        print("run_dir",run_dir,"key",key)
        print("no config")
        return None
    try:
        data = json.loads(cfg.read_text())
        val = data.get(key)
        if isinstance(val, (int, float)):
            return float(val)
        if isinstance(val, str):
            num = _parse_num(val)
            return num if num is not None else val
    except Exception:
        pass
    print("error")
    return None


def _tail_alignment(run_dir: Path, tail: int, metric: str) -> Optional[float]:
    f = run_dir / "metrics" / "frac_pos_alignment.json"
    if not f.exists():
        return None
    try:
        data = json.loads(f.read_text())
    except json.JSONDecodeError:
        return None
    if not data:
        return None

    if isinstance(data[0], dict):
        vals = [d.get("alignments", {}).get(metric) for d in data]
    else:
        vals = data

    vals = [v for v in vals if v is not None]
    if not vals:
        return None
    return float(np.mean(vals[-tail:]))


def collect_table(
    exp_dir: Path,
    tail: int,
    metric: str,
    model_filter: Optional[str] = None,
) -> pd.DataFrame:
    """
    Build a table of mean±std tail-alignments for runs in exp_dir whose
    'model' config equals model_filter (if given). Rows: optimizer,
    Columns: weight_decay.
    """
    # find all run directories matching your naming scheme:
    root_runs = sorted(exp_dir.glob(FILE_NAME))
    print("total roots",len(root_runs))
    # now flatten: each sub-directory is an actual run
    runs = []
    for rr in root_runs:
        for sub in rr.iterdir():
            if sub.is_dir():
                runs.append(sub)
    runs = sorted(runs)
    print("total runs",len(runs))
    #runs = sorted(exp_dir.glob("run_*_graph_compare_regression_mlp_transformer_runs"))
    records = []
    wds = set()
    opts = set()

    for run in runs:
        model = _get_param(run, "model")
        if model_filter is not None and model != model_filter:
            continue
        opt = _get_param(run, "optimizer")
        wd  = _get_param(run, "weight_decay")
        if opt is None or wd is None:
            print("missing opt or wd", opt, wd)
            continue
        opt_s = str(opt)
        wd_f  = float(wd)

        align = _tail_alignment(run, tail, metric)
        if align is None:
            print("None found",opt_s, wd_f, align)
            continue


        # if wd_f==0.05:
        #     continue

        
        records.append((opt_s, wd_f, align))
        opts.add(opt_s)
        wds.add(wd_f)
    print("total_records", len(records))
    if not records:
        raise RuntimeError(f"No runs for model={model_filter} in {exp_dir}")

    wd_list  = sorted(wds)
    opt_list = sorted(opts)

    df = pd.DataFrame(records, columns=["optimizer", "weight_decay", "alignment"])
    grp = df.groupby(["optimizer", "weight_decay"])["alignment"]
    mean_tbl = (
        grp.mean()
        .unstack(fill_value=np.nan)
        .reindex(index=opt_list, columns=wd_list)
    )
    std_tbl = (
        grp.std()
        .unstack(fill_value=np.nan)
        .reindex(index=opt_list, columns=wd_list)
    )

    # format as "mean±std"
    col_names = [f"{wd:g}" for wd in wd_list]
    formatted = pd.DataFrame(index=opt_list, columns=col_names, dtype=object)
    for wd in wd_list:
        wd_str = f"{wd:g}"
        for opt in opt_list:
            m = mean_tbl.loc[opt, wd]
            if np.isnan(m):
                formatted.at[opt, wd_str] = ""
            else:
                s = std_tbl.loc[opt, wd]
                if np.isnan(s):
                    s = 0.0
                formatted.at[opt, wd_str] = f"{m:.2f}±{s:.2f}"

    return formatted


def main():
    for model in CONFIG_MODELS:
        try:
            tbl = collect_table(CONFIG_EXP_DIR, CONFIG_TAIL, CONFIG_METRIC, model)
            header = f"### Model: {model}"
            if CONFIG_OUTPUT_PREFX:
                out_csv = Path(f"{CONFIG_OUTPUT_PREFX}_{model}.csv")
                tbl.to_csv(out_csv)
                print(f"{header}\nSaved to {out_csv}\n")
            else:
                print(header)
                print(tbl.to_markdown(tablefmt="github"))
                print()
        except:
            print("oof")

if __name__ == "__main__":
    main()