#!/usr/bin/env python3
import argparse
import json
import time
from pathlib import Path
from typing import Callable, Iterable

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

DEFAULT_NUM_TRIPLETS = [128, 512, 2048, 8192, 32768, 131072]
DEFAULT_SEEDS = [0, 1, 2, 3, 4]
DEFAULT_ALPHAS = [-16.0, -8.0, -4.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]
DEFAULT_BETAS = [0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]

STANDALONE_ROOT = Path(__file__).resolve().parent
RESULTS_DIR = STANDALONE_ROOT / "results"
DATA_DIR = STANDALONE_ROOT / "data"


def resolve_device(device: str) -> str:
    if device == "auto":
        return "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda" and not torch.cuda.is_available():
        print("CUDA not available; falling back to CPU.")
        return "cpu"
    return device


# ----------------------------- Models -----------------------------


class EmbedNet(nn.Module):
    def __init__(self, d: int, hidden: int, embed: int):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(d, hidden), nn.ReLU(), nn.Linear(hidden, embed))

    def forward(self, x: torch.Tensor):
        return self.net(x)


class DotBTModel(nn.Module):
    def __init__(self, d: int, hidden: int, embed: int):
        super().__init__()
        self.f = EmbedNet(d, hidden, embed)

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        return cosine_reward(self.f(x), self.f(y))


def cosine_reward(z1: torch.Tensor, z2: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
    return F.cosine_similarity(z1, z2, dim=-1, eps=eps)


# ----------------------------- Data -----------------------------


def load_or_make_data(
    *,
    seed: int = 0,
    n: int = 200,
    d: int = 32,
    hidden: int = 64,
    embed: int = 64,
    root: Path | str = DATA_DIR,
    force: bool = False,
):
    torch.manual_seed(seed)
    g = torch.Generator().manual_seed(seed)
    root = Path(root)
    out_dir = root / f"seed{seed}_d{d}_h{hidden}_e{embed}"
    out_dir.mkdir(parents=True, exist_ok=True)

    x_path = out_dir / "X.pt"
    emb_path = out_dir / "emb.pt"
    state_path = out_dir / "gt_state.pt"
    meta_path = out_dir / "meta.json"

    exists = all(p.exists() for p in (x_path, emb_path, state_path, meta_path))
    if exists and not force:
        X = torch.load(x_path)
        Z = torch.load(emb_path)
        gt = EmbedNet(d, hidden, embed)
        gt.load_state_dict(torch.load(state_path))
        with open(meta_path, "r") as f:
            meta = json.load(f)
        return X, Z, gt, meta, str(out_dir)

    X = torch.rand(n, d, generator=g)
    gt = EmbedNet(d, hidden, embed)
    with torch.no_grad():
        Z = gt(X)

    torch.save(X, x_path)
    torch.save(Z, emb_path)
    torch.save(gt.state_dict(), state_path)
    meta = {"seed": seed, "n": n, "d": d, "hidden": hidden, "embed": embed}
    with open(meta_path, "w") as f:
        json.dump(meta, f, indent=2)
    return X, Z, gt, meta, str(out_dir)


# ----------------------------- Sampling -----------------------------


def _draw(logits: torch.Tensor, g: torch.Generator):
    probs = torch.softmax(logits, dim=0)
    return torch.multinomial(probs, 1, generator=g).item()


def uniform_p_minus_logits(N: int, device) -> torch.Tensor:
    return torch.zeros(N, N, device=device)


def compute_optimal_p_plus(reward: torch.Tensor, p_minus_logits: torch.Tensor) -> torch.Tensor:
    return p_minus_logits + reward


def sample_triplets_bt_optimal_from_logits(
    p_plus_logits: torch.Tensor,
    p_minus_logits: torch.Tensor,
    num_triplets: int = 5000,
    seed: int = 0,
):
    g = torch.Generator().manual_seed(seed)
    device = p_plus_logits.device
    N = p_plus_logits.size(0)
    anchors = torch.randint(0, N, (num_triplets,), generator=g, device=device)
    pos_idx = torch.empty(num_triplets, dtype=torch.long, device=device)
    neg_idx = torch.empty(num_triplets, dtype=torch.long, device=device)

    for i, a in enumerate(anchors):
        neg_idx[i] = _draw(p_minus_logits[a], g)
        pos_idx[i] = _draw(p_plus_logits[a], g)

    return anchors.cpu(), pos_idx.cpu(), neg_idx.cpu()


# ----------------------------- Training -----------------------------


def bt_loss(model: DotBTModel, x: torch.Tensor, y_pos: torch.Tensor, y_neg: torch.Tensor):
    p = torch.sigmoid(model(x, y_pos) - model(x, y_neg))
    return -(p.clamp_min(1e-9).log()).mean()


def train_bt(
    X: torch.Tensor,
    triplets: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
    *,
    hidden: int,
    embed: int,
    epochs: int = 20,
    bs: int = 128,
    lr: float = 3e-4,
    device: str | None = None,
    weight_decay: float = 0.0,
    val_triplets: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    patience: int | None = None,
    eval_every: int = 1,
    full_batch: bool = False,
) -> tuple[DotBTModel, int]:
    anchors, pos_idx, neg_idx = triplets

    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model = DotBTModel(d=X.size(1), hidden=hidden, embed=embed).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    dataset = TensorDataset(anchors, pos_idx, neg_idx)
    loader = DataLoader(dataset, batch_size=bs if not full_batch else len(dataset), shuffle=not full_batch)

    def _val_loss(triplets: tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        v_anchors, v_pos_idx, v_neg_idx = triplets
        v_ds = TensorDataset(v_anchors, v_pos_idx, v_neg_idx)
        v_loader = DataLoader(v_ds, batch_size=bs if not full_batch else len(v_ds))
        total, count = 0.0, 0
        with torch.no_grad():
            for x_idx, y_pos_idx, y_neg_idx in v_loader:
                x, y_pos, y_neg = X[x_idx].to(device), X[y_pos_idx].to(device), X[y_neg_idx].to(device)
                loss = bt_loss(model, x, y_pos, y_neg)
                total += loss.item() * x.size(0)
                count += x.size(0)
        return total / count if count else float("inf")

    best_state = None
    best_val = float("inf")
    best_epoch = epochs
    since_best = 0

    for epoch in range(epochs):
        for x_idx, y_pos_idx, y_neg_idx in loader:
            x, y_pos, y_neg = X[x_idx].to(device), X[y_pos_idx].to(device), X[y_neg_idx].to(device)
            loss = bt_loss(model, x, y_pos, y_neg)
            opt.zero_grad()
            loss.backward()
            opt.step()

        if val_triplets is not None and (epoch + 1) % eval_every == 0:
            val_loss = _val_loss(val_triplets)
            if val_loss + 1e-4 < best_val:
                best_val = val_loss
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                best_epoch = epoch + 1
                since_best = 0
            else:
                since_best += 1
                if patience is not None and since_best >= patience:
                    break

    if best_state is not None:
        model.load_state_dict(best_state)
    return model, best_epoch


# ----------------------------- Metrics -----------------------------


def triplet_accuracy(
    model: DotBTModel,
    gt_model: DotBTModel,
    X: torch.Tensor,
    triplets: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
    device: str | None = None,
    batch: int = 4096,
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device).eval()
    gt_model = gt_model.to(device).eval()
    anchors, pos_idx, neg_idx = triplets

    Xd = X.to(device)
    total = anchors.numel()
    correct = 0

    with torch.no_grad():
        for start in range(0, total, batch):
            sl = slice(start, min(total, start + batch))
            x = Xd[anchors[sl]]
            y_pos = Xd[pos_idx[sl]]
            y_neg = Xd[neg_idx[sl]]
            gt_diff = gt_model(x, y_pos) - gt_model(x, y_neg)
            pred_diff = model(x, y_pos) - model(x, y_neg)
            correct += (gt_diff * pred_diff >= 0).sum().item()

    return correct / total


def full_triplet_accuracy_all_pairs(
    model: DotBTModel,
    gt_model: DotBTModel,
    X: torch.Tensor,
    Z: torch.Tensor | None = None,
    device: str | None = None,
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    Xd = X.to(device)
    with torch.no_grad():
        if Z is not None:
            Zd = Z.to(device)
            gt_emb = F.normalize(Zd, dim=-1)
        else:
            gt_emb = F.normalize(gt_model.f(Xd), dim=-1)
        model_emb = F.normalize(model.f(Xd), dim=-1)

        sim_gt = gt_emb @ gt_emb.t()
        sim_m = model_emb @ model_emb.t()
        n = X.size(0)
        idx = torch.arange(n, device=device)

        accuracies = []
        for i in range(n):
            sm = sim_m[i]
            sg = sim_gt[i]
            diff_m = sm[:, None] - sm[None, :]
            diff_g = sg[:, None] - sg[None, :]
            j_idx = idx[:, None]
            k_idx = idx[None, :]
            mask = (j_idx != k_idx) & (j_idx != i) & (k_idx != i)
            agree = (diff_m * diff_g >= 0) & mask
            accuracies.append(agree[mask].float().mean())

        return torch.stack(accuracies).mean().item()


def full_triplet_accuracy_hard_pairs(
    model: DotBTModel,
    gt_model: DotBTModel,
    X: torch.Tensor,
    Z: torch.Tensor | None = None,
    hard_quantiles: tuple[float, float] = (0.3, 0.1),
    device: str | None = None,
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    Xd = X.to(device)
    with torch.no_grad():
        if Z is not None:
            gt_emb = F.normalize(Z.to(device), dim=-1)
        else:
            gt_emb = F.normalize(gt_model.f(Xd), dim=-1)
        model_emb = F.normalize(model.f(Xd), dim=-1)

        sim_gt = gt_emb @ gt_emb.t()
        sim_m = model_emb @ model_emb.t()

        n = X.size(0)
        idx = torch.arange(n, device=device)

        per_anchor = []
        margins_all = []
        for i in range(n):
            sg = sim_gt[i]
            sm = sim_m[i]
            diff_g = sg[:, None] - sg[None, :]
            diff_m = sm[:, None] - sm[None, :]
            j_idx = idx[:, None]
            k_idx = idx[None, :]
            mask = (j_idx != k_idx) & (j_idx != i) & (k_idx != i)
            margins = diff_g.abs()
            margins_all.append(margins[mask])
            per_anchor.append((diff_g, diff_m, mask, margins))

        all_margins = torch.cat(margins_all)
        thresholds = {q: torch.quantile(all_margins, q) for q in hard_quantiles}

        accs = {q: [] for q in hard_quantiles}
        for diff_g, diff_m, mask, margins in per_anchor:
            agree = (diff_m * diff_g >= 0) & mask
            for q, thr in thresholds.items():
                hard_mask = mask & (margins <= thr)
                if hard_mask.any():
                    accs[q].append(agree[hard_mask].float().mean())

        return {q: (torch.stack(v).mean().item() if v else float("nan")) for q, v in accs.items()}


# ----------------------------- Connectivity -----------------------------


def build_comparison_graph(p_plus_logits: torch.Tensor, p_minus_logits: torch.Tensor) -> torch.Tensor:
    p_plus = torch.softmax(p_plus_logits, dim=-1)
    p_minus = torch.softmax(p_minus_logits, dim=-1)
    return torch.einsum("ni,nj->nij", p_plus, p_minus) + torch.einsum("ni,nj->nij", p_minus, p_plus)


def _compute_p_tilde(p_plus_logits: torch.Tensor, p_minus_logits: torch.Tensor) -> torch.Tensor:
    m = p_plus_logits.shape[0]
    W = build_comparison_graph(p_plus_logits, p_minus_logits)
    return W / m


def _compute_p_tilde_from_p_minus(p_minus_logits: torch.Tensor, reward: torch.Tensor) -> torch.Tensor:
    p_plus_logits = p_minus_logits + reward
    return _compute_p_tilde(p_plus_logits, p_minus_logits)


def _compute_connectivity_ratio(h: torch.Tensor, p_tilde: torch.Tensor, epsilon: float = 1e-8) -> torch.Tensor:
    h_centered = h - h.mean(dim=1, keepdim=True)
    h_diff = h_centered.unsqueeze(2) - h_centered.unsqueeze(1)
    numerator = (p_tilde * h_diff.pow(2)).sum()
    variance_per_row = h_centered.pow(2).mean(dim=1)
    denominator = variance_per_row.mean() + epsilon
    return numerator / denominator


def _normalize_h_unit_variance(h: torch.Tensor, epsilon: float = 1e-8) -> torch.Tensor:
    h_centered = h - h.mean(dim=1, keepdim=True)
    var_per_row = h_centered.pow(2).mean(dim=1, keepdim=True)
    scale = (var_per_row + epsilon).sqrt()
    return h_centered / scale


def connectivity_degree(
    p_plus_logits: torch.Tensor,
    p_minus_logits: torch.Tensor,
    *,
    X: torch.Tensor,
    hidden: int = 32,
    embed: int = 8,
    num_steps: int = 1000,
    lr: float = 1e-2,
    epsilon: float = 1e-8,
    seed: int = 0,
    verbose: bool = False,
) -> float:
    torch.manual_seed(seed)
    p_tilde = _compute_p_tilde(p_plus_logits, p_minus_logits)
    return _connectivity_dot_bt(p_tilde, X, hidden, embed, num_steps, lr, epsilon, verbose)


def _connectivity_dot_bt(
    p_tilde: torch.Tensor,
    X: torch.Tensor,
    hidden: int,
    embed: int,
    num_steps: int,
    lr: float,
    epsilon: float,
    verbose: bool,
) -> float:
    device = X.device
    d = X.shape[1]
    model_f = DotBTModel(d=d, hidden=hidden, embed=embed).to(device)
    model_g = DotBTModel(d=d, hidden=hidden, embed=embed).to(device)
    optimizer = torch.optim.Adam(list(model_f.parameters()) + list(model_g.parameters()), lr=lr)

    for step in range(num_steps):
        optimizer.zero_grad()
        Z_f = model_f.f(X)
        r_f = F.normalize(Z_f, dim=-1) @ F.normalize(Z_f, dim=-1).t()
        Z_g = model_g.f(X)
        r_g = F.normalize(Z_g, dim=-1) @ F.normalize(Z_g, dim=-1).t()
        h = _normalize_h_unit_variance(r_f - r_g, epsilon)
        ratio = _compute_connectivity_ratio(h, p_tilde, epsilon)
        ratio.backward()
        optimizer.step()
        if verbose and (step % 200 == 0 or step == num_steps - 1):
            print(f"  [dot_bt] Step {step}: ratio = {ratio.item():.6f}")

    with torch.no_grad():
        Z_f = model_f.f(X)
        r_f = F.normalize(Z_f, dim=-1) @ F.normalize(Z_f, dim=-1).t()
        Z_g = model_g.f(X)
        r_g = F.normalize(Z_g, dim=-1) @ F.normalize(Z_g, dim=-1).t()
        h = _normalize_h_unit_variance(r_f - r_g, epsilon)
        return _compute_connectivity_ratio(h, p_tilde, epsilon).item()


def optimize_p_minus_connectivity(
    reward: torch.Tensor,
    *,
    X: torch.Tensor,
    hidden: int = 32,
    embed: int = 8,
    seed: int = 0,
    num_outer_steps: int = 1000,
    num_inner_steps: int = 5,
    lr_outer: float = 1e-4,
    lr_inner: float = 1e-2,
    epsilon: float = 1e-8,
    eval_every: int = 10,
    eval_num_steps: int = 1000,
    eval_lr: float = 1e-2,
    eval_verbose: bool = False,
    verbose: bool = False,
) -> torch.Tensor:
    torch.manual_seed(seed)
    device = reward.device
    m = reward.shape[0]
    d = X.shape[1]

    ell = torch.zeros(m, m, device=device, requires_grad=True)
    model_f = DotBTModel(d=d, hidden=hidden, embed=embed).to(device)
    model_g = DotBTModel(d=d, hidden=hidden, embed=embed).to(device)
    optimizer_inner = torch.optim.Adam(list(model_f.parameters()) + list(model_g.parameters()), lr=lr_inner)
    optimizer_outer = torch.optim.Adam([ell], lr=lr_outer)

    history = []
    best_connectivity = None
    best_ell = None
    for outer_step in range(num_outer_steps):
        for _ in range(num_inner_steps):
            optimizer_inner.zero_grad()
            p_tilde_inner = _compute_p_tilde_from_p_minus(ell.detach(), reward)
            Z_f = model_f.f(X)
            r_f = F.normalize(Z_f, dim=-1) @ F.normalize(Z_f, dim=-1).t()
            Z_g = model_g.f(X)
            r_g = F.normalize(Z_g, dim=-1) @ F.normalize(Z_g, dim=-1).t()
            h = _normalize_h_unit_variance(r_f - r_g, epsilon)
            ratio = _compute_connectivity_ratio(h, p_tilde_inner, epsilon)
            ratio.backward()
            optimizer_inner.step()

        optimizer_outer.zero_grad()
        p_tilde = _compute_p_tilde_from_p_minus(ell, reward)
        with torch.no_grad():
            Z_f = model_f.f(X)
            r_f = F.normalize(Z_f, dim=-1) @ F.normalize(Z_f, dim=-1).t()
            Z_g = model_g.f(X)
            r_g = F.normalize(Z_g, dim=-1) @ F.normalize(Z_g, dim=-1).t()
            h = _normalize_h_unit_variance(r_f - r_g, epsilon)
        ratio = _compute_connectivity_ratio(h, p_tilde, epsilon)
        ratio_value = ratio.item()
        (-ratio).backward()
        optimizer_outer.step()

        if eval_every > 0 and (outer_step % eval_every == 0 or outer_step == num_outer_steps - 1):
            p_tilde_eval = _compute_p_tilde_from_p_minus(ell.detach(), reward)
            current_connectivity = _connectivity_dot_bt(
                p_tilde_eval, X, hidden, embed, eval_num_steps, eval_lr, epsilon, False
            )
            if best_connectivity is None or current_connectivity > best_connectivity:
                best_connectivity = current_connectivity
                best_ell = ell.detach().clone()
            if eval_verbose:
                print(
                    f"step {outer_step} ratio: {ratio_value:.6f}, "
                    f"connectivity: {current_connectivity:.6f}"
                )

        history.append(ratio_value)
        if verbose and (outer_step % 20 == 0 or outer_step == num_outer_steps - 1):
            print(f"Outer step {outer_step}: lambda_conn = {ratio.item():.6f}")

    if best_ell is None:
        best_ell = ell.detach()
    return best_ell


# ----------------------------- Core helpers -----------------------------


def rank_normalize_to_uniform(reward: torch.Tensor) -> torch.Tensor:
    ranks = reward.argsort(dim=-1).argsort(dim=-1).float()
    return 2 * ranks / (reward.shape[-1] - 1) - 1


def compute_reward(
    Z: torch.Tensor,
    *,
    stretch_reward: bool = False,
    reward_scale: float = 1.0,
) -> torch.Tensor:
    Z_norm = F.normalize(Z, dim=-1)
    reward = Z_norm @ Z_norm.t()
    if stretch_reward:
        reward = rank_normalize_to_uniform(reward)
    if reward_scale != 1.0:
        reward = reward * reward_scale
    return reward


def build_p_minus_logits(reward: torch.Tensor, alpha_neg: float = 0.0) -> torch.Tensor:
    p_minus_logits = uniform_p_minus_logits(reward.size(0), reward.device)
    if alpha_neg != 0:
        p_minus_logits = p_minus_logits + alpha_neg * reward
    return p_minus_logits


def initialize_ground_truth(meta: dict, gt_embed: nn.Module) -> DotBTModel:
    gt = DotBTModel(d=meta["d"], hidden=meta["hidden"], embed=meta["embed"])
    gt.f.load_state_dict(gt_embed.state_dict())
    return gt


def build_train_kwargs(args: argparse.Namespace, meta: dict) -> dict:
    return dict(
        hidden=meta["hidden"],
        embed=meta["embed"],
        epochs=args.epochs,
        lr=args.lr,
        weight_decay=args.weight_decay,
        device=resolve_device(args.device),
        bs=args.bs,
        val_triplets=None,
        patience=args.patience,
        eval_every=1,
        full_batch=args.full_batch,
    )


def run_training_sweep(
    X: torch.Tensor,
    Z: torch.Tensor,
    gt: DotBTModel,
    reward: torch.Tensor,
    p_minus_logits: torch.Tensor,
    num_triplets_list: Iterable[int],
    *,
    val_num: int,
    seed: int,
    train_kwargs: dict,
    progress_callback: Callable[[dict], None] | None = None,
    hard_quantiles: tuple[float, float] | None = None,
):
    p_plus_logits = compute_optimal_p_plus(reward, p_minus_logits)
    val_triplets = sample_triplets_bt_optimal_from_logits(
        p_plus_logits, p_minus_logits, num_triplets=val_num, seed=seed + 1234
    )
    train_kwargs = dict(train_kwargs)
    train_kwargs["val_triplets"] = val_triplets

    results = []
    for num_triplets in num_triplets_list:
        train_triplets = sample_triplets_bt_optimal_from_logits(
            p_plus_logits, p_minus_logits, num_triplets=num_triplets, seed=seed
        )
        if train_kwargs["device"] == "cuda":
            torch.cuda.synchronize()
        train_start = time.perf_counter()
        model, best_epoch = train_bt(X, train_triplets, **train_kwargs)
        if train_kwargs["device"] == "cuda":
            torch.cuda.synchronize()
        train_time = time.perf_counter() - train_start

        device = train_kwargs["device"]
        train_acc = triplet_accuracy(model, gt, X, train_triplets, device=device)
        val_acc = triplet_accuracy(model, gt, X, val_triplets, device=device)
        full_acc = full_triplet_accuracy_all_pairs(model, gt, X, Z, device=device)
        hard_acc = None
        very_hard_acc = None
        if hard_quantiles is not None:
            hard_metrics = full_triplet_accuracy_hard_pairs(
                model, gt, X, Z, hard_quantiles=hard_quantiles, device=device
            )
            hard_acc = hard_metrics[hard_quantiles[0]]
            very_hard_acc = hard_metrics[hard_quantiles[1]]
        row = {
            "num_triplets": num_triplets,
            "best_epoch": best_epoch,
            "train_acc": train_acc,
            "val_acc": val_acc,
            "full_acc": full_acc,
            "train_time_sec": train_time,
        }
        if hard_quantiles is not None:
            row["hard_acc"] = hard_acc
            row["very_hard_acc"] = very_hard_acc
        results.append(row)
        if progress_callback is not None:
            progress_callback(results[-1])
    return results


# ----------------------------- Experiments -----------------------------


def _setup_experiment(args, seed):
    X, Z, gt_embed, meta, _ = load_or_make_data(
        seed=seed, n=args.n, d=args.d,
        hidden=args.hidden, embed=args.embed, force=args.force_data,
    )
    gt = initialize_ground_truth(meta, gt_embed)
    train_kwargs = build_train_kwargs(args, meta)
    return X, Z, gt, meta, train_kwargs


def _maybe_compute_connectivity(args, reward, p_minus_logits, X, meta):
    if not getattr(args, 'compute_connectivity', False):
        return None
    p_plus_logits = compute_optimal_p_plus(reward, p_minus_logits)
    return connectivity_degree(
        p_plus_logits, p_minus_logits,
        X=X, hidden=meta["hidden"], embed=meta["embed"]
    )


def run_margin_experiment(args: argparse.Namespace) -> dict:
    results = []
    for seed in args.seeds:
        X, Z, gt, meta, train_kwargs = _setup_experiment(args, seed)
        for variant, stretch in [("baseline", False), ("rank", True)]:
            reward = compute_reward(Z, stretch_reward=stretch, reward_scale=args.reward_scale)
            p_minus_logits = build_p_minus_logits(reward, alpha_neg=0.0)
            rows = run_training_sweep(
                X,
                Z,
                gt,
                reward,
                p_minus_logits,
                args.num_triplets,
                val_num=args.val_num,
                seed=seed,
                train_kwargs=train_kwargs,
                hard_quantiles=(0.3, 0.1),
            )
            for row in rows:
                results.append(
                    {
                        "seed": seed,
                        "variant": variant,
                        "num_triplets": row["num_triplets"],
                        "train_acc": row["train_acc"],
                        "val_acc": row["val_acc"],
                        "full_acc": row["full_acc"],
                        "hard_acc": row.get("hard_acc"),
                        "very_hard_acc": row.get("very_hard_acc"),
                    }
                )

    return {
        "experiment": "margin",
        "config": {
            "seeds": args.seeds,
            "num_triplets": args.num_triplets,
            "reward_scale": args.reward_scale,
            "hard_quantiles": [0.3, 0.1],
            "data": {"n": args.n, "d": args.d, "hidden": args.hidden, "embed": args.embed},
        },
        "results": results,
    }


def run_connectivity_alpha_experiment(args: argparse.Namespace) -> dict:
    results = []
    for seed in args.seeds:
        X, Z, gt, meta, train_kwargs = _setup_experiment(args, seed)
        base_reward = compute_reward(Z, stretch_reward=False, reward_scale=args.reward_scale)

        for alpha in args.alphas:
            reward = base_reward
            p_minus_logits = build_p_minus_logits(reward, alpha_neg=alpha)
            rows = run_training_sweep(
                X,
                Z,
                gt,
                reward,
                p_minus_logits,
                args.num_triplets,
                val_num=args.val_num,
                seed=seed,
                train_kwargs=train_kwargs,
            )
            connectivity = _maybe_compute_connectivity(args, reward, p_minus_logits, X, meta)
            for row in rows:
                results.append({
                    "seed": seed, "alpha": alpha, "num_triplets": row["num_triplets"],
                    "train_acc": row["train_acc"], "val_acc": row["val_acc"],
                    "full_acc": row["full_acc"], "connectivity": connectivity,
                })

    return {
        "experiment": "connectivity_alpha",
        "config": {
            "alphas": args.alphas,
            "seeds": args.seeds,
            "num_triplets": args.num_triplets,
            "reward_scale": args.reward_scale,
            "compute_connectivity": args.compute_connectivity,
            "data": {"n": args.n, "d": args.d, "hidden": args.hidden, "embed": args.embed},
        },
        "results": results,
    }


def run_optimize_connectivity_experiment(args: argparse.Namespace) -> dict:
    results = []
    variants = ["baseline"] + (["optimized"] if args.include_optimized else [])
    for seed in args.seeds:
        X, Z, gt, meta, train_kwargs = _setup_experiment(args, seed)
        base_reward = compute_reward(Z, stretch_reward=False, reward_scale=1.0)

        for beta in args.betas:
            reward = base_reward * beta
            for variant in variants:
                if variant == "baseline":
                    p_minus_logits = build_p_minus_logits(reward, alpha_neg=0.0)
                else:
                    p_minus_logits = optimize_p_minus_connectivity(
                        reward,
                        seed=seed,
                        X=X,
                        hidden=meta["hidden"],
                        embed=meta["embed"],
                        num_outer_steps=args.opt_outer_steps,
                        num_inner_steps=args.opt_inner_steps,
                        lr_outer=args.opt_lr_outer,
                        lr_inner=args.opt_lr_inner,
                        eval_every=args.opt_eval_every,
                        eval_num_steps=args.opt_eval_steps,
                        eval_lr=args.opt_eval_lr,
                    )

                rows = run_training_sweep(
                    X,
                    Z,
                    gt,
                    reward,
                    p_minus_logits,
                    args.num_triplets,
                    val_num=args.val_num,
                    seed=seed,
                    train_kwargs=train_kwargs,
                )

                connectivity = _maybe_compute_connectivity(args, reward, p_minus_logits, X, meta)
                for row in rows:
                    results.append({
                        "seed": seed, "beta": beta, "variant": variant,
                        "num_triplets": row["num_triplets"], "train_acc": row["train_acc"],
                        "val_acc": row["val_acc"], "full_acc": row["full_acc"],
                        "connectivity": connectivity,
                    })

    return {
        "experiment": "optimize_connectivity",
        "config": {
            "betas": args.betas,
            "seeds": args.seeds,
            "num_triplets": args.num_triplets,
            "include_optimized": args.include_optimized,
            "compute_connectivity": args.compute_connectivity,
            "opt_outer_steps": args.opt_outer_steps,
            "opt_inner_steps": args.opt_inner_steps,
            "opt_lr_outer": args.opt_lr_outer,
            "opt_lr_inner": args.opt_lr_inner,
            "opt_eval_every": args.opt_eval_every,
            "opt_eval_steps": args.opt_eval_steps,
            "opt_eval_lr": args.opt_eval_lr,
            "data": {"n": args.n, "d": args.d, "hidden": args.hidden, "embed": args.embed},
        },
        "results": results,
    }


# ----------------------------- CLI -----------------------------


def add_data_args(p: argparse.ArgumentParser) -> None:
    p.add_argument("--seeds", type=int, nargs="+", default=DEFAULT_SEEDS)
    p.add_argument("--n", type=int, default=16)
    p.add_argument("--d", type=int, default=128)
    p.add_argument("--hidden", type=int, default=32)
    p.add_argument("--embed", type=int, default=8)
    p.add_argument("--force-data", action="store_true", help="Regenerate synthetic data for each seed.")


def add_training_args(p: argparse.ArgumentParser) -> None:
    p.add_argument("--num-triplets", type=int, nargs="+", default=DEFAULT_NUM_TRIPLETS)
    p.add_argument("--val-num", type=int, default=2048)
    p.add_argument("--epochs", type=int, default=200)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--weight-decay", type=float, default=1e-4)
    p.add_argument("--bs", type=int, default=1024)
    p.add_argument("--patience", type=int, default=20)
    p.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Training device: cuda, cpu, or auto.",
    )
    p.add_argument(
        "--full-batch",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Use full-batch GD (default on).",
    )


def add_output_args(p: argparse.ArgumentParser, default_path: Path) -> None:
    p.add_argument("--output", type=str, default=str(default_path))


def main() -> None:
    parser = argparse.ArgumentParser(description="Standalone synthetic experiment runner.")
    sub = parser.add_subparsers(dest="command", required=True)

    margin = sub.add_parser("margin", help="Margin experiment: baseline vs rank-normalized reward.")
    add_data_args(margin)
    add_training_args(margin)
    margin.add_argument("--reward-scale", type=float, default=1.0)
    add_output_args(margin, RESULTS_DIR / "margin_results.json")

    alpha = sub.add_parser("connectivity-alpha", help="Alpha sweep for P- skewness.")
    add_data_args(alpha)
    add_training_args(alpha)
    alpha.add_argument("--alphas", type=float, nargs="+", default=DEFAULT_ALPHAS)
    alpha.add_argument("--reward-scale", type=float, default=1.0)
    alpha.add_argument("--compute-connectivity", action="store_true")
    add_output_args(alpha, RESULTS_DIR / "connectivity_alpha_results.json")

    optimize = sub.add_parser("optimize-connectivity", help="Baseline vs optimized P- over beta.")
    add_data_args(optimize)
    add_training_args(optimize)
    optimize.add_argument("--betas", type=float, nargs="+", default=DEFAULT_BETAS)
    optimize.add_argument("--include-optimized", action="store_true")
    optimize.add_argument("--compute-connectivity", action="store_true")
    optimize.add_argument("--opt-outer-steps", type=int, default=400)
    optimize.add_argument("--opt-inner-steps", type=int, default=200)
    optimize.add_argument("--opt-lr-outer", type=float, default=1e-4)
    optimize.add_argument("--opt-lr-inner", type=float, default=1e-2)
    optimize.add_argument("--opt-eval-every", type=int, default=10)
    optimize.add_argument("--opt-eval-steps", type=int, default=1000)
    optimize.add_argument("--opt-eval-lr", type=float, default=1e-2)
    add_output_args(optimize, RESULTS_DIR / "optimize_connectivity_results.json")

    args = parser.parse_args()
    RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    DATA_DIR.mkdir(parents=True, exist_ok=True)

    if args.command == "margin":
        payload = run_margin_experiment(args)
    elif args.command == "connectivity-alpha":
        payload = run_connectivity_alpha_experiment(args)
    elif args.command == "optimize-connectivity":
        payload = run_optimize_connectivity_experiment(args)
    else:
        raise ValueError(f"Unknown command: {args.command}")

    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(payload, f, indent=2)
    print(f"Saved results to {output_path}")


if __name__ == "__main__":
    main()
