from dataclasses import dataclass, field
from pathlib import Path
from itertools import product
from string import ascii_lowercase, ascii_uppercase
from collections import defaultdict
from loguru import logger
from tqdm.auto import tqdm
import json, csv, random, numpy as np, torch, matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments, TrainerCallback

torch.backends.cuda.matmul.allow_tf32 = True

# Each tuple is either (E, L, H, IM) or ("name", E, L, H, IM)
# IM is fixed at 4; we enumerate HEADS for each size bucket
HEADS = (1, 2, 4, 8, 16, 32)
BASE_100M: dict[int, int] = {
    2: 6656,
    4: 3328,
    6: 2240,
    8: 1664,
    10: 1344,
    12: 1120,
    14: 960,
    16: 832,
    18: 736,
    20: 672,
    22: 608,
    24: 544,
    26: 512,
    28: 480,
    30: 448,
    32: 416,
}
BASE_1M: dict[int, int] = {2: 384, 4: 192, 6: 128, 8: 96, 12: 64}
BASE_10M: dict[int, int] = {
    2: 1984,
    4: 992,
    6: 672,
    8: 512,
    10: 384,
    12: 320,
    14: 288,
    16: 256,
    18: 224,
    20: 192,
    22: 192,
}
def make_variants(prefix: str, base: dict[int, int]) -> list[tuple[str, int, int, int, int]]:
    return [(f"{prefix}_E{E}_L{L}_H{H}_IM4", E, L, H, 4) for L, E in base.items() for H in HEADS]
BASE_1B: dict[int, int] = {2: 6496, 4: 4576, 6: 3744, 8: 3232, 10: 2912, 12: 2656}
VARIANTS = (
    # make_variants("1B", BASE_1B)
    # make_variants("100M", BASE_100M)
    make_variants("10M", BASE_10M)
    # + make_variants("1M", BASE_1M)
)
# reverse to get larger buckets first
VARIANTS = VARIANTS[::-1]

def _assert_consistent(base: dict[int, int], name: str) -> None:
    prods = [L*E for L, E in base.items()]
    if prods:
        mn, mx = min(prods), max(prods)
        assert mx <= 1.10*mn, f"{name} L*E spread >10%: {mn}..{mx}"

_assert_consistent(BASE_1M, "1M"); _assert_consistent(BASE_10M, "10M"); _assert_consistent(BASE_100M, "100M")

LR_PER: dict[str, float] = {"1M": 9.5e-4, "10M": 5.3e-4, "100M": 3.0e-4, "1B": 1.7e-4}

# --------------------------- Config ------------------------------
@dataclass(slots=True)
class CFG:
    # data
    train_path: str = "data/paths_paper/train.txt"
    eval_path: str = "data/paths_paper/eval.txt"
    graph_path: str = "data/gt_graph_paper.txt"
    out_dir: str = "out/paths_paper_final"

    # model/training
    seed: int = 32
    upper_grams: tuple[int, ...] = (3,)
    n_positions: int = 128
    epochs: int = 16
    batch_size: int = 2048
    # batch_size: int = 1024
    eval_bs: int = 2048
    lr: float = 1e-4
    wd: float = 0.01
    warmup_ratio: float = 0.05
    bf16: bool = True
    log_steps: int = 1
    max_new_tokens: int = 32
    gradient_checkpointing: bool = True
    grad_accum_steps: int = 1
    # grad_accum_steps: int = 2

    # eval
    eval_sample_gen: int = 0     # 0 = evaluate all eval rows each epoch
    diff_bins: int = 20

# ----------------------- Tokenizer + data ------------------------
@dataclass
class Tok:
    upper_grams: tuple[int, ...]
    stoi: dict[str, int]
    itos: list[str]
    pad_id: int
    unk_id: int
    eos_id: int
    def id(self, t: str) -> int: return self.stoi.get(t, self.unk_id)

def build_tok(upper_grams: tuple[int, ...]) -> Tok:
    specials = ["<pad>", "<unk>", "<eos>"]
    punct = ["-", "[", "]", "(", ")", ":"]
    up = ["".join(p) for k in upper_grams for p in product(ascii_uppercase, repeat=k)]
    low = list(ascii_lowercase)
    vocab = specials + punct + up + low
    stoi = {t: i for i, t in enumerate(vocab)}
    return Tok(upper_grams, stoi, vocab, stoi["<pad>"], stoi["<unk>"], stoi["<eos>"])

def parse_difficulty(line: str) -> float:
    j = line.rfind("#d="); return float(line[j+3:]) if j >= 0 else 0.0

def split_query_path(s: str) -> tuple[str, str]:
    i = s.index(":"); return s[:i], s[i+1:]

def parse_qmeta(q: str) -> tuple[str, str, str, str]:
    b1, b2 = q.find("["), q.find("(")
    cut = len(q) if b1 < 0 else b1
    if b2 >= 0: cut = min(cut, b2)
    u, v = q[:cut].split("-", 1)
    inc = ""; exc = ""
    if b1 >= 0:
        e1 = q.index("]", b1); inc = q[b1+1:e1]
    if b2 >= 0:
        e2 = q.index(")", b2); exc = q[b2+1:e2]
    return u, v, inc, exc

def tokenize_query(q: str) -> list[str]:
    u, v, inc, exc = parse_qmeta(q)
    toks: list[str] = [u, "-", v]
    if inc: toks += ["[", *list(inc), "]"]
    if exc: toks += ["(", *list(exc), ")"]
    toks.append(":"); return toks

def tokenize_path(p: str, upper_grams: tuple[int, ...]) -> list[str]:
    toks: list[str] = []; i, n = 0, len(p)
    while i < n:
        j = i
        while j < n and p[j] in ascii_uppercase: j += 1
        node = p[i:j]; i = j
        if len(node) == 0 or len(node) not in upper_grams: raise ValueError(f"bad node: {node}")
        toks.append(node)
        if i >= n: break
        l = p[i]
        if l not in ascii_lowercase: raise ValueError("expected lowercase label")
        toks.append(l); i += 1
    return toks

@dataclass
class Ex:
    tokens: list[int]
    colon: int
    u: str
    v: str
    inc: set[str]
    exc: set[str]
    diff: float
    hops: int
    level: int

def parse_line(line: str, tok: Tok) -> Ex:
    line = line.strip(); diff = parse_difficulty(line); core = line.split(" #d=")[0]
    q, p = split_query_path(core)
    u, v, inc, exc = parse_qmeta(q)
    q_toks = tokenize_query(q); p_toks = tokenize_path(p, tok.upper_grams)
    hops = sum(1 for t in p_toks if len(t) == 1 and t in ascii_lowercase)
    lvl = 1 + len(inc) + len(exc)
    toks = q_toks + p_toks + ["<eos>"]; ids = [tok.id(t) for t in toks]
    return Ex(tokens=ids, colon=len(q_toks)-1, u=u, v=v, inc=set(inc), exc=set(exc), diff=diff, hops=hops, level=lvl)

class PathDS(Dataset):
    def __init__(self, lines: list[str], tok: Tok):
        self.ex: list[Ex] = [parse_line(ln, tok) for ln in lines if ln.strip()]
        self.tok = tok
        self.diffs = np.array([e.diff for e in self.ex], dtype=np.float32)
        self.hops = np.array([e.hops for e in self.ex], dtype=np.int32)
        self.levels = np.array([e.level for e in self.ex], dtype=np.int32)
        self.level_set = tuple(sorted(set(self.levels.tolist())))
    def __len__(self) -> int: return len(self.ex)
    def __getitem__(self, i: int) -> dict[str, torch.Tensor]:
        e = self.ex[i]; x = torch.tensor(e.tokens, dtype=torch.long)
        y = x.clone(); y[: e.colon + 1] = -100
        attn = torch.ones_like(x)
        return {"input_ids": x, "labels": y, "attention_mask": attn}

def pad_collate(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
    mx = max(len(b["input_ids"]) for b in batch); pad = 0
    ids = torch.stack([torch.nn.functional.pad(b["input_ids"], (0, mx - len(b["input_ids"])), value=pad) for b in batch])
    mas = torch.stack([torch.nn.functional.pad(b["attention_mask"], (0, mx - len(b["attention_mask"])), value=0) for b in batch])
    y = torch.stack([torch.nn.functional.pad(b["labels"], (0, mx - len(b["labels"])), value=-100) for b in batch])
    return {"input_ids": ids, "labels": y, "attention_mask": mas}

# ----------------------- Graph + eval utils -----------------------
@dataclass(frozen=True)
class Edge:
    u: str; l: str; v: str

def parse_edge_token(t: str) -> Edge:
    i = next(j for j, ch in enumerate(t) if ch in ascii_lowercase)
    return Edge(t[:i], t[i], t[i+1:])

def load_graph(path: Path) -> dict[tuple[str, str], set[str]]:
    rows = [ln.strip() for ln in path.read_text(encoding="utf-8").splitlines() if ln.strip()]
    adj: dict[tuple[str, str], set[str]] = defaultdict(set)
    for ln in rows:
        e = parse_edge_token(ln)
        adj[(e.u, e.l)].add(e.v); adj[(e.v, e.l)].add(e.u)
    return adj

def decode_tokens(ids: list[int], tok: Tok) -> list[str]: return [tok.itos[i] for i in ids]
def take_path_tokens(gen: list[str]) -> list[str]:
    stop = {"-", "[", "]", "(", ")", ":", "<pad>", "<unk>"}
    out: list[str] = []
    for t in gen:
        if t == "<eos>": break
        if t in stop: continue
        out.append(t)
    return out

def join_path_tokens(ts: list[str]) -> str: return "".join(ts)
def fmt_query(u: str, v: str, inc: set[str], exc: set[str]) -> str:
    ib = f"[{''.join(sorted(inc))}]" if inc else ""
    eb = f"({''.join(sorted(exc))})" if exc else ""
    return f"{u}-{v}{ib}{eb}:"

FAIL_KEYS: tuple[str, ...] = ("used_excluded","missing_included","wrong_end","non_edge","bad_start","bad_label","truncated","empty_path")

def check_path_violations(adj: dict[tuple[str, str], set[str]], u: str, v: str, inc: set[str], exc: set[str], toks: list[str]) -> tuple[bool, dict[str, int], int]:
    r = {k: 0 for k in FAIL_KEYS}; used: set[str] = set()
    if not toks: r["empty_path"] = 1; return False, r, 0
    if toks[0] != u: r["bad_start"] = 1
    x = u; hops = 0; i = 1
    while i < len(toks):
        l = toks[i]
        if l not in ascii_lowercase: r["bad_label"] = 1; return False, r, hops
        if l in exc: r["used_excluded"] = 1
        if i + 1 >= len(toks): r["truncated"] = 1; return False, r, hops
        y = toks[i+1]
        if y not in adj.get((x, l), set()): r["non_edge"] = 1; return False, r, hops
        used.add(l); x = y; hops += 1; i += 2
    if not inc.issubset(used): r["missing_included"] = 1
    if x != v: r["wrong_end"] = 1
    ok = all(v == 0 for v in r.values())
    return ok, r, hops

def is_dist() -> bool:
    return torch.distributed.is_available() and torch.distributed.is_initialized()
def rank() -> int: return torch.distributed.get_rank() if is_dist() else 0
def world() -> int: return torch.distributed.get_world_size() if is_dist() else 1
def shard_indices(n: int) -> list[int]: return list(range(rank(), n, world()))
def allreduce_sum(x: torch.Tensor) -> torch.Tensor:
    if is_dist(): torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
    return x
def allgather_list(rows: list[list | tuple]) -> list[list | tuple]:
    if not is_dist(): return rows
    buf: list[list | tuple] = [None]*world()
    torch.distributed.all_gather_object(buf, rows)
    out: list[list | tuple] = []
    for part in buf: out += part
    return out

# ----------------------- Evaluation (batched) ---------------------
def batched_generate_eval_levels(model: GPT2LMHeadModel, tok: Tok, ds: PathDS, adj: dict[tuple[str,str], set[str]], max_new: int, batch_size: int, sample_k: int, bins: np.ndarray) -> tuple[dict, list[tuple], list[tuple]]:
    model.eval(); device = next(model.parameters()).device
    idx = np.arange(len(ds))
    if sample_k and sample_k < len(ds): idx = np.random.default_rng(0).choice(len(ds), size=sample_k, replace=False)
    idx = idx[shard_indices(len(idx))]
    levels_all = sorted(set(ds.levels.tolist()))
    Lmax = max(levels_all) if levels_all else 1

    succ_overall = torch.zeros(4, device=device)  # ok,total,inc_ok,end_ok
    eos_hits = 0; new_len_sum = 0
    nb = len(bins) - 1
    bin_succ = torch.zeros(nb, device=device); bin_cnt = torch.zeros(nb, device=device)
    fail_vec = torch.zeros(len(FAIL_KEYS), device=device)
    lvl_succ = torch.zeros(Lmax+1, device=device); lvl_cnt = torch.zeros(Lmax+1, device=device)
    lvl_inc_ok = torch.zeros(Lmax+1, device=device); lvl_end_ok = torch.zeros(Lmax+1, device=device)
    lvl_bin_succ = torch.zeros(Lmax+1, nb, device=device); lvl_bin_cnt = torch.zeros(Lmax+1, nb, device=device)
    lvl_fail = torch.zeros(Lmax+1, len(FAIL_KEYS), device=device)

    rows_local: list[tuple] = []
    traces_local: list[tuple] = []

    for s in tqdm(range(0, len(idx), batch_size), desc=f"gen-eval[r{rank()}]", disable=rank()!=0):
        span = idx[s:s+batch_size]
        batch_q, qlens = [], []
        for i in span:
            ex = ds.ex[i]; q = ex.tokens[: ex.colon + 1]
            batch_q.append(q); qlens.append(len(q))
        mx = max(qlens); pad = tok.pad_id
        x = torch.tensor([[pad]*(mx-len(q)) + q for q in batch_q], dtype=torch.long, device=device)
        am = torch.tensor([[0]*(mx-len(q)) + [1]*len(q) for q in batch_q], dtype=torch.long, device=device)
        with torch.no_grad():
            out = model.generate(input_ids=x, attention_mask=am, do_sample=False, max_new_tokens=max_new, pad_token_id=tok.pad_id, eos_token_id=tok.eos_id)
        input_len = x.size(1)

        for j, i in enumerate(span):
            seq = out[j].tolist(); new_ids = seq[input_len:]; new_len_sum += len(new_ids); eos_hits += int(tok.eos_id in new_ids)
            gen = decode_tokens(new_ids, tok); path_toks = take_path_tokens(gen)
            ex = ds.ex[i]; lvl = ex.level
            ok, reasons, hops_pred = check_path_violations(adj, ex.u, ex.v, ex.inc, ex.exc, path_toks)

            succ_overall += torch.tensor([1.0 if ok else 0.0, 1.0, 1.0 if reasons["missing_included"] == 0 else 0.0, 1.0 if reasons["wrong_end"] == 0 else 0.0], device=device)
            b = int(np.clip(np.digitize(ds.diffs[i], bins) - 1, 0, nb - 1))
            bin_cnt[b] += 1.0; 
            if ok: bin_succ[b] += 1.0
            for k, key in enumerate(FAIL_KEYS): fail_vec[k] += float(reasons[key] > 0)

            lvl_cnt[lvl] += 1.0; 
            if ok: lvl_succ[lvl] += 1.0
            if reasons["missing_included"] == 0: lvl_inc_ok[lvl] += 1.0
            if reasons["wrong_end"] == 0: lvl_end_ok[lvl] += 1.0
            lvl_bin_cnt[lvl, b] += 1.0
            if ok: lvl_bin_succ[lvl, b] += 1.0
            for k, key in enumerate(FAIL_KEYS): lvl_fail[lvl, k] += float(reasons[key] > 0)

            rows_local.append((
                i, ex.u, ex.v, int(ds.hops[i]), int(lvl),
                "".join(sorted(ex.inc)), "".join(sorted(ex.exc)),
                float(ds.diffs[i]), int(ok),
                int(reasons["used_excluded"]), int(reasons["missing_included"]), int(reasons["wrong_end"]),
                int(reasons["non_edge"]), int(reasons["bad_start"]), int(reasons["bad_label"]), int(reasons["truncated"]), int(reasons["empty_path"]),
                len(new_ids), int(tok.eos_id in new_ids)
            ))

            q_str = fmt_query(ex.u, ex.v, ex.inc, ex.exc)
            ref_path = "".join(decode_tokens(ex.tokens[ex.colon+1:-1], tok))
            gen_path = join_path_tokens(path_toks)
            gen_raw = " ".join(gen)
            traces_local.append((i, q_str, ref_path, gen_path, gen_raw, int(ok), hops_pred, len(new_ids), int(tok.eos_id in new_ids)))

    succ_overall = allreduce_sum(succ_overall); bin_succ = allreduce_sum(bin_succ); bin_cnt = allreduce_sum(bin_cnt); fail_vec = allreduce_sum(fail_vec)
    lvl_succ = allreduce_sum(lvl_succ); lvl_cnt = allreduce_sum(lvl_cnt); lvl_inc_ok = allreduce_sum(lvl_inc_ok); lvl_end_ok = allreduce_sum(lvl_end_ok)
    lvl_bin_succ = allreduce_sum(lvl_bin_succ); lvl_bin_cnt = allreduce_sum(lvl_bin_cnt); lvl_fail = allreduce_sum(lvl_fail)
    eos_hits_t = allreduce_sum(torch.tensor(float(eos_hits), device=device)); new_len_sum_t = allreduce_sum(torch.tensor(float(new_len_sum), device=device))
    rows_all = allgather_list(rows_local); traces_all = allgather_list(traces_local)

    total = max(succ_overall[1].item(), 1.0); avg_new = new_len_sum_t.item() / total
    solve_by_bin = (bin_succ / torch.clamp(bin_cnt, min=1.0)).tolist()
    fail_counts = {k: int(fail_vec[i].item()) for i, k in enumerate(FAIL_KEYS)}
    fail_fracs = {k: (0.0 if total == 0 else float(fail_counts[k] / total)) for k in FAIL_KEYS}

    per_level = {}
    levels_all = sorted(set(ds.levels.tolist()))
    for L in levels_all:
        n = max(lvl_cnt[L].item(), 1.0)
        per_level[L] = {
            "n": int(lvl_cnt[L].item()),
            "solve_rate": float(lvl_succ[L].item() / n),
            "inc_ok": float(lvl_inc_ok[L].item() / n),
            "end_ok": float(lvl_end_ok[L].item() / n),
            "solve_by_bin": [float(x) for x in (lvl_bin_succ[L] / torch.clamp(lvl_bin_cnt[L], min=1.0)).tolist()],
            "counts_by_bin": [int(c) for c in lvl_bin_cnt[L].tolist()],
            "fail_counts": {k: int(lvl_fail[L, i].item()) for i, k in enumerate(FAIL_KEYS)},
        }

    metrics = {
        "levels": levels_all,
        "overall": {
            "solve_rate": succ_overall[0].item() / total,
            "inc_ok": succ_overall[2].item() / total,
            "end_ok": succ_overall[3].item() / total,
            "n": int(total),
            "eos_frac": float(eos_hits_t.item() / total),
            "avg_new_len": float(avg_new),
            "solve_by_bin": [float(x) for x in solve_by_bin],
            "counts_by_bin": [int(c) for c in bin_cnt.tolist()],
            "fail_counts": fail_counts, "fail_fracs": fail_fracs,
        },
        "per_level": per_level,
    }
    return metrics, rows_all, traces_all

def loss_by_difficulty(model: GPT2LMHeadModel, ds: PathDS, batch_size: int, bins: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    model.eval(); device = next(model.parameters()).device
    idxs = shard_indices(len(ds))
    sub_ds = torch.utils.data.Subset(ds, idxs)
    dl = DataLoader(sub_ds, batch_size=batch_size, shuffle=False, collate_fn=pad_collate, pin_memory=True)
    sums = torch.zeros(len(bins)-1, device=device); cnts = torch.zeros(len(bins)-1, device=device)
    ptr = 0; diffs_all = ds.diffs[idxs]
    for b in tqdm(dl, desc=f"loss-eval[r{rank()}]", disable=rank()!=0):
        x = b["input_ids"].to(device); y = b["labels"].to(device); m = (y != -100)
        with torch.no_grad():
            logits = model(input_ids=x, attention_mask=b["attention_mask"].to(device)).logits
            logp = torch.log_softmax(logits, dim=-1)
            y_safe = torch.where(m, y, torch.zeros_like(y))
            tok_nll = -torch.gather(logp, 2, y_safe.unsqueeze(-1)).squeeze(-1) * m
            ex_loss = tok_nll.sum(dim=1) / m.sum(dim=1).clamp_min(1)
        bs = x.size(0); diffs = diffs_all[ptr:ptr+bs]; ptr += bs
        bid = np.digitize(diffs, bins) - 1; bid = np.clip(bid, 0, len(bins)-2)
        for k in range(len(bins)-1):
            mask = torch.tensor(bid==k, device=device)
            if mask.any():
                sums[k] += ex_loss[mask].sum(); cnts[k] += mask.sum()
    sums = allreduce_sum(sums); cnts = allreduce_sum(cnts)
    means = torch.where(cnts>0, sums/cnts, torch.zeros_like(sums))
    centers = 0.5*(bins[:-1]+bins[1:])
    return centers, means.detach().cpu().numpy()

# ---------------- Eval callback (saves metrics + traces) ----------
class EvalLevelsCallback(TrainerCallback):
    def __init__(self, cfg: CFG, tok: Tok, ev_ds: PathDS, adj: dict[tuple[str,str], set[str]], run_dir: Path):
        self.cfg, self.tok, self.ev_ds, self.adj, self.run_dir = cfg, tok, ev_ds, adj, run_dir
        self.epochs: list[int] = []; self.solve_overall: list[float] = []
        self.levels = list(ev_ds.level_set)
        dmin, dmax = float(ev_ds.diffs.min()), float(ev_ds.diffs.max()); eps = 1e-6
        self.bins = np.linspace(dmin-eps, dmax+eps, cfg.diff_bins+1); self.bin_centers = 0.5*(self.bins[:-1]+self.bins[1:])

    def on_epoch_end(self, args, state, control, model=None, **kwargs):
        if model is None: return control
        metrics, rows, traces = batched_generate_eval_levels(model, self.tok, self.ev_ds, self.adj, self.cfg.max_new_tokens, self.cfg.eval_bs, self.cfg.eval_sample_gen, self.bins)
        _, loss_means = loss_by_difficulty(model, self.ev_ds, self.cfg.eval_bs, self.bins)
        if rank()==0:
            e = int(state.epoch); self.epochs.append(e); self.solve_overall.append(metrics["overall"]["solve_rate"])
            Path(self.run_dir / f"epoch_{e}_solve.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
            with open(self.run_dir / "level_solve_epoch.csv", "a", newline="") as f:
                w = csv.writer(f); 
                if f.tell()==0: w.writerow(["epoch","level","solve_rate","count"])
                for L in self.levels:
                    pl = metrics["per_level"].get(L, {"solve_rate": 0.0, "n": 0})
                    w.writerow([e, L, float(pl["solve_rate"]), int(pl["n"])])

            with open(self.run_dir / "solve_by_diff_epoch.csv", "a", newline="") as f:
                w = csv.writer(f)
                if f.tell()==0: w.writerow(["epoch","bin_center","solve_rate","count","eval_loss"])
                for c, s, n, lm in zip(self.bin_centers, metrics["overall"]["solve_by_bin"], metrics["overall"]["counts_by_bin"], loss_means):
                    w.writerow([e, float(c), float(s), int(n), float(lm)])

            if rows:
                p = self.run_dir / f"per_sample_epoch_{e}.csv"
                if not p.exists():
                    p.write_text("epoch,idx,u,v,hops,level,include,exclude,difficulty,ok,used_excluded,missing_included,wrong_end,non_edge,bad_start,bad_label,truncated,empty_path,new_len,eos_hit\n", encoding="utf-8")
                with open(p, "a", newline="") as f:
                    w = csv.writer(f); w.writerows([(e, *r) for r in rows])

            if traces:
                # CSV
                pcsv = self.run_dir / f"eval_traces_epoch_{e}.csv"
                if not pcsv.exists():
                    pcsv.write_text("epoch,idx,query,ref_path,gen_path,gen_raw,ok,hops_pred,new_len,eos_hit\n", encoding="utf-8")
                with open(pcsv, "a", newline="") as f:
                    w = csv.writer(f)
                    for t in traces: w.writerow((e, *t))
                # JSONL
                pjsonl = self.run_dir / f"eval_traces_epoch_{e}.jsonl"
                with open(pjsonl, "a", encoding="utf-8") as f:
                    for i, q, refp, genp, genraw, ok, hopsp, newlen, eos in traces:
                        f.write(json.dumps({
                            "epoch": int(e), "idx": int(i), "query": q, "ref_path": refp, "gen_path": genp,
                            "gen_raw": genraw, "ok": int(ok), "hops_pred": int(hopsp), "new_len": int(newlen), "eos_hit": int(eos)
                        }) + "\n")
        return control

# -------------------------- Helpers ------------------------------
def count_params(m: torch.nn.Module) -> int: return sum(p.numel() for p in m.parameters())
def is_main() -> bool:
    return (not torch.distributed.is_available()) or (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == 0

def normalize_variants(spec: list[tuple]) -> list[tuple[str, int, int, int, int]]:
    out: list[tuple[str, int, int, int, int]] = []
    for t in spec:
        if len(t) == 4 and isinstance(t[0], int):
            E, L, H, IM = t; name = f"E{E}_L{L}_H{H}_IM{IM}"
        elif len(t) == 5 and isinstance(t[0], str):
            name, E, L, H, IM = t
        else:
            raise ValueError("variant must be (E,L,H,IM) or ('name',E,L,H,IM)")
        if E % H != 0: raise ValueError(f"E must be divisible by H: {E=} {H=}")
        out.append((name, E, L, H, IM))
    return out

# ---------------------------- Main -------------------------------
def main() -> None:
    cfg = CFG()
    torch.manual_seed(cfg.seed); torch.cuda.manual_seed_all(cfg.seed); random.seed(cfg.seed); np.random.seed(cfg.seed)
    tok = build_tok(cfg.upper_grams)
    out_root = Path(cfg.out_dir); out_root.mkdir(parents=True, exist_ok=True)
    Path(out_root / "vocab.json").write_text(json.dumps(tok.stoi, indent=2), encoding="utf-8")

    train_lines = [ln for ln in Path(cfg.train_path).read_text(encoding="utf-8").splitlines() if ln.strip()]
    eval_lines = [ln for ln in Path(cfg.eval_path).read_text(encoding="utf-8").splitlines() if ln.strip()]
    tr_ds, ev_ds = PathDS(train_lines, tok), PathDS(eval_lines, tok)
    adj = load_graph(Path(cfg.graph_path))
    if is_main(): logger.info(f"train={len(tr_ds)} eval={len(ev_ds)} vocab={len(tok.stoi)} ctx={cfg.n_positions} levels={ev_ds.level_set}")

    # max_new_tokens from 99th percentile of target lengths
    lens = [len(e.tokens) - (e.colon + 1) for e in tr_ds.ex]
    auto_max = int(np.percentile(lens, 99)) + 4
    if auto_max > cfg.max_new_tokens: cfg.max_new_tokens = auto_max
    if is_main(): logger.debug(f"max_new_tokens={cfg.max_new_tokens} (auto)")

    variants = normalize_variants(VARIANTS)
    if is_main(): logger.info(f"variants={variants}")

    for name, E, L, H, IM in variants:
        run_dir = out_root / name; run_dir.mkdir(parents=True, exist_ok=True)
        done = (run_dir / "config.json").exists() and ((run_dir / "model.safetensors").exists() or (run_dir / "pytorch_model.bin").exists())
        if done:
            if is_main(): logger.info(f"[{name}] already trained; skipping")
            continue
        if is_main(): logger.add(str(run_dir / "run.log"), enqueue=True, level="INFO")
        cfgm = GPT2Config(
            vocab_size=len(tok.stoi), n_positions=cfg.n_positions, n_ctx=cfg.n_positions,
            n_embd=E, n_layer=L, n_head=H, n_inner=IM*E, use_cache=False,
            pad_token_id=tok.pad_id, eos_token_id=tok.eos_id
        )
        model = GPT2LMHeadModel(cfgm); params = count_params(model)
        bucket = name.split("_", 1)[0]; lr = LR_PER.get(bucket, cfg.lr)
        if is_main(): logger.info(f"[{name}] emb={E} heads={H} layers={L} inner_mult={IM} params={params:,} lr={lr}")

        args = TrainingArguments(
            output_dir=str(run_dir),
            per_device_train_batch_size=cfg.batch_size, per_device_eval_batch_size=cfg.eval_bs,
            learning_rate=lr, weight_decay=cfg.wd, num_train_epochs=cfg.epochs, warmup_ratio=cfg.warmup_ratio,
            logging_steps=cfg.log_steps, report_to=[], bf16=cfg.bf16, save_strategy="no", eval_strategy="epoch",
            ddp_find_unused_parameters=False, disable_tqdm=False, gradient_checkpointing=cfg.gradient_checkpointing,
            gradient_accumulation_steps=cfg.grad_accum_steps, optim="adamw_torch", lr_scheduler_type="cosine"
        )
        trainer = Trainer(model=model, args=args, train_dataset=tr_ds, eval_dataset=ev_ds, data_collator=pad_collate)
        cb = EvalLevelsCallback(cfg, tok, ev_ds, adj, run_dir); trainer.add_callback(cb)
        trainer.train()

        if is_main():
            model.save_pretrained(run_dir); Path(run_dir / "vocab.json").write_text(json.dumps(tok.stoi, ensure_ascii=True), encoding="utf-8")
            logs = trainer.state.log_history
            tr_steps = [e.get("step", j) for j, e in enumerate(logs) if "loss" in e]; tr_losses = [e["loss"] for e in logs if "loss" in e]
            if len(tr_steps) and len(tr_losses):
                fig_r, ax_r = plt.subplots(1, 1, figsize=(6.2, 3.6), constrained_layout=True)
                ax_r.plot([s if s>0 else 1 for s in tr_steps], tr_losses, linewidth=2)
                ax_r.set_xscale("log"); ax_r.set_xlabel("step"); ax_r.set_ylabel("train_loss"); ax_r.grid(True, alpha=0.3); ax_r.set_title("Training loss")
                fig_r.savefig(run_dir / "train_loss.png", dpi=220, bbox_inches="tight"); plt.close(fig_r)
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
