
from __future__ import annotations

import itertools
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, List
from pathlib import Path
from types import SimpleNamespace
import os, sys, subprocess
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore", message="Could not infer format", category=UserWarning)
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")  # reduce CUDA fragmentation

# (optional but safe) default to float32 everywhere
torch.set_default_dtype(torch.float32)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# -----------------------------------------------------------------------------



PHI_MIN, PHI_MAX = -50.0, 50.0
A_MIN, A_MAX = 1e-3, 10.0

# Colab / environment setup
CANDIDATE_POLICY = "closed"
DEBUG_STRICT_PAST = True
PRINT_ITEM_TOPK = 10
torch.autograd.set_detect_anomaly(False)

# Dataset and evaluation hyperparameters
TINY_MODE = False
MAX_EDGES_PER_DATASET = 10000 #50_000
N_BINS_TOTAL_TINY = 12
LIMIT_TIME_QUANTILE_TINY = None

MIN_EDGES_PER_BIN = 800
MIN_EDGES_PER_SPLIT = 1200
MIN_BINS_GLOBAL = 8
MIN_UNIQUE_TS_FOR_TIME = 3 * MIN_BINS_GLOBAL

EVAL_NUM_NEGS = 256  # unused when FULL_SOFTMAX=True
REPORT_HITSK = 10

INCLUDE_TGB = False

# -----------------------------------------------------------------------------
# TGB bootstrap (optional)

def _ensure_tgb():
    if not INCLUDE_TGB:
        return False
    try:
        from tgb.linkproppred.dataset import LinkPropPredDataset  # noqa
        return True
    except Exception:
        print("[INFO] Installing 'tgb'...")
        subprocess.run([sys.executable, "-m", "pip", "install", "-q", "tgb"], check=False)
        try:
            from tgb.linkproppred.dataset import LinkPropPredDataset  # noqa
            return True
        except Exception as e:
            print("[WARN] TGB still not importable:", e)
            return False

_TGB_AVAILABLE = _ensure_tgb()
if _TGB_AVAILABLE:
    from tgb.linkproppred.dataset import LinkPropPredDataset
else:
    LinkPropPredDataset = None

# -----------------------------------------------------------------------------
# Paths/Globals

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINTS = Path("checkpoints"); CHECKPOINTS.mkdir(parents=True, exist_ok=True)

USER_CSV_DATASETS = {
    # c0 = user_id, c1 = item/page/artist id, c2 = timestamp
    "jodie-wikipedia": {"path": "datasets/jodie_wikipedia.csv", "src": "c0", "dst": "c1", "time": "c2", "sep": ",", "header": False},
    "jodie-lastfm":    {"path": "datasets/jodie_lastfm.csv",    "src": "c0", "dst": "c1", "time": "c2", "sep": ",", "header": False},
}

JODIE_URLS = {
    "jodie-wikipedia": "http://snap.stanford.edu/jodie/wikipedia.csv",
    "jodie-lastfm":    "http://snap.stanford.edu/jodie/lastfm.csv",
}

def _maybe_download_csv(key: str, path: str):
    """Download a CSV dataset if it does not exist locally."""
    if os.path.exists(path):
        return
    os.makedirs(os.path.dirname(path), exist_ok=True)
    url = JODIE_URLS.get(key)
    assert url, f"No URL for {key}"
    print(f"[INFO] Download {key} -> {path}")
    import urllib.request
    urllib.request.urlretrieve(url, path)

# -----------------------------------------------------------------------------
# Timestamp parsing utilities

def _normalize_numeric_epoch_to_seconds(ts: np.ndarray) -> np.ndarray:
    x = ts.astype("float64")
    finite = np.isfinite(x)
    if not finite.any():
        return np.zeros_like(x, dtype=np.int64)
    med = np.nanmedian(x[finite])
    if med >= 1e17:      x = np.floor(x / 1e9)   # ns→s
    elif med >= 1e14:    x = np.floor(x / 1e6)   # us→s
    elif med >= 1e11:    x = np.floor(x / 1e3)   # ms→s
    else:                x = np.floor(x)         # assume seconds
    x[~finite] = np.nan
    x = np.nan_to_num(x, nan=np.nanmin(x[np.isfinite(x)]) if np.isfinite(x).any() else 0.0)
    return x.astype(np.int64)

def _ensure_time_numeric_robust(ts_col: pd.Series) -> np.ndarray:
    vals = ts_col.values
    if np.issubdtype(vals.dtype, np.number):
        return _normalize_numeric_epoch_to_seconds(vals)
    dt = pd.to_datetime(ts_col, errors="coerce", utc=True)
    if dt.isna().all():
        try:
            numeric = ts_col.astype("float64").values
            return _normalize_numeric_epoch_to_seconds(numeric)
        except Exception:
            return np.zeros(len(ts_col), dtype=np.int64)
    return (dt.astype("int64") // 10**9).astype(np.int64)

def _auto_detect_time_col(df: pd.DataFrame, exclude_cols: List[str]) -> Optional[str]:
    best_col, best_u, best_kind = None, -1, None
    for c in df.columns:
        if c in exclude_cols:
            continue
        s = df[c]
        num = pd.to_numeric(s, errors="coerce")
        frac_num = float(num.notna().mean())
        if frac_num >= 0.70:
            sec = _normalize_numeric_epoch_to_seconds(num.values)
            u = int(np.unique(sec).shape[0])
            span = float(np.nanmax(sec) - np.nanmin(sec)) if np.isfinite(sec).any() else 0.0
            if u > best_u and span > 0:
                best_col, best_u, best_kind = c, u, "numeric"
            continue
        dt = pd.to_datetime(s, errors="coerce", utc=True)
        if dt.notna().any():
            sec = (dt.astype("int64") // 10**9).astype(np.int64)
            sec = sec[np.isfinite(sec)]
            if sec.size:
                u = int(np.unique(sec).shape[0])
                span = float(np.nanmax(sec) - np.nanmin(sec))
                if u > best_u and span > 0:
                    best_col, best_u, best_kind = c, u, "datetime"
    if best_col is not None:
        print(f"[INFO] auto-detected time column: {best_col} ({best_kind}), unique_ts={best_u}")
    return best_col

# -----------------------------------------------------------------------------
# ID + column helpers

def pessimistic_item_ranks(probs_items: torch.Tensor) -> torch.Tensor:
    """Return ranks with pessimistic tie handling for a vector of item probabilities."""
    N = probs_items.numel()
    if N == 0:
        return torch.empty(0, dtype=torch.long, device=probs_items.device)
    # sort descending, stable to preserve equal-prob groups
    order = torch.argsort(probs_items, descending=True, stable=True)
    sorted_p = probs_items[order]
    # group ends for equal values
    vals, counts = torch.unique_consecutive(sorted_p, return_counts=True)
    ends = torch.cumsum(counts, 0) - 1  # last index per group
    group_end_pos = torch.repeat_interleave(ends, counts)
    # map worst position back to item index
    worst_pos_for_item = torch.empty(N, dtype=torch.long, device=probs_items.device)
    worst_pos_for_item[order] = group_end_pos
    return worst_pos_for_item + 1  # ranks in [1..N]

def _ensure_int(arr):
    a = np.asarray(arr)
    if np.issubdtype(a.dtype, np.number):
        return a.astype(np.int64, copy=False)
    _, inv = np.unique(a, return_inverse=True)
    return inv.astype(np.int64)

def _infer_cols(df: pd.DataFrame):
    cols = [c.lower().strip() for c in df.columns.tolist()]
    lut = {c.lower().strip(): c for c in df.columns.tolist()}
    def pick(cands):
        for k in cands:
            if k in lut:
                return lut[k]
        for k in cands:
            for c in cols:
                if k in c:
                    return lut[c]
        return None
    src = pick(["u", "user", "user_id", "uid"]) or df.columns[0]
    dst = pick(["i", "item", "item_id", "iid"]) or df.columns[1]
    tim = pick(["ts", "time", "timestamp", "t", "datetime", "date"]) or df.columns[2]
    if {src, dst, tim} - set(df.columns.tolist()):
        raise ValueError("Could not infer (src,dst,time) columns.")
    print(f"[INFO] Using columns: src={src}, dst={dst}, time={tim}")
    return src, dst, tim

def _dataset_column_hints(key: str, df: pd.DataFrame):
    key = (key or "").lower()
    cols = {c.lower(): c for c in df.columns}
    def pick(*names):
        for n in names:
            if n in cols:
                return cols[n]
        return None
    if "wikipedia" in key:
        src = pick("user_id", "uid", "u", "user")
        dst = pick("page_id", "item_id", "article_id", "iid", "item", "page")
        tim = pick("timestamp", "ts", "time", "datetime", "unix_ts", "unix_time")
        return src, dst, tim
    if "lastfm" in key:
        src = pick("user_id", "uid", "u", "user")
        dst = pick("artist_id", "track_id", "item_id", "iid", "artist", "track")
        tim = pick("timestamp", "ts", "time", "datetime", "unix_ts", "unix_time")
        return src, dst, tim
    return None, None, None

def _sanity_item_histogram(df: pd.DataFrame, item_col: str, tag: str):
    vc = df[item_col].value_counts()
    total = int(vc.sum()); uniq = int(vc.shape[0])
    topk = vc.head(PRINT_ITEM_TOPK)
    top1_share = (float(topk.iloc[0]) / total) if total else float("nan")
    uniq_ratio = (uniq / total) if total else float("nan")
    print(f"[SANITY:{tag}] items={uniq} edges={total} uniq/edge={uniq_ratio:.3f} top1_share={top1_share:.3f}")
    print(f"[SANITY:{tag}] top-{min(PRINT_ITEM_TOPK, len(topk))} items:\n{topk.to_string()}")
    if (top1_share > 0.5) or (uniq_ratio > 0.9):
        print(f"[RED-FLAG:{tag}] Item mapping likely wrong (degenerate histogram). "
              f"Consider specifying cfg['dst'] explicitly (e.g., page_id / artist_id).")

def _sanity_timestamp_unique(ts: np.ndarray, tag: str):
    u = int(np.unique(ts).shape[0])
    print(f"[SANITY:{tag}] unique_ts={u} (>= {MIN_UNIQUE_TS_FOR_TIME} recommended)")
    if u < MIN_UNIQUE_TS_FOR_TIME:
        print(f"[RED-FLAG:{tag}] Very few unique timestamps. Parse the true timestamp column "
              f"(units/timezone) or use official splits. Set STRICT_TRUE_TIMESTAMPS=True to forbid fallback.")

# -----------------------------------------------------------------------------
# Binning utilities

def _edges_grouped_by_bin(t: torch.Tensor, cuts: torch.Tensor) -> List[torch.Tensor]:
    bin_idx = torch.bucketize(t.float(), cuts, right=True) - 1
    bin_idx = torch.clamp(bin_idx, 0, cuts.numel()-2)
    order = torch.argsort(bin_idx)
    counts = torch.bincount(bin_idx, minlength=cuts.numel()-1)
    bins = []
    start = 0
    for c in counts.tolist():
        end = start + c
        bins.append(order[start:end])
        start = end
    return bins

def _compute_cuts_splitwise(t: torch.Tensor,
                            train_mask: torch.Tensor,
                            val_mask: torch.Tensor,
                            test_mask: torch.Tensor,
                            n_bins_total: int = 40,
                            q_train: float = 0.70,
                            q_val: float = 0.85):
    n_train_bins = max(2, int(round(n_bins_total * q_train)))
    n_val_bins   = max(1, int(round(n_bins_total * (q_val - q_train))))
    n_test_bins  = max(1, n_bins_total - n_train_bins - n_val_bins)
    t = t.float()
    t_tr = t[train_mask]; t_va = t[val_mask]; t_te = t[test_mask]
    t_min, t_max = torch.min(t), torch.max(t)
    def _cuts_local(x, n_bins):
        if x.numel() == 0:
            return torch.tensor([0.0, 1.0], device=t.device)
        x_min, x_max = torch.min(x), torch.max(x)
        if x_max <= x_min:
            return torch.stack([x_min - 1.0, x_max + 1.0])
        q = torch.linspace(0, 1, n_bins + 1, device=t.device)
        c = torch.quantile(x, q).unique()
        if c.numel() < 3:
            c = torch.linspace(x_min, x_max, n_bins + 1, device=t.device)
        return torch.sort(c).values
    cuts_train = _cuts_local(t_tr, n_train_bins)
    cuts_val   = _cuts_local(t_va, n_val_bins)
    cuts_test  = _cuts_local(t_te, n_test_bins)
    parts = [cuts_train]
    if cuts_val.numel() > 1:
        parts.append(cuts_val[1:])
    if cuts_test.numel() > 1:
        parts.append(cuts_test[1:])
    cuts = torch.unique(torch.cat(parts)); cuts = torch.sort(cuts).values
    if cuts.numel() < max(5, MIN_BINS_GLOBAL+1):
        if t_max <= t_min:
            cuts = torch.tensor([t_min-2, t_min-1, t_min, t_min+1, t_min+2], device=t.device)
        else:
            cuts = torch.linspace(t_min, t_max, max(5, MIN_BINS_GLOBAL+1), device=t.device)
    return cuts

def _compute_cuts_by_edgecount(t: torch.Tensor, n_bins_total: int) -> torch.Tensor:
    t = t.float(); m = t.numel()
    if m == 0:
        return torch.tensor([0., 1.], device=t.device)
    target_bins = max(MIN_BINS_GLOBAL,
                      min(int(n_bins_total),
                          int(max(1, m // max(1, MIN_EDGES_PER_BIN)))*2, m))
    order = torch.argsort(t); t_sorted = t[order]
    idx = torch.linspace(0, m, target_bins + 1, device=t.device).long()
    idx[-1] = m
    right_idx = torch.clamp(idx - 1, 0, m - 1)
    cuts = t_sorted[right_idx]
    if cuts.numel() >= 2:
        cuts[-1] = cuts[-1] + 1e-6
    cuts = torch.sort(cuts.unique()).values
    if cuts.numel() < MIN_BINS_GLOBAL + 1:
        t_min, t_max = t_sorted[0], t_sorted[-1]
        if t_max <= t_min:
            base = torch.linspace(-3, 3, MIN_BINS_GLOBAL + 1, device=t.device)
            cuts = base + t_min
        else:
            cuts = torch.linspace(t_min, t_max, MIN_BINS_GLOBAL + 1, device=t.device)
    return cuts

def _split_bins_with_min_edges(bins_edges: List[torch.Tensor],
                               ratios=(0.70, 0.15, 0.15),
                               min_edges_split: int = MIN_EDGES_PER_SPLIT):
    B = len(bins_edges); r0, r1, r2 = ratios
    ntr = max(2, int(round(r0 * B))); nva = max(1, int(round(r1 * B))); nte = max(1, B - ntr - nva)
    if ntr + nva + nte > B:
        ntr = max(2, B - nva - nte)
    train = list(range(0, ntr)); val = list(range(ntr, ntr + nva)); test = list(range(ntr + nva, B))
    cnts = [e.numel() for e in bins_edges]
    def split_sum(lst): return int(sum(cnts[b] for b in lst))
    for _ in range(B):
        ok = True
        if split_sum(val) < min_edges_split and len(test) > 1:
            val.append(test[0]); test = test[1:]; ok = False
        if split_sum(test) < min_edges_split and len(val) > 1:
            test = [val[-1]] + test; val = val[:-1]; ok = False
        if ok:
            break
    def ensure_nonempty(target, donor):
        if any(cnts[b] > 0 for b in target):
            return target, donor
        for b in reversed(donor):
            if cnts[b] > 0:
                target = [b] + target; donor  = [x for x in donor if x != b]; break
        return target, donor
    val,  train = ensure_nonempty(val,  train)
    test, val   = ensure_nonempty(test, val)
    return train, val, test

# -----------------------------------------------------------------------------


def _load_any_dataset(names, root="datasets",
                      max_edges: Optional[int]=None, limit_q: Optional[float]=None):
    nm = names[-1] if isinstance(names, (list, tuple)) else str(names)
    if isinstance(nm, str) and nm.startswith("csv:"):
        key = nm.split("csv:",1)[1].strip()
        print(f"Using CSV dataset: {key}")
        return _load_temporal_csv_by_key(key, max_edges=max_edges, limit_q=limit_q)
    if not _TGB_AVAILABLE:
        raise RuntimeError("TGB not installed and a non-CSV dataset was requested.")
    ds = LinkPropPredDataset(name=nm, root=root, preprocess=True)
    _ = ds.full_data
    print(f"Using dataset: {nm}")
    return ds

# -----------------------------------------------------------------------------
# Math helpers

def set_seed(seed: int = 1337):
    import random
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def softplus_pos(x):
    return F.softplus(x) + 1e-6

def project_to_simplex(v: torch.Tensor) -> torch.Tensor:
    shape = v.shape; v = v.reshape(-1, shape[-1])
    u, _ = torch.sort(v, descending=True, dim=1); cssv = torch.cumsum(u, dim=1) - 1
    ind = torch.arange(1, v.shape[1] + 1, device=v.device).float()
    cond = u - cssv / ind > 0; rho = cond.sum(dim=1) - 1
    theta = cssv[torch.arange(v.shape[0]), rho] / (rho + 1).float()
    w = torch.clamp(v - theta.unsqueeze(1), min=0); return w.reshape(shape)

def entropy(mu: torch.Tensor):
    mu_safe = torch.clamp(mu, 1e-12); return torch.sum(mu_safe * (torch.log(mu_safe) - 1.0), dim=-1)


def _candidate_items_for_eval(bi: int,
                              mode: str,
                              prefix: _PrefixView,
                              dst: torch.Tensor,
                              item_id_to_index: torch.Tensor,
                              items_idx: torch.Tensor,
                              dev: torch.device):
    """
    Returns (items_cand_local, items_cand_global)
      - local: indices into readout.item embedding [0..num_items-1]
      - global: node ids in the big graph space
    """
    if mode == "all":
        # Full item universe: comparable to JODIE/TGAT tables
        items_cand_local  = torch.arange(items_idx.numel(), device=dev)
        items_cand_global = items_idx
        return items_cand_local, items_cand_global

    # closed-world prefix: items observed up to the current prefix cut
    m_prefix = int(prefix.m_at_cut[bi].item())
    items_now_global = torch.unique(dst[:m_prefix]).to(dev)
    items_now_local  = item_id_to_index[items_now_global]
    keep = items_now_local >= 0
    return items_now_local[keep], items_now_global[keep]



# -----------------------------------------------------------------------------
# GLF (geometry + PDHG/JKO)

@dataclass
class GraphSnapshot:
    n: int
    m: int
    M_coo: torch.Tensor   # [2, 2m], unused in simplified representation
    M_val: torch.Tensor   # [2m], unused in simplified representation
    w_base: torch.Tensor  # [m], edge features (for A)
    X_nodes: torch.Tensor # [n, xdim], node features for phi
    edge_u: torch.Tensor  # [m] edge source node
    edge_v: torch.Tensor  # [m] edge destination node

def _xraw_from_cum(in_cum, out_cum):
    tot = in_cum + out_cum
    X_raw = torch.stack([in_cum, out_cum, tot, torch.log1p(tot), torch.ones_like(tot)], dim=1)
    y_pers = in_cum / in_cum.sum().clamp_min(1e-12)
    return X_raw, y_pers

@dataclass
class _PrefixView:
    src_s: torch.Tensor; dst_s: torch.Tensor; t_s: torch.Tensor
    cols: torch.Tensor; ones: torch.Tensor; m_at_cut: torch.Tensor; num_nodes: int

def _prepare_prefix(src, dst, t, cuts, num_nodes, device):
    order = torch.argsort(t); src_s = src[order]; dst_s = dst[order]; t_s = t[order].float()
    m_at_cut = torch.bucketize(cuts.float(), t_s, right=True)
    m = src_s.numel(); cols = torch.arange(m, device=device, dtype=torch.long); ones = torch.ones(m, device=device)
    return _PrefixView(src_s, dst_s, t_s, cols, ones, m_at_cut, num_nodes)

def _snapshot_from_prefix(prefix: _PrefixView, bi: int, X_nodes: torch.Tensor) -> GraphSnapshot:
    """
    Construct a GraphSnapshot for the prefix up to bin index `bi`.
    If there are no edges in the prefix, all tensors are empty except for `X_nodes`.
    """
    m = int(prefix.m_at_cut[bi+1].item())
    if m == 0:
        return GraphSnapshot(
            prefix.num_nodes, 0,
            M_coo=torch.empty(2,0, dtype=torch.long, device=X_nodes.device),
            M_val=torch.empty(0, device=X_nodes.device),
            w_base=torch.empty(0, device=X_nodes.device),
            X_nodes=X_nodes,
            edge_u=torch.empty(0, dtype=torch.long, device=X_nodes.device),
            edge_v=torch.empty(0, dtype=torch.long, device=X_nodes.device)
        )
    u = prefix.src_s[:m].long()
    v = prefix.dst_s[:m].long()
    if LEARN_A:
        # Build 6-dimensional degree-based features per edge (u,v)
        out_deg = torch.bincount(prefix.src_s[:m], minlength=prefix.num_nodes).float()
        in_deg  = torch.bincount(prefix.dst_s[:m], minlength=prefix.num_nodes).float()
        feat = torch.stack([
            torch.log1p(out_deg[u]),
            torch.log1p(in_deg[u]),
            torch.log1p(out_deg[v]),
            torch.log1p(in_deg[v]),
            torch.log1p(out_deg[u] + in_deg[v]),
            torch.ones(m, device=X_nodes.device)
        ], dim=1)
        w_base = feat
    else:
        w_base = torch.empty(0, device=X_nodes.device)
    return GraphSnapshot(
        prefix.num_nodes, m,
        M_coo=torch.empty(2,0, dtype=torch.long, device=X_nodes.device),
        M_val=torch.empty(0, device=X_nodes.device),
        w_base=w_base,
        X_nodes=X_nodes,
        edge_u=u, edge_v=v
    )

class EdgeConnectionNet(nn.Module):
    """MLP that outputs positive edge weights A(e) from edge features."""
    def __init__(self, in_dim=1, hidden=64):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ELU(),
            nn.Linear(hidden, hidden), nn.ELU(),
            nn.Linear(hidden, 1)
        )
    def forward(self, edge_feats):
        x = edge_feats
        if x.dim() == 1:
            x = x.unsqueeze(-1)
        # auto pad/truncate to match in_features
        want = self.mlp[0].in_features
        have = x.shape[-1]
        if have < want:
            pad = torch.zeros(x.shape[0], want - have, device=x.device, dtype=x.dtype)
            x = torch.cat([x, pad], dim=-1)
        elif have > want:
            x = x[..., :want]
        
        s = self.mlp(x).squeeze(-1)
        A = A_MIN + F.softplus(s)         # enforce uniform ellipticity from below
        A = torch.clamp(A, max=A_MAX)     # cap at A_MAX
        return A

        #return softplus_pos(self.mlp(x)).squeeze(-1)

class NodePotentialNet(nn.Module):
    """MLP that outputs scalar potentials phi(i) from node features."""
    def __init__(self, in_dim=2, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ELU(),
            nn.Linear(hidden, hidden), nn.ELU(),
            nn.Linear(hidden, 1)
        )
    def forward(self, X_nodes):
        return self.net(X_nodes).squeeze(-1)

class GLFModel(nn.Module):
    """
    Learnable components of the GLF: a node potential network phi and an edge
    connection network A.  Optionally, a Hodge penalty term can be added
    externally to regularize A on cycles.
    """
    def __init__(self, x_dim=2, edge_feat_dim=1, hidden_phi=128, hidden_A=64,
                 learn_A=False, phi_scale=1.5, hodge_reg=1e-4):
        super().__init__()
        self.learn_A = learn_A
        self.phi_scale = phi_scale
        self.hodge_reg = hodge_reg
        self.phi_net = NodePotentialNet(in_dim=x_dim, hidden=hidden_phi)
        self.A_net = EdgeConnectionNet(in_dim=edge_feat_dim, hidden=hidden_A)
        # Only update A_net parameters when learn_A is True
        for p in self.A_net.parameters():
            p.requires_grad = learn_A
    def forward_energy_params(self, G: GraphSnapshot):
        X_nodes = torch.nan_to_num(G.X_nodes, nan=0.0, posinf=0.0, neginf=0.0)
        # Compute phi(i), scale and clamp
        phi = self.phi_scale * torch.tanh(self.phi_net(X_nodes))
        phi = torch.clamp(torch.nan_to_num(phi, nan=0.0, posinf=PHI_MAX, neginf=PHI_MIN),
                          min=PHI_MIN, max=PHI_MAX)
        # Compute A(e) if learn_A
            # Compute A(e) if learn_A
        if self.learn_A and G.m > 0:
            # if edge features weren't provided, build degree-based ones on the fly
            if (G.w_base is None) or (G.w_base.numel() == 0):
                u = G.edge_u.long()
                v = G.edge_v.long()
                n = G.n
                dev = X_nodes.device
                out_deg = torch.bincount(u, minlength=n).float().to(dev)
                in_deg  = torch.bincount(v, minlength=n).float().to(dev)
                edge_feats = torch.stack([
                    torch.log1p(out_deg[u]),
                    torch.log1p(in_deg[u]),
                    torch.log1p(out_deg[v]),
                    torch.log1p(in_deg[v]),
                    torch.log1p(out_deg[u] + in_deg[v]),
                    torch.ones(G.m, device=dev)
                ], dim=1)
            else:
                edge_feats = torch.nan_to_num(G.w_base, nan=0.0, posinf=0.0, neginf=0.0).float()
                if edge_feats.dim() == 1:
                    # make it (m, 1) if a flat vector is ever passed
                    edge_feats = edge_feats.view(G.m, 1)
    
            A = self.A_net(edge_feats)
            A = torch.clamp(torch.nan_to_num(A, nan=1.0, posinf=A_MAX, neginf=A_MIN),
                            min=A_MIN, max=A_MAX)
        else:
            A = torch.ones(G.m, device=X_nodes.device)

        return phi, A
    def edge_smoothness_reg(self, A, G: GraphSnapshot):
        """
        Local smoothness regularization on log A across edge endpoints.  This is a
        simplified alternative to a true Hodge/curl penalty and remains for
        backward compatibility.
        """
        if A.numel() == 0 or G.m == 0:
            return torch.tensor(0.0, device=(A.device if A.numel() else G.X_nodes.device))
        logA = torch.log(torch.clamp(A, 1e-8))
        u, v = G.edge_u, G.edge_v
        m = G.m
        # Build incidence (rows: node ids; cols: edge ids repeated twice)
        rows = torch.cat([u, v], dim=0)
        cols = torch.cat([torch.arange(m, device=logA.device), torch.arange(m, device=logA.device)], dim=0)
        deg = torch.zeros(G.n, device=logA.device)
        deg.index_add_(0, rows, torch.ones_like(rows, dtype=torch.float))
        sum_edge_log = torch.zeros(G.n, device=logA.device)
        sum_edge_log.index_add_(0, rows, logA[cols])
        node_mean = torch.where(deg > 0,
                                sum_edge_log / deg.clamp_min(1.0),
                                torch.zeros_like(sum_edge_log))
        mean_on_edge = 0.5 * (node_mean[u] + node_mean[v])
        return F.mse_loss(logA, mean_on_edge, reduction="mean")

    def hodge_penalty(self, G: GraphSnapshot, A: torch.Tensor, cycles: Optional[List[Tuple[int,int,int]]] = None):
        """
        Placeholder for a discrete curl/Hodge penalty on the learned connection A.
        The caller should provide a list of small cycles (triples of node indices).
        This function sums the squared sum of log A along each directed cycle.
        """
        if not self.learn_A or A.numel() == 0:
            return torch.tensor(0.0, device=A.device)
        if cycles is None or len(cycles) == 0:
            return torch.tensor(0.0, device=A.device)
        # Map node pairs to edge indices.  This requires a lookup structure built
        # outside.  For brevity, this implementation is a stub and returns zero.
        # TODO: implement cycle-based penalty using edge lookup.
        return torch.tensor(0.0, device=A.device)

class BarrierSuite:
    """
    Convex barriers with memory-safe implementations.
    - temporal TV (Huber)
    - graph TV (absolute diff)  -> sampled/chunked to avoid OOM
    - Bellman (softplus)        -> already sampled
    """
    def __init__(self, bellman_sample_m=1024, tv_smooth=1e-3, bellman_beta=5.0,
                 gtv_sample_m=65536, gtv_chunk=32768):
        self.bellman_sample_m = int(bellman_sample_m)
        self.tv_smooth = float(tv_smooth)
        self.bellman_beta = float(bellman_beta)
        # NEW: graph-TV memory guards
        self.gtv_sample_m = int(gtv_sample_m)
        self.gtv_chunk = int(gtv_chunk)
        self.last_mu: Optional[torch.Tensor] = None

    # -------- temporal TV --------
    def tv_violation(self, mu):
        if self.last_mu is None or self.last_mu.shape != mu.shape:
            return torch.tensor(0.0, device=mu.device)
        diff = mu - self.last_mu
        d = self.tv_smooth
        return torch.where(diff.abs() < d, 0.5*(diff**2)/d, diff.abs()-0.5*d).sum()

    def tv_subgrad(self, mu):
        if self.last_mu is None or self.last_mu.shape != mu.shape:
            return torch.zeros_like(mu)
        diff = mu - self.last_mu
        d = self.tv_smooth
        return torch.where(diff.abs() < d, diff/d, diff.sign())

    # -------- graph TV (memory-safe) --------
    @torch.no_grad()
    def _graph_tv_pairs(self, G: GraphSnapshot, max_m: int):
        """Return (u,v) possibly downsampled to <= max_m edges."""
        if G.m == 0:
            return (torch.empty(0, dtype=torch.long, device=G.edge_u.device),
                    torch.empty(0, dtype=torch.long, device=G.edge_v.device))
        u, v = G.edge_u, G.edge_v
        m = u.numel()
        if m <= max_m:
            return u, v
        idx = torch.randint(0, m, (max_m,), device=u.device)
        return u[idx], v[idx]

    def graph_tv_violation(self, mu, G: GraphSnapshot):
        if G.m == 0:
            return torch.tensor(0.0, device=mu.device)
        u, v = self._graph_tv_pairs(G, self.gtv_sample_m)
        if u.numel() == 0:
            return torch.tensor(0.0, device=mu.device)
        # chunk to keep peak mem low
        total = torch.tensor(0.0, device=mu.device)
        for s in range(0, u.numel(), self.gtv_chunk):
            e = s + self.gtv_chunk
            total += (mu[u[s:e]] - mu[v[s:e]]).abs().sum()
        # scale to unbiased estimate when sub-sampling
        scale = (G.m / max(1, u.numel()))
        return total * scale

    def graph_tv_subgrad(self, mu, G: GraphSnapshot):
        if G.m == 0:
            return torch.zeros_like(mu)
        u, v = self._graph_tv_pairs(G, self.gtv_sample_m)
        if u.numel() == 0:
            return torch.zeros_like(mu)
        g = torch.zeros_like(mu)
        # chunked accumulation
        for s in range(0, u.numel(), self.gtv_chunk):
            e = s + self.gtv_chunk
            diff = mu[u[s:e]] - mu[v[s:e]]
            sgn = diff.sign()
            g.index_add_(0, u[s:e],  sgn)
            g.index_add_(0, v[s:e], -sgn)
        # unbiased scaling if we sub-sampled
        scale = (G.m / max(1, u.numel()))
        return g * scale

    # -------- Bellman (already sampled) --------
    def bellman_violation_and_grad(self, mu, G: GraphSnapshot, A: torch.Tensor):
        if G.m == 0:
            return torch.tensor(0.0, device=mu.device), torch.zeros_like(mu)
        A = torch.nan_to_num(A, nan=1.0, posinf=1.0, neginf=1.0).clamp_min(1e-8)
        u, v = G.edge_u, G.edge_v
        m = u.numel()
        if m > self.bellman_sample_m:
            idx = torch.randint(0, m, (self.bellman_sample_m,), device=mu.device)
            u = u[idx]; v = v[idx]; A = A[idx]
        eps = 1e-12
        d = -torch.log(mu.clamp_min(eps))
        w = 1.0 / (A.clamp_min(1e-8))
        s = d[u] - d[v] - w
        beta = self.bellman_beta
        viol = F.softplus(beta*s).sum() / beta
        sig = torch.sigmoid(beta*s).detach()
        g_d = torch.zeros_like(mu); g_d.index_add_(0, u, sig); g_d.index_add_(0, v, -sig)
        g_mu = g_d * (-1.0 / mu.clamp_min(eps))
        # scale if sub-sampled
        scale = (G.m / max(1, u.numel()))
        return viol * scale, g_mu * scale

    # (triangle/flow keep as in your file or 0)
    def triangle_violation_and_grad(self, mu, G: GraphSnapshot, A: torch.Tensor, max_trip=2048):
        return torch.tensor(0.0, device=mu.device), torch.zeros_like(mu)

    def flow_violation_and_grad(self, mu, G: GraphSnapshot):
        return torch.tensor(0.0, device=mu.device), torch.zeros_like(mu)

    @torch.no_grad()
    def metrics(self, mu, G: GraphSnapshot, A: torch.Tensor, active: Dict[str, float]):
        out = {}
        if "tv" in active:
            out["tv"] = float(self.tv_violation(mu).detach().cpu())
        if "graph_tv" in active:
            out["graph_tv"] = float(self.graph_tv_violation(mu, G).detach().cpu())
        if "bellman" in active:
            b,_ = self.bellman_violation_and_grad(mu, G, A)
            out["bellman"] = float(b.detach().cpu())
        return out


@dataclass
class GLFConfig:
    tau: float = 0.005
    lam_energy: float = 1e-3
    entropy_weight: float = 1.0
    pdhg_primal: float = 0.8
    pdhg_dual: float = 0.8
    pdhg_theta: float = 1.0
    pdhg_iters: int = 50
    mu_prox_iters: int = 6
    mu_prox_lr: float = 0.5
    # NEW:
    kkt_tol: float = 1e-3
    check_every: int = 10

def _M_times_DA_f(G: GraphSnapshot, A: torch.Tensor, f: torch.Tensor):
    if G.m == 0:
        return torch.zeros(G.n, device=f.device)
    u, v = G.edge_u, G.edge_v
    Af = A * f
    out = torch.zeros(G.n, device=f.device)
    out.index_add_(0, u, Af)   # + on sources
    out.index_add_(0, v, -Af)  # - on destinations
    return out

def _DA_Mt_vec(G: GraphSnapshot, A: torch.Tensor, y: torch.Tensor):
    if G.m == 0:
        return torch.zeros(0, device=y.device)
    u, v = G.edge_u, G.edge_v
    return A * (y[u] - y[v])

def _agg_mean_std(vals):
    import numpy as np
    arr = [v for v in vals if v is not None]
    if not arr:
        return float("nan"), float("nan"), 0
    a = np.array(arr, dtype=float)
    return float(a.mean()), float(a.std(ddof=0)), int(a.size)


def bench_multi_seed(
    name,
    seeds=(1337, 2027, 3109, 4441, 5557),
    n_bins_total=40,
    eval_every=6
):
    runs = []
    for s in seeds:
        print(f"\n====== {name} | seed={s} ======")
        r = run_paper(
            name=name,
            seed=int(s),
            n_bins_total=n_bins_total,
            eval_every=eval_every
        )
        runs.append(r)

    # Collect metrics safely
    jodie_val_mrr  = [r["VAL"]["JODIE"].get("MRR")  if r["VAL"]["JODIE"] else None for r in runs]
    jodie_val_h10  = [r["VAL"]["JODIE"].get("H@10") if r["VAL"]["JODIE"] else None for r in runs]
    jodie_test_mrr = [r["TEST"]["JODIE"].get("MRR") if r["TEST"]["JODIE"] else None for r in runs]
    jodie_test_h10 = [r["TEST"]["JODIE"].get("H@10")if r["TEST"]["JODIE"] else None for r in runs]

    tgat_val_auc   = [r["VAL"]["TGAT"].get("AUC")   if r["VAL"]["TGAT"] else None for r in runs]
    tgat_val_ap    = [r["VAL"]["TGAT"].get("AP")    if r["VAL"]["TGAT"] else None for r in runs]
    tgat_test_auc  = [r["TEST"]["TGAT"].get("AUC")  if r["TEST"]["TGAT"] else None for r in runs]
    tgat_test_ap   = [r["TEST"]["TGAT"].get("AP")   if r["TEST"]["TGAT"] else None for r in runs]

    def fmt(m, s, n): 
        return f"{m:.4f} ± {s:.4f}  (n={n})"

    jm, js, jn = _agg_mean_std(jodie_val_mrr);  hm, hs, hn = _agg_mean_std(jodie_val_h10)
    print(f"\n[{name}] VAL (JODIE):   MRR {fmt(jm,js,jn)}   H@10 {fmt(hm,hs,hn)}")

    jm, js, jn = _agg_mean_std(jodie_test_mrr); hm, hs, hn = _agg_mean_std(jodie_test_h10)
    print(f"[{name}] TEST (JODIE):  MRR {fmt(jm,js,jn)}   H@10 {fmt(hm,hs,hn)}")

    am, as_, an = _agg_mean_std(tgat_val_auc);  pm, ps, pn = _agg_mean_std(tgat_val_ap)
    print(f"[{name}] VAL (TGAT):    AUC {fmt(am,as_,an)}   AP {fmt(pm,ps,pn)}")

    am, as_, an = _agg_mean_std(tgat_test_auc); pm, ps, pn = _agg_mean_std(tgat_test_ap)
    print(f"[{name}] TEST (TGAT):   AUC {fmt(am,as_,an)}   AP {fmt(pm,ps,pn)}")

    return runs



def _np_to_torch(dct, key, device):
    if key in dct:
        return torch.from_numpy(dct[key]).to(device)
    alt = key[:-1] if key.endswith("s") else key + "s"
    if alt in dct:
        return torch.from_numpy(dct[alt]).to(device)
    raise KeyError(f"{key} not in dataset")

def _time_splits_by_quantile(t, q_train=0.70, q_val=0.85):
    t = t.float(); thr_train = torch.quantile(t, q_train).item(); thr_val = torch.quantile(t, q_val).item()
    train = (t <= thr_train); val = (t > thr_train) & (t <= thr_val); test = (t > thr_val)
    return train, val, test, thr_train, thr_val

def build_item_index_map(num_nodes, items_idx, device):
    item_id_to_index = torch.full((num_nodes,), -1, dtype=torch.long, device=device)
    item_id_to_index[items_idx] = torch.arange(items_idx.numel(), device=device)
    return item_id_to_index

def _acc_from_item_ranks(item_ranks: torch.Tensor, pos_idx: torch.Tensor):
    r = item_ranks[pos_idx].to(torch.float32)
    rr  = (1.0 / r).sum().item()
    hit = (r <= REPORT_HITSK).sum().item()
    return rr, hit, pos_idx.numel

def _anneal_temp(step, total):
    import math
    t0, t1 = 0.35, 0.20
    c = 0.5 * (1 + math.cos(math.pi * min(1.0, step / max(1,total))))
    return t1 + (t0 - t1) * c


class LogicJudge:
    """
    Finite-state judge that selects which barriers to activate.  It supports
    violation thresholds (epsilons), graph-drift guards, hysteresis (dwell
    counters), and escalation/de-escalation of barrier weights.  Train and
    test-time behaviour can differ via the `mode` attribute.
    """
    def __init__(self, eps_tv=0.0, eps_bell=0.0, eps_gtv=0.0, eps_tri=0.0, eps_flow=0.0,
                 base={'tv':0.02, 'graph_tv':0.01, 'bellman':0.02},
                 max_rho={'tv':0.2, 'graph_tv':0.1, 'bellman':0.2},
                 dwell=3, escalate=1.5, top_k=2, mode='train'):
        self.eps = {'tv': eps_tv, 'graph_tv': eps_gtv, 'bellman': eps_bell,
                    'triangle': eps_tri, 'flow': eps_flow}
        self.rho = dict(base)
        self.base = dict(base)
        self.max_rho = dict(max_rho)
        self.dwell = dwell
        self.escalate = escalate
        self.top_k = top_k
        self.cool = {k: 0 for k in self.rho}
        self.mode = mode  # 'train' or 'test'
        # Graph-drift threshold per barrier; user can set via eps dictionary
        self.prev_L: Optional[torch.Tensor] = None
    def _graph_laplacian(self, G: GraphSnapshot) -> torch.Tensor:
        """
        Compute a (simplified) unsigned graph Laplacian for drift detection.
        The Laplacian is L = D - W where W counts edges between nodes.  For
        multi-edges, counts accumulate.  Edge weights A are ignored for drift.
        """
        n = G.n
        L = torch.zeros((n, n), device=device)
        if G.m > 0:
            u, v = G.edge_u, G.edge_v
            # increment adjacency counts
            for src, dst in zip(u.tolist(), v.tolist()):
                L[src, dst] -= 1.0
                L[dst, src] -= 1.0
                L[src, src] += 1.0
                L[dst, dst] += 1.0
        return L
    def select(self, barrier_metrics: Dict[str, float], G: Optional[GraphSnapshot] = None) -> Dict[str, float]:
        """
        Decide which barriers to activate based on current violation margins,
        graph drift, hysteresis, and the top_k policy.  Returns a dictionary
        mapping active barrier names to their current weight rho.
        """
        # Compute normalized margins (violation minus epsilon)
        margins = {k: max(0.0, barrier_metrics.get(k, 0.0) - self.eps.get(k, 0.0))
                   for k in self.rho}
        # Optional: check graph drift and force activation
        if G is not None:
            L = self._graph_laplacian(G)
            if self.prev_L is None:
                self.prev_L = L.detach().clone()
            else:
                drift = torch.norm(L - self.prev_L).item()
                # Example: if drift is large, activate triangle and flow barriers
                # (user can define other policies)
                if drift > 1.0 and 'triangle' in self.rho:
                    margins['triangle'] = max(margins.get('triangle', 0.0), drift)
                if drift > 1.0 and 'flow' in self.rho:
                    margins['flow'] = max(margins.get('flow', 0.0), drift)
                self.prev_L = L.detach().clone()
        # Rank margins descending and select top_k active barriers
        ranked = sorted(margins.items(), key=lambda x: x[1], reverse=True)
        active = [k for k, v in ranked[:self.top_k] if v > 0]
        for k in self.rho:
            if k in active:
                self.cool[k] += 1
                if self.cool[k] >= self.dwell:
                    # Escalate weight up to max_rho
                    self.rho[k] = min(self.max_rho.get(k, self.rho[k]), self.rho[k] * self.escalate)
                    self.cool[k] = 0
            else:
                self.cool[k] = max(0, self.cool[k] - 1)
                self.rho[k] = max(self.base.get(k, 0.0), self.rho[k] / self.escalate)
        # In test mode, only return barriers (no classification head)
        if self.mode == 'test':
            # Always return something to avoid empty barrier sets; default to base
            return {k: self.base.get(k, 0.0) for k in self.rho if k in active or self.base.get(k, 0.0) > 0.0}
        else:
            return {k: self.rho[k] for k in active if self.rho[k] > 0.0}


class UserItemReadout(nn.Module):
    def __init__(self, num_users, num_items, d=192, temp_init=0.35, gamma_init=0.95, pdrop=0.05):
        super().__init__()
        self.user = nn.Embedding(num_users, d)
        self.item = nn.Embedding(num_items, d)
        self.u_bias = nn.Embedding(num_users, 1)
        self.i_bias = nn.Embedding(num_items, 1)
        self.ln_u = nn.LayerNorm(d); self.ln_i = nn.LayerNorm(d)
        self.proj_u = nn.Linear(d, d); self.proj_i = nn.Linear(d, d)
        self.dropout = nn.Dropout(pdrop)
        self.temp = nn.Parameter(torch.tensor(float(temp_init)))
        self.gamma = nn.Parameter(torch.tensor(float(gamma_init)))
        # NEW: per-user gate for prior
        self.gate_u = nn.Linear(d, 1)

    def logits_over_items(self, u_ids: torch.Tensor, item_ids: torch.Tensor):
        hu = self.user(u_ids); hi = self.item(item_ids)
        hu = self.dropout(F.gelu(self.proj_u(self.ln_u(hu))))
        hi = self.dropout(F.gelu(self.proj_i(self.ln_i(hi))))
        scores = (hu @ hi.t()) / self.temp.clamp_min(0.05)
        scores = scores + self.u_bias(u_ids) + self.i_bias(item_ids).t()
        # gamma_eff = gamma * sigmoid(w^T h_u)
        g = torch.sigmoid(self.gate_u(hu)).squeeze(-1)  # [B]
        return scores, g  # return gate for mixing with prior


# -----------------------------------------------------------------------------
# Main runner omitted for brevity.  In practice, the run_dataset function
# integrates dataset loading, splitting, per-bin prefix construction, model
# instantiation, training loop, and evaluation.  The logic remains similar
# to the original code but must now incorporate the revised judge, barrier
# suite, and penalty terms.

def _build_X_nodes(in_hist, out_hist, num_nodes, device):
    # paper features: [in, out, tot, log1p(tot), 1]
    tot = in_hist + out_hist
    X_raw = torch.stack([in_hist, out_hist, tot, torch.log1p(tot), torch.ones_like(tot)], dim=1)
    return torch.nan_to_num(X_raw, 0.0, 0.0, 0.0)

@torch.no_grad()
def _pack_scores(mu_items, pos_idx):
    # pessimistic ties
    N = mu_items.numel()
    mu_items = mu_items.clamp_min(1e-12); mu_items = mu_items / mu_items.sum().clamp_min(1e-12)
    ranks = pessimistic_item_ranks(mu_items)
    r = ranks[pos_idx].float()
    rr  = (1.0 / r).sum().item()
    hit = (r <= REPORT_HITSK).sum().item()
    return rr, hit, pos_idx.numel()



@torch.no_grad()
def eval_ap_auc(bins, num_negs=100):
    from math import isnan
    total_ap, total_auc, N = 0.0, 0.0, 0

    # strict-past state
    in_h = torch.zeros(num_nodes, device=dev); out_h = torch.zeros(num_nodes, device=dev)
    mu_prev_eval = project_to_simplex(torch.ones(num_nodes, device=dev).unsqueeze(0)).squeeze(0)

    for bi in bins:
        # prefix graph up to bi-1
        if bi == 0:
            X = _build_X_nodes(in_h, out_h, num_nodes, dev)
            Gp = GraphSnapshot(num_nodes, 0, torch.empty(2,0, dtype=torch.long, device=dev),
                               torch.empty(0, device=dev), torch.empty(0, device=dev), X,
                               edge_u=torch.empty(0,dtype=torch.long,device=dev),
                               edge_v=torch.empty(0,dtype=torch.long,device=dev))
        else:
            m = int(_prepare_prefix(src, dst, t, cuts, num_nodes, dev).m_at_cut[bi].item())
            u = src[:m]; v = dst[:m]
            X = _build_X_nodes(in_h, out_h, num_nodes, dev)
            Gp = GraphSnapshot(num_nodes, m, torch.empty(2,0, dtype=torch.long, device=dev),
                               torch.empty(0, device=dev), torch.empty(0, device=dev), X,
                               edge_u=u.long(), edge_v=v.long())

        phi_eval, A_eval = model.forward_energy_params(Gp)

        bar_metrics = barriers.metrics(
            mu_prev_eval, Gp, A_eval,
            active={'tv':1.0, 'graph_tv':1.0, 'bellman':1.0}
        )
        #active = judge.select(bar_metrics, Gp)

        active = {}  # Phase 1: NO barriers

        
        mu_eval, _, _, _ = solver.jko_step(Gp, mu_prev_eval, phi_eval, A_eval, barriers, active)
 
        
        mu_prev_eval = mu_eval.detach()

        eidx = bins_edges[bi]
        if not eidx.numel():
            continue

        # candidates available so far
        m_prefix = int(_prepare_prefix(src, dst, t, cuts, num_nodes, dev).m_at_cut[bi].item())
        items_now_all = torch.unique(dst[:m_prefix]).to(dev)
        items_now = item_id_to_index[items_now_all]
        mask = (items_now >= 0); items_now = items_now[mask]; items_now_all = items_now_all[mask]
        if items_now.numel() < 2:
            continue

        u_b = user_id_to_index[src[eidx]]
        v_b = item_id_to_index[dst[eidx]]
        keep = (u_b >= 0) & (v_b >= 0)
        u_b = u_b[keep]; v_b = v_b[keep]
        if u_b.numel() == 0:
            continue

        # scores for positives
        prior = -phi_eval[items_now_all]                         # [I]
        logits_all = readout.logits_over_items(u_b, items_now) + readout.gamma * prior.unsqueeze(0)  # [B, I]
        
        pos_scores = logits_all[torch.arange(u_b.numel(), device=dev), v_b]

        # sample negatives and compute AUC/AP per example
        for b in range(u_b.numel()):
            # sample negatives not equal to v_b[b]
            neg_pool = torch.arange(items_now.numel(), device=dev)
            neg_pool = neg_pool[neg_pool != v_b[b]]
            if neg_pool.numel() == 0:
                continue
            pick = neg_pool[torch.randint(0, neg_pool.numel(), (min(num_negs, neg_pool.numel()),), device=dev)]
            s_pos = pos_scores[b]
            s_neg = logits_all[b, pick]

            # AUC: P(score_pos > score_neg)
            auc = (s_pos > s_neg).float().mean().item()
            # AP with one positive reduces to precision@rank = 1/rank
            rank = ( (s_neg > s_pos).sum().item() + 1 )
            ap = 1.0 / rank

            if not isnan(auc):
                total_auc += auc
            total_ap  += ap
            N += 1

        # advance strict-past hist
        in_h  += torch.bincount(dst[eidx], minlength=num_nodes).float().to(dev)
        out_h += torch.bincount(src[eidx], minlength=num_nodes).float().to(dev)

    return {"AUC": (total_auc / N) if N else float('nan'),
            "AP":  (total_ap  / N) if N else float('nan'),
            "Count": N}


def run_paper(name="csv:jodie-wikipedia", seed=1337, n_bins_total=40, eval_every=6):
    # --- setup ---
    set_seed(seed)
    dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ds = _load_any_dataset((name,), root="datasets")
    D = ds.full_data
    src = _np_to_torch(D, "sources", dev).long()
    dst = _np_to_torch(D, "destinations", dev).long()
    t   = _np_to_torch(D, "timestamps", dev).long()

    num_nodes = int(max(src.max(), dst.max()).item() + 1)

    # user/item universes and id→index maps (for readout)
    users_idx = torch.unique(src).to(dev)
    items_idx = torch.unique(dst).to(dev)
    user_id_to_index = build_item_index_map(num_nodes, users_idx, dev)
    item_id_to_index = build_item_index_map(num_nodes, items_idx, dev)

    # --- model & readout ---
    readout = UserItemReadout(
        num_users=users_idx.numel(),
        num_items=items_idx.numel(),
        d=192,
        temp_init=0.35,
        gamma_init=0.20,   # start small; will ramp via schedule
        pdrop=0.15
    ).to(dev)

    # time binning (edge-count based; strict-past)
    cuts = _compute_cuts_by_edgecount(t, n_bins_total)
    bins_edges = _edges_grouped_by_bin(t, cuts)
    B = len(bins_edges)
    ntr = max(2, int(0.70 * B)); nva = max(1, int(0.15 * B))
    tr  = list(range(0, ntr))
    va  = list(range(ntr, ntr + nva))
    te  = list(range(ntr + nva, B))

    # strict-past histories
    in_hist = torch.zeros(num_nodes, device=dev)
    out_hist= torch.zeros(num_nodes, device=dev)
    X0 = _build_X_nodes(in_hist, out_hist, num_nodes, dev)
    x_dim = X0.shape[1]

    # Geometry model: start with A frozen; unfreeze later
    model = GLFModel(
        x_dim=x_dim, edge_feat_dim=6, hidden_phi=64, hidden_A=64,
        learn_A=False,            # freeze first
        phi_scale=2.0, hodge_reg=0.0
    ).to(dev)
    for p in model.A_net.parameters():
        p.requires_grad = False

    # Solver: settings that allow non-trivial movement
    solver = GLF_PDHG(GLFConfig(
        tau=0.08,
        entropy_weight=0.003,
        lam_energy=1e-3,
        pdhg_primal=0.9,
        pdhg_dual=0.9,
        pdhg_theta=1.0,
        pdhg_iters=600,
        mu_prox_iters=120,
        mu_prox_lr=1.2,
        kkt_tol=1e-4,
        check_every=10
    ))

    # Barriers: off during train; tiny ones at eval only
    barriers = BarrierSuite(bellman_sample_m=256, tv_smooth=1e-3)

    # Warmups & schedules
    KL_WEIGHT_MAX   = 0.10
    KL_WARMUP_STEPS = max(100, len(tr) // 2)
    A_WARMUP_STEPS  = max(int(0.35 * len(tr)), 40)

    # Optimiser + scheduler
    opt = torch.optim.AdamW(
        [
            {"params": model.parameters(),   "lr": 2e-4, "weight_decay": 1e-4},
            {"params": readout.parameters(), "lr": 3e-3, "weight_decay": 1e-4},
        ]
    )
    total_steps = max(200, len(tr))
    warmup = max(50, total_steps // 10)
    cos = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=1e-5)

    def gamma_schedule(step, total_steps, g0=0.20, g1=0.90):
        import math
        c = min(1.0, step / max(1, total_steps))
        return g1 - (g1 - g0) * 0.5 * (1 + math.cos(math.pi * c))

    def lr_step(step):
        if step < warmup:
            scale = (step + 1) / float(warmup)
            for g in opt.param_groups:
                g['lr'] = g['lr'] * scale
        else:
            cos.step()
        with torch.no_grad():
            # anneal temp (uses your global helper)
            readout.temp.copy_(torch.tensor(_anneal_temp(step, len(tr))))
            # ramp gamma
            readout.gamma.copy_(torch.tensor(gamma_schedule(step, len(tr))))

    def _head_scores(readout, u_ids, item_ids, prior_tensor=None):
        """
        Works whether readout.logits_over_items returns (scores, gate) or scores.
        If prior_tensor is provided (shape [I]), it is added with gamma (and gate if available).
        """
        out = readout.logits_over_items(u_ids, item_ids)
        if isinstance(out, tuple):
            scores, g_u = out  # [B, I], [B]
            if prior_tensor is None:
                return scores
            gamma_eff = torch.clamp(readout.gamma, 0.0, 1.0) * g_u.unsqueeze(1)
            return scores + gamma_eff * prior_tensor.unsqueeze(0)
        else:
            scores = out
            if prior_tensor is None:
                return scores
            return scores + torch.clamp(readout.gamma, 0.0, 1.0) * prior_tensor.unsqueeze(0)

    prefix = _prepare_prefix(src, dst, t, cuts, num_nodes, dev)


    @torch.no_grad()
    def _semi_hard_cands(items_all_local, zscore_prior, pos_local_keep, max_items=512, hard_k=256):
        """
        Robust candidate picker:
        - pos_local_keep is already LOCAL indices in [0..I-1]; clamp just in case.
        - always return a subset of items_all_local with valid indices only.
        """
        I = int(items_all_local.numel())
        if I <= max_items:
            return items_all_local

        if pos_local_keep.numel():
            pos_local_keep = pos_local_keep[(pos_local_keep >= 0) & (pos_local_keep < I)]

        hard_k = int(min(max_items, max(64, hard_k)))
        # zscore_prior is aligned to items_all_local; guard length
        if zscore_prior.numel() != I:
            # fallback: uniform priorities
            zscore_prior = torch.zeros(I, device=items_all_local.device)
        k_eff = min(hard_k, I)
        _, top_idx = torch.topk(zscore_prior, k=k_eff, largest=True, sorted=False)

        keep = torch.zeros(I, dtype=torch.bool, device=items_all_local.device)
        if top_idx.numel():
            keep[top_idx] = True
        if pos_local_keep.numel():
            keep[pos_local_keep] = True

        need = max_items - int(keep.sum().item())
        if need > 0:
            pool = torch.nonzero(~keep, as_tuple=False).squeeze(1)
            if pool.numel():
                perm = torch.randperm(pool.numel(), device=pool.device)[:need]
                pick = pool[perm]
                keep[pick] = True

        sel = torch.nonzero(keep, as_tuple=False).squeeze(1)
        # safety: ensure we output valid indices
        sel = sel[(sel >= 0) & (sel < I)]
        return items_all_local[sel]

    @torch.no_grad()
    def eval_jodie_mrr_recall_all(
        bins: list,
        hist_end: int,
        *,
        dev: torch.device,
        num_nodes: int,
        src: torch.Tensor, dst: torch.Tensor,
        items_idx: torch.Tensor,
        user_id_to_index: torch.Tensor,
        item_id_to_index: torch.Tensor,
        bins_edges: list,
        prefix, model, solver, barriers
        ):
        # strict-past hist up to eval window start
        in_h = torch.zeros(num_nodes, device=dev)
        out_h = torch.zeros(num_nodes, device=dev)
        for b in range(hist_end):
            eidx = bins_edges[b]
            if eidx.numel():
                in_h  += torch.bincount(dst[eidx], minlength=num_nodes).float().to(dev)
                out_h += torch.bincount(src[eidx], minlength=num_nodes).float().to(dev)

        mu_prev = project_to_simplex(torch.ones(num_nodes, device=dev).unsqueeze(0)).squeeze(0)
        items_all_local  = torch.arange(items_idx.numel(), device=dev)
        items_all_global = items_idx

        rr_sum = 0.0; hit_sum = 0; cnt = 0
        for bi in bins:
            # prefix graph up to this bin
            if bi == 0:
                X = _build_X_nodes(in_h, out_h, num_nodes, dev)
                Gp = GraphSnapshot(num_nodes, 0,
                                   torch.empty(2,0, dtype=torch.long, device=dev),
                                   torch.empty(0, device=dev),
                                   torch.empty(0, device=dev),
                                   X,
                                   edge_u=torch.empty(0,dtype=torch.long,device=dev),
                                   edge_v=torch.empty(0,dtype=torch.long,device=dev))
            else:
                m = int(prefix.m_at_cut[bi].item())
                u = src[:m]; v = dst[:m]
                X = _build_X_nodes(in_h, out_h, num_nodes, dev)
                Gp = GraphSnapshot(num_nodes, m,
                                   torch.empty(2,0, dtype=torch.long, device=dev),
                                   torch.empty(0, device=dev),
                                   torch.empty(0, device=dev),
                                   X,
                                   edge_u=u.long(), edge_v=v.long())
            phi, A = model.forward_energy_params(Gp)
            mu_t, _, _, _ = solver.jko_step(Gp, mu_prev, phi, A, barriers, ACTIVE_EVAL_BARRIERS)
            mu_prev = mu_t.detach()

            eidx = bins_edges[bi]
            if not eidx.numel():
                continue
            u_b = user_id_to_index[src[eidx]]
            v_b_all = item_id_to_index[dst[eidx]]
            keep = (u_b >= 0) & (v_b_all >= 0)
            u_b = u_b[keep]; v_b_all = v_b_all[keep]
            if u_b.numel() == 0:
                in_h  += torch.bincount(dst[eidx], minlength=num_nodes).float().to(dev)
                out_h += torch.bincount(src[eidx], minlength=num_nodes).float().to(dev)
                continue

            prior = _zscore_prior(-phi[items_all_global])
            logits_all = _head_scores(readout, u_b, items_all_local, prior)
            probs = F.softmax(logits_all, dim=1)
            for b_ in range(u_b.numel()):
                r = int(pessimistic_item_ranks(probs[b_])[v_b_all[b_]].item())
                rr_sum += 1.0 / max(1, r)
                hit_sum += int(r <= REPORT_HITSK)
                cnt += 1

            in_h  += torch.bincount(dst[eidx], minlength=num_nodes).float().to(dev)
            out_h += torch.bincount(src[eidx], minlength=num_nodes).float().to(dev)

        return {"MRR": rr_sum / cnt if cnt else float('nan'),
                "H@10": hit_sum / cnt if cnt else float('nan'),
                "Count": cnt}

    @torch.no_grad()
    def eval_tgat_ap_balanced(
        bins: list,
        *,
        dev: torch.device,
        num_nodes: int,
        src: torch.Tensor, dst: torch.Tensor,
        items_idx: torch.Tensor,
        user_id_to_index: torch.Tensor,
        item_id_to_index: torch.Tensor,
        bins_edges: list,
        prefix, model, solver, barriers,
        num_negs: int = AP_NUM_NEGS
    ):
        in_h = torch.zeros(num_nodes, device=dev)
        out_h = torch.zeros(num_nodes, device=dev)
        mu_prev = project_to_simplex(torch.ones(num_nodes, device=dev).unsqueeze(0)).squeeze(0)
        items_all_local  = torch.arange(items_idx.numel(), device=dev)
        items_all_global = items_idx
        total_ap, total_auc, N = 0.0, 0.0, 0

        for bi in bins:
            if bi == 0:
                X = _build_X_nodes(in_h, out_h, num_nodes, dev)
                Gp = GraphSnapshot(num_nodes, 0,
                                   torch.empty(2,0, dtype=torch.long, device=dev),
                                   torch.empty(0, device=dev),
                                   torch.empty(0, device=dev),
                                   X,
                                   edge_u=torch.empty(0,dtype=torch.long,device=dev),
                                   edge_v=torch.empty(0,dtype=torch.long,device=dev))
            else:
                m = int(prefix.m_at_cut[bi].item())
                u = src[:m]; v = dst[:m]
                X = _build_X_nodes(in_h, out_h, num_nodes, dev)
                Gp = GraphSnapshot(num_nodes, m,
                                   torch.empty(2,0, dtype=torch.long, device=dev),
                                   torch.empty(0, device=dev),
                                   torch.empty(0, device=dev),
                                   X,
                                   edge_u=u.long(), edge_v=v.long())

            phi, A = model.forward_energy_params(Gp)
            mu_eval, _, _, _ = solver.jko_step(Gp, mu_prev, phi, A, barriers, ACTIVE_EVAL_BARRIERS)
            mu_prev = mu_eval.detach()

            eidx = bins_edges[bi]
            if not eidx.numel():
                continue
            u_b = user_id_to_index[src[eidx]]
            v_b_all = item_id_to_index[dst[eidx]]
            keep = (u_b >= 0) & (v_b_all >= 0)
            u_b = u_b[keep]; v_b_all = v_b_all[keep]
            if u_b.numel() == 0:
                in_h  += torch.bincount(dst[eidx], minlength=num_nodes).float().to(dev)
                out_h += torch.bincount(src[eidx], minlength=num_nodes).float().to(dev)
                continue

            prior = _zscore_prior(-phi[items_all_global])
            logits_all = _head_scores(readout, u_b, items_all_local, prior)

            I = items_all_local.numel()
            for b_ in range(u_b.numel()):
                pos_idx = int(v_b_all[b_].item())
                if I <= 1:
                    continue
                if I - 1 <= num_negs:
                    neg_idx = torch.arange(I, device=dev)
                    neg_idx = neg_idx[neg_idx != pos_idx]
                else:
                    pool = torch.cat([torch.arange(0, pos_idx, device=dev),
                                      torch.arange(pos_idx+1, I, device=dev)], dim=0)
                    perm = torch.randperm(pool.numel(), device=dev)[:num_negs]
                    neg_idx = pool[perm]
                s_pos = logits_all[b_, pos_idx]
                s_neg = logits_all[b_, neg_idx]
                gt = (s_pos > s_neg).float()
                eq = (s_pos == s_neg).float()
                auc = (gt + 0.5 * eq).mean().item()
                rank = int((s_neg > s_pos).sum().item()) + 1
                ap = 1.0 / max(1, rank)
                total_auc += auc; total_ap += ap; N += 1

            in_h  += torch.bincount(dst[eidx], minlength=num_nodes).float().to(dev)
            out_h += torch.bincount(src[eidx], minlength=num_nodes).float().to(dev)

        return {"AUC": (total_auc / N) if N else float('nan'),
                "AP":  (total_ap  / N) if N else float('nan'),
                "Count": N}

    # ----------------- quick head pretrain (1 pass over train bins) -----------------
    readout.train(); model.eval()
    with torch.no_grad():
        _ = project_to_simplex(torch.ones(num_nodes, device=dev).unsqueeze(0)).squeeze(0)
    for bi in tr:
        eidx = bins_edges[bi]
        if not eidx.numel():
            continue
        m_prefix = int(prefix.m_at_cut[bi].item())
        items_now_all = torch.unique(dst[:m_prefix]).to(dev)
        items_now = item_id_to_index[items_now_all]
        mask = (items_now >= 0); items_now = items_now[mask]; items_now_all = items_now_all[mask]
        if items_now.numel() < 2:
            continue
        u_b = user_id_to_index[src[eidx]]
        v_b = item_id_to_index[dst[eidx]]
        keep = (u_b >= 0) & (v_b >= 0)
        u_b = u_b[keep]; v_b = v_b[keep]
        if u_b.numel() == 0:
            continue
        logits = _head_scores(readout, u_b, items_now)  # no prior in pretrain
        local_pos = torch.full((items_idx.numel(),), -1, dtype=torch.long, device=dev)
        local_pos[items_now] = torch.arange(items_now.numel(), device=dev)
        targets = local_pos[v_b]
        keep2 = (targets >= 0) & (targets < logits.size(1))
        if keep2.any():
            ce = F.cross_entropy(logits[keep2], targets[keep2], label_smoothing=0.10)
            opt.zero_grad()
            ce.backward()
            torch.nn.utils.clip_grad_norm_(list(readout.parameters()), 1.0)
            opt.step()

    # ----------------- main training -----------------
    model.train(); readout.train()
    mu_prev = project_to_simplex(torch.ones(num_nodes, device=dev).unsqueeze(0)).squeeze(0)
    cache = []  # strict-past per-bin caches for head-only replay

    # Early-stop-by-VAL tracking
    best_val = {"score": -1e9, "blob": None, "step": -1}
    VAL_KEY = ("JODIE", "MRR")  # or ("TGAT","AUC")
    _last_val_jodie, _last_val_tgat = None, None

    def _pack_state():
        return {"model": model.state_dict(), "readout": readout.state_dict(), "opt": opt.state_dict()}

    def _load_state(blob):
        model.load_state_dict(blob["model"])
        readout.load_state_dict(blob["readout"])
        opt.load_state_dict(blob["opt"])

    for step, bi in enumerate(tr, 1):
        lr_step(step)

        # Unfreeze A after warmup
        if (not model.learn_A) and (step == A_WARMUP_STEPS):
            for p in model.A_net.parameters(): p.requires_grad = True
            model.learn_A = True

        # Build prefix snapshot
        X = _build_X_nodes(in_hist, out_hist, num_nodes, dev)
        if bi == 0:
            Gp = GraphSnapshot(num_nodes, 0, torch.empty(2,0, dtype=torch.long, device=dev),
                               torch.empty(0, device=dev), torch.empty(0, device=dev), X,
                               edge_u=torch.empty(0,dtype=torch.long,device=dev),
                               edge_v=torch.empty(0,dtype=torch.long,device=dev))
        else:
            m = int(prefix.m_at_cut[bi].item())
            u = src[:m]; v = dst[:m]
            Gp = GraphSnapshot(num_nodes, m, torch.empty(2,0, dtype=torch.long, device=dev),
                               torch.empty(0, device=dev), torch.empty(0, device=dev), X,
                               edge_u=u.long(), edge_v=v.long())

        phi, A = model.forward_energy_params(Gp)

        # barriers OFF during train (keep memory predictable)
        active = {}
        mu_t, certs, _, _ = solver.jko_step(Gp, mu_prev, phi, A, barriers, active)

        # DEBUG: compute muΔ BEFORE overriding mu_prev
        with torch.no_grad():
            mu_delta = torch.norm(project_to_simplex(mu_t.unsqueeze(0)).squeeze(0) - mu_prev).item()

        if (step % eval_every) == 0:
            print(f"[CERT] kkt={certs['kkt_residual']:.3e}  action={certs['transport_action']:.3f}  "
                  f"dE={certs['delta_energy']:.5f}  bars={certs['barriers']}")
            with torch.no_grad():
                print(f"[DBG] phi std={phi.std().item():.4f}  muΔ={mu_delta:.4f}")

        mu_prev = mu_t.detach()
        mu_t = project_to_simplex(mu_t.unsqueeze(0)).squeeze(0)

        # --- TRAIN READOUT (semi-hard negatives, per-bin) ---
        eidx = bins_edges[bi]
        if eidx.numel():
            m_prefix = int(prefix.m_at_cut[bi].item())
            items_now_all = torch.unique(dst[:m_prefix]).to(dev)
            items_now = item_id_to_index[items_now_all]
            valid_mask = (items_now >= 0)
            items_now = items_now[valid_mask]; items_now_all = items_now_all[valid_mask]

            u_b = user_id_to_index[src[eidx]]
            v_b = item_id_to_index[dst[eidx]]
            keep_uv = (u_b >= 0) & (v_b >= 0)
            u_b = u_b[keep_uv]; v_b = v_b[keep_uv]

            if (u_b.numel() > 0) and (items_now.numel() >= 2):
                prior = _zscore_prior(-phi[items_now_all])
                # map positives to current universe
                local_pos = torch.full((items_idx.numel(),), -1, dtype=torch.long, device=dev)
                local_pos[items_now] = torch.arange(items_now.numel(), device=dev)
                pos_local = local_pos[v_b]

                keep2 = pos_local >= 0
                idx2 = torch.nonzero(keep2, as_tuple=False).squeeze(1)
                u_b_eff = u_b[idx2]; v_b_eff = v_b[idx2]; pos_local_eff = pos_local[idx2]

                if u_b_eff.numel() > 0:
                    # semi-hard candidate subset (keep all positives)
                    cand_items = _semi_hard_cands(items_now, prior, pos_local_eff.unique(),
                                                  max_items=512, hard_k=256)

                    # rebuild mapping after subsample and align again
                    local_pos2 = torch.full((items_idx.numel(),), -1, dtype=torch.long, device=dev)
                    local_pos2[cand_items] = torch.arange(cand_items.numel(), device=dev)
                    pos_local2 = local_pos2[v_b_eff]
                    keep3 = pos_local2 >= 0
                    idx3 = torch.nonzero(keep3, as_tuple=False).squeeze(1)
                    u_b_eff = u_b_eff[idx3]; v_b_eff = v_b_eff[idx3]; pos_local2 = pos_local2[idx3]

                    if (u_b_eff.numel() > 0) and (cand_items.numel() >= 2):
                        cand_items_all = items_idx[cand_items]
                        prior2 = _zscore_prior(-phi[cand_items_all])
                        logits = _head_scores(readout, u_b_eff, cand_items, prior2)
                        I = logits.size(1)
                        keep4 = (pos_local2 >= 0) & (pos_local2 < I)

                        ce = (F.cross_entropy(logits[keep4], pos_local2[keep4], label_smoothing=0.10)
                              if keep4.any() else torch.tensor(0.0, device=dev))
                        # small InfoNCE term
                        logp = F.log_softmax(logits, dim=1)
                        pos_mask = torch.zeros_like(logits, dtype=torch.bool)
                        if keep4.any():
                            r = torch.arange(logits.size(0), device=dev)[keep4]
                            c = pos_local2[keep4]
                            pos_mask[r, c] = True
                        nce = (-(logp[pos_mask]).mean()
                               if pos_mask.any() else torch.tensor(0.0, device=dev))
                        loss = ce + 0.4 * nce + 1e-4 * (torch.clamp(readout.gamma, 0)**2)

                        # optional KL distillation of μ
                        if KL_WEIGHT_MAX > 0 and keep4.any():
                            with torch.no_grad():
                                q = project_to_simplex(mu_t[cand_items_all].unsqueeze(0))
                                q = q / q.sum(dim=1, keepdim=True).clamp_min(1e-12)
                            p_log = F.log_softmax(logits, dim=1)
                            kl = F.kl_div(p_log, q.expand_as(p_log), reduction='batchmean')
                            kl_w = min(1.0, step / float(KL_WARMUP_STEPS)) * (0.5 * KL_WEIGHT_MAX)
                            loss = loss + kl_w * kl

                        # optional smoothing on A
                        if ('EDGE_SMOOTH_WEIGHT' in globals()) and (EDGE_SMOOTH_WEIGHT > 0) and A.numel():
                            loss = loss + EDGE_SMOOTH_WEIGHT * model.edge_smoothness_reg(A, Gp)

                        opt.zero_grad(set_to_none=True)
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(readout.parameters()), 1.0)
                        opt.step()

                        # cache aligned views for replay
                        cache.append({
                            "items_all": cand_items.detach().cpu(),
                            "items_all_global": cand_items_all.detach().cpu(),
                            "u": u_b_eff.detach().cpu(),
                            "v": v_b_eff.detach().cpu(),
                            "prior": prior2.detach().cpu(),
                        })

            # advance strict-past hist
            if eidx.numel():
                in_hist  += torch.bincount(dst[eidx], minlength=num_nodes).float().to(dev)
                out_hist += torch.bincount(src[eidx], minlength=num_nodes).float().to(dev)

        # ---- quick VAL & early-stop tracking ----
        if (step % eval_every) == 0 and va:
            _last_val_jodie, _last_val_tgat = None, None
            if EVAL_PROTOCOL in ("JODIE", "BOTH"):
                res_j = eval_jodie_mrr_recall_all(
                    va, hist_end=va[0],
                    dev=dev, num_nodes=num_nodes,
                    src=src, dst=dst, items_idx=items_idx,
                    user_id_to_index=user_id_to_index,
                    item_id_to_index=item_id_to_index,
                    bins_edges=bins_edges,
                    prefix=prefix, model=model, solver=solver, barriers=barriers
                )
                _last_val_jodie = res_j
                print(f"[VAL:JODIE] {name}  MRR={res_j['MRR']:.5f}  H@10={res_j['H@10']:.5f}  N={res_j['Count']}")

            if EVAL_PROTOCOL in ("TGAT", "BOTH"):
                res_t = eval_tgat_ap_balanced(
                    va,
                    dev=dev, num_nodes=num_nodes,
                    src=src, dst=dst, items_idx=items_idx,
                    user_id_to_index=user_id_to_index,
                    item_id_to_index=item_id_to_index,
                    bins_edges=bins_edges,
                    prefix=prefix, model=model, solver=solver, barriers=barriers,
                    num_negs=AP_NUM_NEGS
                )
                _last_val_tgat = res_t
                print(f"[VAL:TGAT]  {name}  AUC={res_t['AUC']:.4f}  AP={res_t['AP']:.4f}  M={res_t['Count']}")

            # pick selection metric
            score = None
            if VAL_KEY[0] == "JODIE" and _last_val_jodie:
                score = float(_last_val_jodie.get(VAL_KEY[1], float("-inf")))
            elif VAL_KEY[0] == "TGAT" and _last_val_tgat:
                score = float(_last_val_tgat.get(VAL_KEY[1], float("-inf")))
            if (score is not None) and (score > best_val["score"]):
                best_val["score"] = score
                best_val["step"]  = step
                best_val["blob"]  = _pack_state()

    # ---- HEAD-ONLY REPLAY (cached) ----
    model.eval()       # freeze GLF; no JKO calls here
    readout.train()
    for rep in range(2):
        for entry in cache:
            items_now = entry["items_all"].to(dev, non_blocking=True)
            u_b = entry["u"].to(dev, non_blocking=True)
            v_b = entry["v"].to(dev, non_blocking=True)
            prior = entry["prior"].to(dev, non_blocking=True)
            if u_b.numel() == 0 or items_now.numel() < 2:
                continue
            local_pos = torch.full((items_idx.numel(),), -1, dtype=torch.long, device=dev)
            local_pos[items_now] = torch.arange(items_now.numel(), device=dev)
            pos_local = local_pos[v_b]
            keep = pos_local >= 0
            u_b = u_b[keep]; pos_local = pos_local[keep]
            if u_b.numel() == 0: 
                continue
            logits = _head_scores(readout, u_b, items_now, prior)
            I = logits.size(1)
            keep2 = (pos_local >= 0) & (pos_local < I)
            if keep2.any():
                ce = F.cross_entropy(logits[keep2], pos_local[keep2], label_smoothing=0.10)
                opt.zero_grad(set_to_none=True); ce.backward()
                torch.nn.utils.clip_grad_norm_(readout.parameters(), 1.0); opt.step()

    # ---- EXTRA HEAD-ONLY FINETUNE OVER TRAIN BINS (strict-past), GLF frozen ----
    model.eval()
    readout.train()
    HEAD_FT_EPOCHS = 2
    for _ep in range(HEAD_FT_EPOCHS):
        in_h = torch.zeros(num_nodes, device=dev); out_h = torch.zeros(num_nodes, device=dev)
        for bi in tr:
            eidx = bins_edges[bi]
            if not eidx.numel(): 
                continue
            m_prefix = int(prefix.m_at_cut[bi].item())
            items_now_all = torch.unique(dst[:m_prefix]).to(dev)
            items_now = item_id_to_index[items_now_all]
            keep_items = items_now >= 0
            items_now, items_now_all = items_now[keep_items], items_now_all[keep_items]
            if items_now.numel() < 2:
                continue

            # prefix snapshot (for phi prior only)
            X = _build_X_nodes(in_h, out_h, num_nodes, dev)
            if bi == 0:
                Gp = GraphSnapshot(num_nodes, 0, torch.empty(2,0, dtype=torch.long, device=dev),
                                   torch.empty(0, device=dev), torch.empty(0, device=dev), X,
                                   edge_u=torch.empty(0,dtype=torch.long,device=dev),
                                   edge_v=torch.empty(0,dtype=torch.long,device=dev))
            else:
                m = int(prefix.m_at_cut[bi].item())
                Gp = GraphSnapshot(num_nodes, m, torch.empty(2,0, dtype=torch.long, device=dev),
                                   torch.empty(0, device=dev), torch.empty(0, device=dev), X,
                                   edge_u=src[:m].long(), edge_v=dst[:m].long())
            with torch.no_grad():
                phi, _A = model.forward_energy_params(Gp)

            u_b = user_id_to_index[src[eidx]]; v_b = item_id_to_index[dst[eidx]]
            keep = (u_b >= 0) & (v_b >= 0)
            u_b, v_b = u_b[keep], v_b[keep]
            if u_b.numel() == 0:
                in_h  += torch.bincount(dst[eidx], minlength=num_nodes).float().to(dev)
                out_h += torch.bincount(src[eidx], minlength=num_nodes).float().to(dev)
                continue

            prior = _zscore_prior(-phi[items_now_all])
            # cand list with safety
            v_unique_local = torch.unique(v_b)
            v_unique_local = v_unique_local[(v_unique_local >= 0) & (v_unique_local < items_idx.numel())]
            cand_items = _semi_hard_cands(items_now, prior, v_unique_local, max_items=768, hard_k=384)

            local_pos = torch.full((items_idx.numel(),), -1, dtype=torch.long, device=dev)
            local_pos[cand_items] = torch.arange(cand_items.numel(), device=dev)
            pos_local = local_pos[v_b]
            keep2 = pos_local >= 0
            if not keep2.any():
                in_h  += torch.bincount(dst[eidx], minlength=num_nodes).float().to(dev)
                out_h += torch.bincount(src[eidx], minlength=num_nodes).float().to(dev)
                continue

            u_eff = u_b[keep2]; pos_eff = pos_local[keep2]
            prior2 = _zscore_prior(-phi[items_idx[cand_items]])
            logits = _head_scores(readout, u_eff, cand_items, prior2)
            ce = F.cross_entropy(logits, pos_eff, label_smoothing=0.10)
            opt.zero_grad(set_to_none=True); ce.backward()
            torch.nn.utils.clip_grad_norm_(readout.parameters(), 1.0); opt.step()

            in_h  += torch.bincount(dst[eidx], minlength=num_nodes).float().to(dev)
            out_h += torch.bincount(src[eidx], minlength=num_nodes).float().to(dev)

    # ---- restore best-by-VAL before TEST ----
    if best_val["blob"] is not None:
        _load_state(best_val["blob"])
        model.eval(); readout.eval()

    # ---- TEST ----
    _test_jodie, _test_tgat = None, None
    if te:
        if EVAL_PROTOCOL in ("JODIE", "BOTH"):
            _test_jodie = eval_jodie_mrr_recall_all(
                te, hist_end=te[0],
                dev=dev, num_nodes=num_nodes,
                src=src, dst=dst, items_idx=items_idx,
                user_id_to_index=user_id_to_index,
                item_id_to_index=item_id_to_index,
                bins_edges=bins_edges,
                prefix=prefix, model=model, solver=solver, barriers=barriers
            )
            print(f"[TEST:JODIE] {name}  MRR={_test_jodie['MRR']:.5f}  H@10={_test_jodie['H@10']:.5f}  N={_test_jodie['Count']}")
        if EVAL_PROTOCOL in ("TGAT", "BOTH"):
            _test_tgat = eval_tgat_ap_balanced(
                te,
                dev=dev, num_nodes=num_nodes,
                src=src, dst=dst, items_idx=items_idx,
                user_id_to_index=user_id_to_index,
                item_id_to_index=item_id_to_index,
                bins_edges=bins_edges,
                prefix=prefix, model=model, solver=solver, barriers=barriers,
                num_negs=AP_NUM_NEGS
            )
            print(f"[TEST:TGAT ] {name}  AUC={_test_tgat['AUC']:.4f}  AP={_test_tgat['AP']:.4f}  M={_test_tgat['Count']}")

    return {
        "dataset": name,
        "seed": int(seed),
        "VAL": {"JODIE": _last_val_jodie, "TGAT": _last_val_tgat},
        "TEST": {"JODIE": _test_jodie,   "TGAT": _test_tgat},
    }


if __name__ == "__main__":
    bench_multi_seed("csv:jodie-wikipedia")
    bench_multi_seed("csv:jodie-lastfm")
