import os
import re
import time
import math
import argparse
import pickle
from typing import Tuple, List, Dict

import numpy as np
import pandas as pd
import networkx as nx
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    recall_score,
    f1_score,
    roc_auc_score,
    average_precision_score,
)
import ot
from ignite.metrics import MaximumMeanDiscrepancy

from Denoiser_A_embedding import DenoiseNetworkA


# =========================
# small utils
# =========================

def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def collate_graphs(batch):
    # stack list of square adj to dense batch
    max_n = max(g.size(0) for g in batch)
    b = len(batch)
    A = torch.zeros(b, max_n, max_n, dtype=batch[0].dtype)
    node_mask = torch.zeros(b, max_n, dtype=torch.bool)
    for i, g in enumerate(batch):
        n = g.size(0)
        A[i, :n, :n] = g
        node_mask[i, :n] = True
    return A, node_mask

def permute_square(A: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
    return A.index_select(0, p).index_select(1, p)

def invert_perm(p: torch.Tensor) -> torch.Tensor:
    inv = torch.empty_like(p)
    inv[p] = torch.arange(p.numel(), device=p.device)
    return inv

def linear_coeffs(t: torch.Tensor):
    # alpha one minus t, beta t
    a = 1.0 - t
    b = t
    adot = torch.full_like(t, -1.0)
    bdot = torch.ones_like(t)
    return a, b, adot, bdot

def zero_diag_(M: torch.Tensor) -> torch.Tensor:
    M.diagonal().zero_()
    return M

def sym_zero_diag_valid(M: torch.Tensor, node_mask: torch.Tensor) -> torch.Tensor:
    # symmetric zero diag within valid nodes
    if M.dim() == 2:
        nm = node_mask.to(M.dtype)
        pair = nm[:, None] * nm[None, :]
        M = M * pair
        ut = torch.triu(M, diagonal=1)
        M = ut + ut.T
        M.fill_diagonal_(0.0)
        M = M * pair
        return M

    B, N, _ = M.shape
    nm = node_mask.to(M.dtype)
    pair = nm.unsqueeze(2) * nm.unsqueeze(1)
    M = M * pair

    ut_mask = torch.triu(torch.ones(N, N, dtype=torch.bool, device=M.device), diagonal=1).unsqueeze(0)
    ut = M.masked_fill(~ut_mask, 0.0)
    M = ut + ut.transpose(1, 2)
    M = M.masked_fill(torch.eye(N, dtype=torch.bool, device=M.device).unsqueeze(0), 0.0)
    M = M * pair
    return M

def add_masked_symmetric_noise(M: torch.Tensor,
                               node_mask: torch.Tensor,
                               edge_mask: torch.Tensor,
                               sigma: float,
                               clip01: bool = True) -> torch.Tensor:
    # add noise only on masked area
    if sigma <= 0.0:
        return sym_zero_diag_valid(M, node_mask)
    unknown = (1.0 - edge_mask).to(M.dtype)
    unknown = sym_zero_diag_valid(unknown, node_mask)
    eps = torch.randn_like(M)
    eps = sym_zero_diag_valid(eps, node_mask)
    M_noisy = M + sigma * (eps * unknown)
    M_noisy = sym_zero_diag_valid(M_noisy, node_mask)
    if clip01:
        M_noisy.clamp_(0.0, 1.0)
    return M_noisy

def upper_triu_mask_batched(node_mask: torch.Tensor) -> torch.Tensor:
    # valid upper tri
    B, N = node_mask.shape
    ut = torch.triu(torch.ones(N, N, dtype=torch.bool, device=node_mask.device), diagonal=1)
    ut = ut.unsqueeze(0).expand(B, -1, -1)
    pair = node_mask.unsqueeze(2) & node_mask.unsqueeze(1)
    return ut & pair

def masked_upper_mse(pred: torch.Tensor,
                     target: torch.Tensor,
                     node_mask: torch.Tensor,
                     edge_mask: torch.Tensor) -> torch.Tensor:
    # mse on masked valid upper tri
    B = pred.size(0)
    ut_valid  = upper_triu_mask_batched(node_mask)
    masked_ut = ut_valid & (edge_mask < 0.5)
    vals = []
    for i in range(B):
        mu = masked_ut[i]
        if mu.any():
            vals.append((pred[i][mu] - target[i][mu]).pow(2).mean())
        else:
            vals.append(pred.new_tensor(0.0))
    return torch.stack(vals).mean()

# =========================
# mmd helpers
# =========================

def _as_2d(X: np.ndarray) -> np.ndarray:
    X = np.asarray(X, dtype=np.float64)
    return X[:, None] if X.ndim == 1 else X

def _pairwise_sq_dists(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    X = _as_2d(X); Y = _as_2d(Y)
    XX = np.sum(X*X, axis=1, keepdims=True)
    YY = np.sum(Y*Y, axis=1, keepdims=True).T
    return XX + YY - 2.0 * (X @ Y.T)

def _median_heuristic_sigma(x: np.ndarray, y: np.ndarray) -> float:
    Z = np.concatenate([_as_2d(x), _as_2d(y)], axis=0)
    D2 = _pairwise_sq_dists(Z, Z)
    iu = np.triu_indices(D2.shape[0], k=1)
    vals = np.sqrt(np.maximum(D2[iu], 0))
    med = np.median(vals[vals > 0]) if vals.size else 1.0
    return float(med if np.isfinite(med) and med > 0 else 1.0)

# =========================
# gw helper
# =========================

def gw_distance_simple(gt: np.ndarray, estimation: np.ndarray) -> Tuple[float, float]:
    # simple gw on cost as given
    p = np.ones((gt.shape[0],)) / gt.shape[0]
    q = np.ones((estimation.shape[0],)) / estimation.shape[0]
    loss_fun = 'square_loss'
    dw2 = ot.gromov.gromov_wasserstein2(gt, estimation, p, q, loss_fun, log=False, armijo=False)
    return float(dw2), float(np.sqrt(max(dw2, 0.0)))

# =========================
# eval helpers used by train and sample
# =========================

def _compute_metric_row_single(A_true_t, A_rec_t, mask_t, plot_path,
                               score_mode="raw", variant="unknown",
                               sample_idx=None, n_steps=None):
    A_true = A_true_t.detach().cpu().numpy()
    A_rec  = A_rec_t.detach().cpu().numpy()
    mask   = mask_t.detach().cpu().numpy()
    n = A_true.shape[0]
    iu = np.triu_indices(n, k=1)
    masked_upper = (1.0 - mask)[iu] == 1
    if not masked_upper.any():
        return None

    y_true_masked = A_true[iu][masked_upper]
    y_hat_masked  = A_rec[iu][masked_upper]

    if score_mode.lower() == "raw":
        pred_mode = "RAW"
    elif score_mode.lower() == "bin":
        pred_mode = "BIN"
    else:
        is_bin = np.all((y_hat_masked == 0) | (y_hat_masked == 1))
        pred_mode = "BIN" if is_bin else "RAW"

    mae  = mean_absolute_error(y_true_masked, y_hat_masked)
    mse  = mean_squared_error(y_true_masked, y_hat_masked)
    frob = np.linalg.norm(y_true_masked - y_hat_masked)

    y_bin    = (y_true_masked > 0.5).astype(int)
    yhat_bin = (y_hat_masked  > 0.5).astype(int)
    rec = recall_score(y_bin, yhat_bin, zero_division=0)
    f1  = f1_score(y_bin, yhat_bin, zero_division=0)

    try:    auc = roc_auc_score(y_bin, y_hat_masked)
    except: auc = float("nan")
    try:    ap  = average_precision_score(y_bin, y_hat_masked)
    except: ap  = float("nan")

    TP = int(np.logical_and(y_bin == 1, yhat_bin == 1).sum())
    FN = int(np.logical_and(y_bin == 1, yhat_bin == 0).sum())
    FP = int(np.logical_and(y_bin == 0, yhat_bin == 1).sum())
    TN = int(np.logical_and(y_bin == 0, yhat_bin == 0).sum())
    fn_denom = FN + TP
    fp_denom = FP + TN
    fn_rate = (FN / fn_denom) if fn_denom > 0 else float("nan")
    fp_rate = (FP / fp_denom) if fp_denom > 0 else float("nan")

    return {
        "Variant": variant,
        "n_steps": n_steps,
        "sample":  sample_idx,
        "PredMode": pred_mode,
        "NumMasked": int(masked_upper.sum()),
        "MAE [pred=RAW]": mae,
        "MSE [pred=RAW]": mse,
        "FrobNorm [pred=RAW]": frob,
        "AveragePrecision [pred=RAW]": ap,
        "ROC_AUC [pred=RAW]": auc,
        "Rec@0.5 [pred=BIN]": rec,
        "F1@0.5 [pred=BIN]": f1,
        "FN_rate": fn_rate,
        "FP_rate": fp_rate,
        "PlotPath": plot_path,
    }

def _save_five_panel(A_true: torch.Tensor,
                     edge_mask: torch.Tensor,
                     A_step: torch.Tensor,
                     outpath: str,
                     title: str):
    # simple panel
    A1   = A_true.detach().cpu().numpy().astype(np.float32)
    M    = edge_mask.detach().cpu().numpy().astype(np.float32)
    Arec = A_step.detach().cpu().numpy().astype(np.float32)
    Arec_bin = (Arec > 0.5).astype(np.float32)
    diff_raw = Arec - A1
    diff_bin = Arec_bin - A1
    fig, axes = plt.subplots(1, 5, figsize=(16, 4))
    imkw = dict(cmap="Greys", vmin=0.0, vmax=1.0, interpolation="nearest")
    axes[0].imshow(A1, **imkw); axes[0].set_title("true"); axes[0].axis("off")
    axes[1].imshow(M,  **imkw); axes[1].set_title("mask"); axes[1].axis("off")
    axes[2].imshow(Arec_bin, **imkw); axes[2].set_title("bin"); axes[2].axis("off")
    axes[3].imshow(Arec, **imkw); axes[3].set_title("raw"); axes[3].axis("off")
    v = float(max(np.abs(diff_raw).max(), 1e-6))
    im4 = axes[4].imshow(diff_raw, cmap="bwr", vmin=-v, vmax=+v, interpolation="nearest")
    axes[4].set_title("raw delta"); axes[4].axis("off")
    fig.colorbar(im4, ax=axes[4], fraction=0.046, pad=0.04)
    plt.suptitle(title, fontsize=10)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(outpath, dpi=300)
    plt.close()

def evaluate_and_save_real(
    args, A1_list, reconstructed_list, edge_masks, plot_paths, st,
    score_mode: str = "auto",
    compute_gw: bool = False,
    gw_cost_mode: str = "adj",
    gw_entropic: bool = True,
    gw_epsilon: float = 5e-3,
    gw_max_iter: int = 200,
    gw_tol: float = 1e-9,
    compute_mmd: bool = True,
    mmd_kernel: str = "rbf",
    mmd_sigma = "median",
    mmd_on: str = "masked_raw",
    mmd_max_samples: int = 5000,
    mmd_seed: int = 0
):
    rows = []
    inferred_mode = None

    for i, (A_true, A_rec, mask) in enumerate(zip(A1_list, reconstructed_list, edge_masks)):
        A_true_np = A_true.cpu().numpy()
        A_rec_np  = A_rec.cpu().numpy()
        mask_np   = mask.cpu().numpy()

        n = A_true_np.shape[0]
        iu = np.triu_indices(n, k=1)
        masked_upper = (1.0 - mask_np)[iu] == 1
        if not masked_upper.any():
            continue

        y_true_masked = A_true_np[iu][masked_upper]
        y_hat_masked  = A_rec_np[iu][masked_upper]

        if inferred_mode is None:
            if score_mode.lower() == "raw": inferred_mode = "RAW"
            elif score_mode.lower() == "bin": inferred_mode = "BIN"
            else:
                is_binary = np.all((y_hat_masked == 0) | (y_hat_masked == 1))
                inferred_mode = "BIN" if is_binary else "RAW"

        mae  = mean_absolute_error(y_true_masked, y_hat_masked)
        mse  = mean_squared_error(y_true_masked, y_hat_masked)
        frob = np.linalg.norm(y_true_masked - y_hat_masked)

        y_bin    = (y_true_masked > 0.5).astype(int)
        yhat_bin = (y_hat_masked  > 0.5).astype(int)
        rec  = recall_score(y_bin, yhat_bin, zero_division=0)
        f1   = f1_score(y_bin, yhat_bin, zero_division=0)

        try: auc = roc_auc_score(y_bin, y_hat_masked)
        except ValueError: auc = float("nan")
        try: ap  = average_precision_score(y_bin, y_hat_masked)
        except ValueError: ap  = float("nan")

        TP = int(np.logical_and(y_bin == 1, yhat_bin == 1).sum())
        FN = int(np.logical_and(y_bin == 1, yhat_bin == 0).sum())
        FP = int(np.logical_and(y_bin == 0, yhat_bin == 1).sum())
        TN = int(np.logical_and(y_bin == 0, yhat_bin == 0).sum())
        fn_denom = FN + TP
        fp_denom = FP + TN
        fn_rate = (FN / fn_denom) if fn_denom > 0 else float("nan")
        fp_rate = (FP / fp_denom) if fp_denom > 0 else float("nan")

        GW2, GW = float("nan"), float("nan")
        if compute_gw:
            try:
                GW2, GW = gw_distance_simple(A_true_np, A_rec_np)
            except Exception:
                GW2, GW = np.nan, np.nan

        MMD2_un, MMD2_b = float("nan"), float("nan")
        if compute_mmd:
            if mmd_on == "masked_raw":
                x = y_true_masked.astype(np.float32)
                y = y_hat_masked.astype(np.float32)
            elif mmd_on == "full_raw":
                x = A_true_np[iu].astype(np.float32)
                y = A_rec_np[iu].astype(np.float32)
            else:
                raise ValueError("bad mmd_on")

            if x.size > mmd_max_samples: x = x[np.random.default_rng(mmd_seed).choice(x.size, mmd_max_samples, replace=False)]
            if y.size > mmd_max_samples: y = y[np.random.default_rng(mmd_seed+1).choice(y.size, mmd_max_samples, replace=False)]

            x_t = torch.from_numpy(x).view(-1, 1)
            y_t = torch.from_numpy(y).view(-1, 1)

            if mmd_sigma == "median":
                s = _median_heuristic_sigma(x.reshape(-1, 1), y.reshape(-1, 1))
                var = float(s * s)
            elif isinstance(mmd_sigma, (int, float)):
                var = float(mmd_sigma) * float(mmd_sigma)
            elif isinstance(mmd_sigma, (list, tuple)) and len(mmd_sigma) > 0:
                var = float(mmd_sigma[0]) * float(mmd_sigma[0])
            else:
                var = 1.0

            mmd_metric = MaximumMeanDiscrepancy(var=var)
            mmd_metric.update((x_t, y_t))
            mmd2_val = mmd_metric.compute()
            mmd2_val = float(mmd2_val.item() if isinstance(mmd2_val, torch.Tensor) else mmd2_val)
            MMD2_un = mmd2_val
            MMD2_b  = mmd2_val

        rows.append({
            "sample":            i,
            "PredMode":          inferred_mode,
            "NumMasked":         int(masked_upper.sum()),
            "MAE":               mae,
            "MSE":               mse,
            "FrobNorm":          frob,
            "AP":                ap,
            "Rec(0.5)":          rec,
            "F1(0.5)":           f1,
            "ROC_AUC":           auc,
            "MMD2_unbiased_maskedRAW": MMD2_un,
            "MMD2_biased_maskedRAW":   MMD2_b,
            "FN_rate":           fn_rate,
            "FP_rate":           fp_rate,
            "GW2":               GW2,
            "GW":                GW,
            "PlotPath":          plot_paths[i] if i < len(plot_paths) else "",
        })

    df = pd.DataFrame(rows)
    if df.empty:
        return None

    avg_row = df.mean(numeric_only=True)
    avg_row["sample"] = "average"
    pred_mode = df["PredMode"].iloc[0] if "PredMode" in df.columns else "RAW"
    avg_row["PredMode"] = pred_mode
    df = pd.concat([df, avg_row.to_frame().T], ignore_index=True)

    renames = {
        "MAE":       f"MAE [pred={pred_mode}]",
        "MSE":       f"MSE [pred={pred_mode}]",
        "FrobNorm":  f"FrobNorm [pred={pred_mode}]",
        "ROC_AUC":   f"ROC_AUC [pred={pred_mode}]",
        "AP":        f"AveragePrecision [pred={pred_mode}]",
        "Rec(0.5)":  "Rec@0.5 [pred=BIN]",
        "F1(0.5)":   "F1@0.5 [pred=BIN]",
        "MMD2_unbiased_maskedRAW": "MMD2_unbiased",
        "MMD2_biased_maskedRAW":   "MMD2_biased",
    }
    df.rename(columns=renames, inplace=True)

    ts = time.strftime("%Y%m%d_%H%M%S")
    out_dir = os.path.join(args.out_dir, "metrics")
    os.makedirs(out_dir, exist_ok=True)
    steps_tag  = getattr(args, "n_steps", "NA")
    drop_tag   = getattr(args, "drop_prob", "NA")
    name = f"{st}_{args.name}_real_{getattr(args,'epochs','NA')}ep_{steps_tag}steps_{drop_tag}drop_{ts}.csv"
    outpath = os.path.join(out_dir, name)
    df.to_csv(outpath, index=False)
    print(f"saved metrics csv -> {outpath}")
    return outpath

# =========================
# prior loader used by train and sample
# =========================

def _load_priors_from_npy_dir(npy_dir: str, masks: List[torch.Tensor], args):
    if not os.path.isdir(npy_dir):
        raise FileNotFoundError(f"prior dir not found: {npy_dir}")
    files = [f for f in os.listdir(npy_dir) if f.endswith(".npy") or f.endswith(".npz")]
    if len(files) == 0:
        raise FileNotFoundError(f"no numpy files in {npy_dir}")
    files = sorted(files)
    L = len(masks)
    priors, z1d_list = [], []

    for idx in range(L):
        cands = [f for f in files if re.search(rf'(?i)(^|[^0-9])g{idx}([^0-9]|$)', f)]
        pick = cands[0] if cands else (files[idx] if len(files) == L else None)
        if pick is None:
            raise FileNotFoundError(f"cannot match prior for graph {idx} in {npy_dir}")
        path = os.path.join(npy_dir, pick)
        arr = np.load(path, allow_pickle=True)
        if isinstance(arr, np.lib.npyio.NpzFile):
            keys = list(arr.keys())
            if not keys: raise ValueError(f"empty npz {path}")
            arr = arr[keys[0]]
        arr = np.array(arr, dtype=np.float32)

        M_mask = masks[idx]
        M_np = M_mask.cpu().numpy() if torch.is_tensor(M_mask) else np.asarray(M_mask)
        if arr.ndim != 2 or arr.shape[0] != arr.shape[1]:
            raise ValueError(f"prior not square: {path} shape {arr.shape}")
        if arr.shape != M_np.shape:
            raise ValueError(f"shape mismatch prior {arr.shape} vs mask {M_np.shape} for {path}")

        arr_sym = (arr + arr.T) * 0.5
        np.fill_diagonal(arr_sym, 0.0)
        arr_sym = np.clip(arr_sym, 0.0, 1.0)

        # simple pc one
        Zc = arr_sym - arr_sym.mean(0, keepdims=True)
        v = np.random.randn(Zc.shape[1]); v /= (np.linalg.norm(v) + 1e-12)
        for _ in range(20):
            v = Zc.T @ (Zc @ v); v /= (np.linalg.norm(v) + 1e-12)
        s = Zc @ v
        s -= s.min(); rng = s.max() - s.min()
        z1d = (s / rng) if rng > 1e-12 else np.linspace(0, 1, arr_sym.shape[0])

        priors.append(torch.from_numpy(arr_sym.astype(np.float32)))
        z1d_list.append(torch.from_numpy(z1d.astype(np.float32)))

    return priors, z1d_list

# =========================
# optional val posterior used by train
# =========================

def _posterior_eval_on_val_samples(epoch, val_loader, denoiser, device, args):
    ts = time.strftime("%Y%m%d_%H%M%S")
    base = os.path.join(args.out_dir, "val_mmse", f"{args.name}_ep{epoch:04d}_{ts}")
    plot_dir      = os.path.join(base, "plots")
    rounded_dir   = os.path.join(base, "rounded_A0")
    rounded_raw_dir  = os.path.join(base, "rounded_A0_raw")
    recon_dir     = os.path.join(base, "recon")
    recon_raw_dir = os.path.join(base, "recon_raw")
    for d in (plot_dir, rounded_dir, rounded_raw_dir, recon_dir, recon_raw_dir):
        os.makedirs(d, exist_ok=True)

    save_steps = args.val_save_steps
    raw_by_step: Dict[int, List[torch.Tensor]] = {step: [] for step in save_steps}
    true_graphs, plot_paths, edge_mask_list, initial_raws = [], [], [], []
    processed = 0
    K = args.val_posterior_k

    denoiser.eval()
    with torch.no_grad():
        for A_batch, node_mask_batch, edge_mask_batch, Y_prior_batch, z1d_batch in val_loader:
            A_batch         = A_batch.to(device)
            node_mask_batch = node_mask_batch.to(device)
            edge_mask_batch = edge_mask_batch.to(device)
            Y_prior_batch   = Y_prior_batch.to(device)
            z1d_batch       = z1d_batch.to(device)
            B = A_batch.size(0)

            for b in range(B):
                if processed >= K:
                    break

                A1        = A_batch[b]
                node_mask = node_mask_batch[b]
                edge_mask = edge_mask_batch[b]
                Y_prior   = Y_prior_batch[b]
                z1d       = z1d_batch[b]

                p = torch.argsort(z1d, dim=0)
                z1d       = z1d.index_select(0, p)
                node_mask = node_mask.index_select(0, p)
                A1        = permute_square(A1,       p)
                edge_mask = permute_square(edge_mask, p)
                Y_prior   = permute_square(Y_prior,  p)

                A0_clean = edge_mask * A1 + (1.0 - edge_mask) * Y_prior
                A0_noisy = add_masked_symmetric_noise(
                    M=A0_clean, node_mask=node_mask, edge_mask=edge_mask,
                    sigma=args.val_noise_std, clip01=True
                )

                A0_rounded = (A0_noisy > 0.5).float()
                prefix = f"val_ep{epoch:04d}_sample{processed}"
                np.save(os.path.join(rounded_dir,     f"{prefix}_A0rounded.npy"),      A0_rounded.cpu().numpy())
                np.save(os.path.join(rounded_raw_dir, f"{prefix}_A0raw.npy"),          A0_noisy.cpu().numpy())
                np.save(os.path.join(rounded_raw_dir, f"{prefix}_A0raw_clean.npy"),    A0_clean.cpu().numpy())

                initial_raws.append(A0_clean.cpu().clone())

                A = A0_noisy.clone()
                dt = 1.0 / args.n_steps
                xfeat = torch.zeros(1, z1d.shape[0], 1, device=z1d.device, dtype=z1d.dtype)
                for step in range(args.n_steps):
                    inp = A.unsqueeze(0).unsqueeze(1)
                    t   = torch.full((1,), step * dt, device=device)
                    b = denoiser(xfeat, inp, node_mask.unsqueeze(0), t).squeeze(0)
                    b = sym_zero_diag_valid(b, node_mask)
                    b = b * (1.0 - edge_mask)
                    A = A + dt * b
                    A = edge_mask * A1 + (1.0 - edge_mask) * A
                    A = sym_zero_diag_valid(A, node_mask)
                    if step in save_steps:
                        raw_by_step[step].append(A.cpu().clone())

                zero_diag_(A)
                np.save(os.path.join(recon_raw_dir, f"{prefix}_reconstructed_raw.npy"), A.cpu().numpy())
                reconstructed_A = (A > 0.5).float()
                zero_diag_(reconstructed_A)

                diff = (reconstructed_A - A1).cpu()
                plot_path = os.path.join(plot_dir, f"{prefix}_plot.png")
                fig, axes = plt.subplots(1, 5, figsize=(16, 4))
                axes[0].imshow(A1.cpu(), cmap='Greys');                axes[0].set_title("true");          axes[0].axis("off")
                axes[1].imshow(edge_mask.cpu(), cmap='Greys');         axes[1].set_title("mask");          axes[1].axis("off")
                axes[2].imshow((A1 * edge_mask).cpu(), cmap='Greys');  axes[2].set_title("masked");        axes[2].axis("off")
                axes[3].imshow(reconstructed_A.cpu(), cmap='Greys');   axes[3].set_title("recon");         axes[3].axis("off")
                v = diff.abs().max().item() or 1e-6
                im = axes[4].imshow(diff.cpu(), cmap='bwr', vmin=-v, vmax=+v)
                axes[4].set_title("raw delta"); axes[4].axis("off")
                fig.colorbar(im, ax=axes[4], fraction=0.046, pad=0.04)
                plt.suptitle(f"val mmse ep {epoch}", fontsize=10)
                plt.tight_layout(rect=[0, 0.03, 1, 0.95])
                plt.savefig(plot_path, dpi=300)
                plt.close()

                true_graphs.append(A1.cpu().clone())
                edge_mask_list.append(edge_mask.cpu().clone())
                plot_paths.append(plot_path)
                processed += 1

            if processed >= K:
                break

    outpaths = []
    for step in args.val_save_steps:
        out = evaluate_and_save_real(
            args,
            true_graphs,
            raw_by_step[step],
            edge_mask_list,
            plot_paths,
            st=step,
            score_mode="raw"
        )
        if out: outpaths.append(out)

    evaluate_and_save_real(
        args,
        true_graphs,
        initial_raws,
        edge_mask_list,
        plot_paths,
        st="A0RAW"
    )
    print(f"val posterior saved under {base}")
    return outpaths

# =========================
# train
# =========================

def train(args):
    print("start train")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    set_seed(args.seed)

    print("load data")
    train_graphs = pickle.load(open(args.train_pkl, 'rb'))
    val_graphs   = pickle.load(open(args.val_pkl,   'rb'))

    mask_dir = os.path.join(os.path.dirname(args.train_pkl), f"masks_drop{args.drop_prob}")
    train_masks_np = pickle.load(open(os.path.join(mask_dir, "train_masks.pkl"), 'rb'))
    val_masks_np   = pickle.load(open(os.path.join(mask_dir, "val_masks.pkl"),   'rb'))
    train_masks = [torch.from_numpy(m).float() for m in train_masks_np]
    val_masks   = [torch.from_numpy(m).float() for m in val_masks_np]

    if not getattr(args, 'n2v_prior_train_dir', None):
        raise ValueError("need --n2v_prior_train_dir")
    if not getattr(args, 'n2v_prior_val_dir', None):
        raise ValueError("need --n2v_prior_val_dir")

    print("load priors")
    n2v_priors_train, n2v_z1d_train = _load_priors_from_npy_dir(args.n2v_prior_train_dir, train_masks, args)
    n2v_priors_val,   n2v_z1d_val   = _load_priors_from_npy_dir(args.n2v_prior_val_dir,   val_masks,   args)

    class GraphMaskDataset(Dataset):
        def __init__(self, graphs, masks, priors, z1d):
            self.graphs = graphs
            self.masks  = masks
            self.priors = priors
            self.z1d    = z1d
        def __len__(self): return len(self.graphs)
        def __getitem__(self, i):
            G = self.graphs[i]
            A = torch.tensor(nx.to_numpy_array(G), dtype=torch.float32) if isinstance(G, nx.Graph) else G.float()
            return A, self.masks[i], self.priors[i].float(), self.z1d[i].float()

    def collate_fn(batch):
        As, Ms, Ys, Zs = zip(*batch)
        A_batch, node_mask = collate_graphs(As)
        max_n = A_batch.size(1); B = len(Ms)
        M_p = torch.zeros(B, max_n, max_n, dtype=torch.float32)
        Y_p = torch.zeros(B, max_n, max_n, dtype=torch.float32)
        Z_p = torch.zeros(B, max_n,           dtype=torch.float32)
        for i, (M, Y, z) in enumerate(zip(Ms, Ys, Zs)):
            n = M.size(0)
            M_p[i, :n, :n] = M
            Y_p[i, :n, :n] = Y
            Z_p[i, :n]     = z
        return A_batch, node_mask, M_p, Y_p, Z_p

    train_loader = DataLoader(
        GraphMaskDataset(train_graphs, train_masks, n2v_priors_train, n2v_z1d_train),
        batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn
    )
    val_loader = DataLoader(
        GraphMaskDataset(val_graphs, val_masks, n2v_priors_val, n2v_z1d_val),
        batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn
    )

    def _size_of(g):
        if isinstance(g, torch.Tensor): return g.size(0)
        if isinstance(g, nx.Graph):     return g.number_of_nodes()
        raise TypeError("bad graph type")
    Nmax = max(_size_of(g) for g in train_graphs)

    denoiser = DenoiseNetworkA(
        max_feat_num=1,
        max_node_num=Nmax,
        nhid=args.hidden_dim,
        num_layers=args.num_layers,
        num_linears=args.num_linears,
        c_init=args.c_init,
        c_hid=args.c_hid,
        c_final=args.c_final,
        adim=args.hidden_dim,
    ).to(device)
    optimizer = optim.Adam(denoiser.parameters(), lr=args.lr)

    ckpt_root = os.path.join(args.out_dir, "checkpoints")
    run_ts = time.strftime("%Y%m%d_%H%M%S")
    ckpt_dir = os.path.join(ckpt_root, f"{args.name}_{args.drop_prob}drop_{run_ts}")
    os.makedirs(ckpt_dir, exist_ok=True)
    with open(os.path.join(ckpt_dir, "run_args.txt"), "w") as f:
        f.write(str(vars(args)))

    train_losses, val_losses = [], []

    for epoch in range(1, args.epochs + 1):
        denoiser.train()
        sum_per_graphs, graphs_seen = 0.0, 0

        for A_batch, node_mask, edge_mask, Y_prior, z1d in train_loader:
            A_batch = A_batch.to(device)
            node_mask = node_mask.to(device)
            edge_mask = edge_mask.to(device)
            Y_prior = Y_prior.to(device)
            z1d     = z1d.to(device)
            B, N, _ = A_batch.size()

            perms, invs = [], []
            for i in range(B):
                p = torch.argsort(z1d[i], dim=0)
                perms.append(p); invs.append(invert_perm(p))
            for i, p in enumerate(perms):
                z1d[i]       = z1d[i].index_select(0, p)
                node_mask[i] = node_mask[i].index_select(0, p)
                A_batch[i]   = permute_square(A_batch[i], p)
                edge_mask[i] = permute_square(edge_mask[i], p)
                Y_prior[i]   = permute_square(Y_prior[i], p)

            A_obs  = sym_zero_diag_valid(A_batch * edge_mask, node_mask)
            omega  = 1.0 - A_obs
            A0     = omega * Y_prior + A_obs
            A0     = add_masked_symmetric_noise(M=A0, node_mask=node_mask, edge_mask=A_obs,
                                                sigma=args.train_noise_std, clip01=True)

            t = torch.rand(B, device=device)
            a, b, adot, bdot = linear_coeffs(t)
            av,bv  = a.view(B,1,1),     b.view(B,1,1)
            adv,bdv= adot.view(B,1,1), bdot.view(B,1,1)

            I_t   = sym_zero_diag_valid(av*A0 + bv*A_batch, node_mask)
            inp   = I_t.unsqueeze(1)
            x_feat = torch.zeros_like(z1d).unsqueeze(-1)
            b_pred = denoiser(x_feat, inp, node_mask, t)
            b_pred = sym_zero_diag_valid(b_pred, node_mask)
            b_pred = b_pred * omega

            target = sym_zero_diag_valid(A_batch - A0, node_mask)
            loss = masked_upper_mse(b_pred, target, node_mask, A_obs)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sum_per_graphs += float(loss.item()) * B
            graphs_seen += B

        train_loss = sum_per_graphs / max(1, graphs_seen)
        train_losses.append(train_loss)

        denoiser.eval()
        val_sum, val_seen = 0.0, 0
        with torch.no_grad():
            for A_batch, node_mask, edge_mask, Y_prior, z1d in val_loader:
                A_batch = A_batch.to(device)
                node_mask = node_mask.to(device)
                edge_mask = edge_mask.to(device)
                Y_prior = Y_prior.to(device)
                z1d     = z1d.to(device)
                B, N, _ = A_batch.size()

                perms = [torch.argsort(z1d[i], dim=0) for i in range(B)]
                for i, p in enumerate(perms):
                    z1d[i]      = z1d[i].index_select(0, p)
                    node_mask[i]= node_mask[i].index_select(0, p)
                    A_batch[i]  = permute_square(A_batch[i], p)
                    edge_mask[i]= permute_square(edge_mask[i], p)
                    Y_prior[i]  = permute_square(Y_prior[i], p)

                A_obs  = sym_zero_diag_valid(A_batch * edge_mask, node_mask)
                omega  = 1.0 - A_obs
                A0     = omega * Y_prior + A_obs
                A0     = add_masked_symmetric_noise(M=A0, node_mask=node_mask, edge_mask=A_obs,
                                                    sigma=args.val_noise_std, clip01=True)

                t = torch.rand(B, device=device)
                a, b, adot, bdot = linear_coeffs(t)
                av,bv  = a.view(B,1,1),     b.view(B,1,1)
                I_t   = sym_zero_diag_valid(av*A0 + bv*A_batch, node_mask)
                inp   = I_t.unsqueeze(1)
                x_feat = torch.zeros_like(z1d).unsqueeze(-1)
                b_pred = denoiser(x_feat, inp, node_mask, t)
                b_pred = sym_zero_diag_valid(b_pred, node_mask)
                b_pred = b_pred * omega
                target = sym_zero_diag_valid(A_batch - A0, node_mask)
                l = masked_upper_mse(b_pred, target, node_mask, A_obs)
                val_sum += float(l.item()) * B
                val_seen += B

        val_loss = val_sum / max(1, val_seen)
        val_losses.append(val_loss)
        print(f"epoch {epoch}: train {train_loss:.6f}  val {val_loss:.6f}")

        if args.val_posterior_every > 0 and (epoch % args.val_posterior_every == 0):
            _posterior_eval_on_val_samples(epoch, val_loader, denoiser, device, args)

        if (epoch % args.ckpt_every == 0) or (epoch == args.epochs):
            ckpt_path = os.path.join(ckpt_dir, f"ep{epoch:04d}.pt")
            torch.save(denoiser.state_dict(), ckpt_path)
            print(f"saved ckpt {ckpt_path}")

    os.makedirs(os.path.join(args.out_dir, "loss_plots"), exist_ok=True)
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='train', marker='o')
    plt.plot(val_losses, label='val', marker='x')
    plt.title("loss over epochs")
    plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.grid(True); plt.tight_layout()
    ts = time.strftime("%Y%m%d_%H%M%S")
    loss_plot_path = os.path.join(args.out_dir, "loss_plots", f"loss_curve_{ts}.png")
    plt.savefig(loss_plot_path, dpi=300); plt.close()
    print(f"saved plot {loss_plot_path}")

# =========================
# sample
# =========================

def sample(args):
    print("start sample")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if args.sample_pkl and args.mask_pkl:
        sample_graphs = pickle.load(open(args.sample_pkl, 'rb'))
        A1_list = [
            torch.tensor(nx.to_numpy_array(g), dtype=torch.float32, device=device)
            if isinstance(g, nx.Graph) else g.to(dtype=torch.float32, device=device)
            for g in sample_graphs
        ]
        M_list = [torch.from_numpy(m).to(device).float() for m in pickle.load(open(args.mask_pkl, 'rb'))]
        print(f"loaded {len(A1_list)} graphs")
    elif args.input_graph and args.mask_npy:
        A = np.load(args.input_graph).astype(np.float32)
        M = np.load(args.mask_npy).astype(np.float32)
        A1_list = [torch.from_numpy(A).to(device)]
        M_list  = [torch.from_numpy(M).to(device)]
        print("loaded single graph")
    else:
        raise ValueError("need sample inputs")

    if not getattr(args, 'n2v_prior_test_dir', None):
        raise ValueError("need --n2v_prior_test_dir")

    graphs_cpu = [a.detach().cpu() for a in A1_list]
    masks_cpu  = [m.detach().cpu() for m in M_list]
    test_priors, test_z1d = _load_priors_from_npy_dir(args.n2v_prior_test_dir, masks_cpu, args)
    print(f"loaded {len(test_priors)} priors")

    max_nodes = max(g.shape[0] for g in graphs_cpu)
    model_nodes = args.max_graph_nodes if args.max_graph_nodes > 0 else max_nodes

    denoiser = DenoiseNetworkA(
        max_feat_num=1,
        max_node_num=model_nodes,
        nhid=args.hidden_dim,
        num_layers=args.num_layers,
        num_linears=args.num_linears,
        c_init=args.c_init,
        c_hid=args.c_hid,
        c_final=args.c_final,
        adim=args.hidden_dim,
    ).to(device)
    denoiser.load_state_dict(torch.load(args.ckpt, map_location=device))
    denoiser.eval()

    ts = time.strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(args.out_dir, "mmse_raw", f"{args.name}_{args.drop_prob}drop_{ts}")
    plot_dir      = os.path.join(run_dir, "plots")
    a0_dir        = os.path.join(run_dir, "A0_rounded")
    a0_raw_dir    = os.path.join(run_dir, "A0_raw")
    recon_raw_dir = os.path.join(run_dir, "recon_raw")
    for d in (plot_dir, a0_dir, a0_raw_dir, recon_raw_dir):
        os.makedirs(d, exist_ok=True)
    prefix = os.path.basename(run_dir)

    n_steps_values = [int(s) for s in args.steps_csv.split(",")] if args.steps_csv else [1]
    final_by_steps: Dict[int, List[torch.Tensor]] = {n: [] for n in n_steps_values}
    true_graphs: List[torch.Tensor] = []
    plot_paths: List[str] = []
    edge_mask_list: List[torch.Tensor] = []
    aobs_mask_list: List[torch.Tensor] = []
    atrue_mask_list: List[torch.Tensor] = []
    initial_raws: List[torch.Tensor] = []
    csv_paths: List[str] = []

    for i, (A1, edge_mask) in enumerate(zip(A1_list, M_list)):
        print(f"sample {i+1}/{len(A1_list)}")
        node_mask = torch.ones(A1.size(0), dtype=torch.bool, device=device)
        Y_prior = test_priors[i].to(device).float()
        z_full  = test_z1d[i].to(device).float()

        p     = torch.argsort(z_full, dim=0)
        p_inv = invert_perm(p)
        z_full   = z_full.index_select(0, p)
        A1_p     = permute_square(A1, p)
        mask_p   = permute_square(edge_mask, p)
        node_mask= node_mask.index_select(0, p)
        Y_p      = permute_square(Y_prior, p)

        A_obs   = sym_zero_diag_valid(A1_p * mask_p, node_mask)
        omega   = 1.0 - A_obs
        A0_clean = omega * Y_p + A_obs
        A0_noisy = add_masked_symmetric_noise(M=A0_clean, node_mask=node_mask, edge_mask=A_obs,
                                              sigma=args.noise_std, clip01=True)

        A0_unperm_clean = permute_square(A0_clean, p_inv)
        A0_unperm_noisy = permute_square(A0_noisy, p_inv)
        np.save(os.path.join(a0_raw_dir, f"{prefix}_sample{i}_A0raw.npy"),        A0_unperm_noisy.cpu().numpy())
        np.save(os.path.join(a0_raw_dir, f"{prefix}_sample{i}_A0raw_clean.npy"),  A0_unperm_clean.cpu().numpy())
        A0_rounded_unperm = (A0_unperm_noisy > 0.5).float()
        np.save(os.path.join(a0_dir,     f"{prefix}_sample{i}_A0rounded.npy"), A0_rounded_unperm.cpu().numpy())
        initial_raws.append(A0_unperm_clean.cpu().clone())

        A_obs_unperm = permute_square(A_obs, p_inv)
        aobs_mask_list.append(A_obs_unperm.cpu().clone())
        atrue_mask_list.append(A1.cpu().clone())

        row_A0 = _compute_metric_row_single(
            A_true_t=A1, A_rec_t=A0_unperm_clean, mask_t=edge_mask,
            plot_path="", score_mode="raw", variant="A0raw", sample_idx=i, n_steps=0
        )

        true_graphs.append(A1.cpu().clone())
        edge_mask_list.append(edge_mask.cpu().clone())

        # optional traj plots
        if args.traj_plot and i < args.traj_max_samples:
            traj_k = args.traj_k if args.traj_k > 0 else max(n_steps_values)
            traj_dir  = os.path.join(run_dir, f"traj_k{traj_k}", f"sample{i:03d}")
            os.makedirs(traj_dir, exist_ok=True)
            _save_five_panel(A_true=A1, edge_mask=edge_mask, A_step=A0_unperm_clean,
                             outpath=os.path.join(traj_dir,  "step000_A0.png"),
                             title="recon t zero a0")

            A_tmp = A0_noisy.clone()
            dt = 1.0 / traj_k
            x_feat = torch.zeros(1, z_full.shape[0], 1, device=device, dtype=z_full.dtype)
            for step in range(traj_k):
                inp = A_tmp.unsqueeze(0).unsqueeze(1)
                t   = torch.full((1,), step * dt, device=device)
                with torch.no_grad():
                    b = denoiser(x_feat, inp, node_mask.unsqueeze(0), t).squeeze(0)
                    b = sym_zero_diag_valid(b, node_mask)
                b = b * omega
                A_tmp = A_tmp + dt * b
                A_tmp.clamp_(0.0, 1.0)
                A_tmp = A_obs + omega * A_tmp
                A_tmp = sym_zero_diag_valid(A_tmp, node_mask)
                if (step + 1) % args.traj_every == 0 or step == traj_k - 1:
                    A_unperm = permute_square(A_tmp, p_inv)
                    panel_path = os.path.join(traj_dir, f"step{step+1:03d}.png")
                    _save_five_panel(A_true=A1, edge_mask=edge_mask, A_step=A_unperm,
                                     outpath=panel_path, title=f"recon step {step+1}")

        # main runs per steps value
        A_final_for_plot = None
        for current_n_steps in n_steps_values:
            A = A0_noisy.clone()
            dt = 1.0 / current_n_steps
            x_feat = torch.zeros(1, z_full.shape[0], 1, device=device, dtype=z_full.dtype)
            for step in range(current_n_steps):
                inp = A.unsqueeze(0).unsqueeze(1)
                with torch.no_grad():
                    t = torch.full((1,), step * dt, device=device)
                    b = denoiser(x_feat, inp, node_mask.unsqueeze(0), t).squeeze(0)
                    b = sym_zero_diag_valid(b, node_mask)
                b = b * omega
                A = A + dt * b
                A.clamp_(0.0, 1.0)
                A = A_obs + omega * A
                A = sym_zero_diag_valid(A, node_mask)

            A_final_unperm = permute_square(A, p_inv)
            zero_diag_(A_final_unperm)
            final_by_steps[current_n_steps].append(A_final_unperm.cpu().clone())
            np.save(os.path.join(recon_raw_dir, f"{prefix}_sample{i}_{current_n_steps}steps_recon_raw.npy"),
                    A_final_unperm.cpu().numpy())
            if current_n_steps == max(n_steps_values):
                A_final_for_plot = A_final_unperm

        reconstructed_A = (A_final_for_plot > 0.5).float()
        zero_diag_(reconstructed_A)
        diff = (reconstructed_A - A1).cpu()
        plot_path = os.path.join(plot_dir, f"{prefix}_sample{i}_plot.png")
        fig, axes = plt.subplots(1, 5, figsize=(16, 4))
        axes[0].imshow(A1.cpu(), cmap='Greys');           axes[0].set_title("true"); axes[0].axis("off")
        axes[1].imshow(edge_mask.cpu(), cmap='Greys');    axes[1].set_title("mask"); axes[1].axis("off")
        axes[2].imshow((A1 * edge_mask).cpu(), cmap='Greys'); axes[2].set_title("masked"); axes[2].axis("off")
        axes[3].imshow(reconstructed_A.cpu(), cmap='Greys');   axes[3].set_title(f"recon {max(n_steps_values)}"); axes[3].axis("off")
        v = diff.abs().max().item() or 1e-6
        im = axes[4].imshow(diff.cpu(), cmap='bwr', vmin=-v, vmax=+v)
        axes[4].set_title("raw delta"); axes[4].axis("off")
        fig.colorbar(im, ax=axes[4], fraction=0.046, pad=0.04)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.savefig(plot_path, dpi=300)
        plt.close()
        print(f"saved plot {plot_path}")
        plot_paths.append(plot_path)

    print("eval per steps")
    for n_steps, recon in final_by_steps.items():
        out = evaluate_and_save_real(
            args, true_graphs, recon, edge_mask_list, plot_paths,
            st=f"final_recon_{n_steps}steps",
            score_mode="raw",
            compute_gw=True,
            gw_cost_mode="adj",
            gw_entropic=False,
            gw_epsilon=0.2,
            gw_max_iter=20000,
            gw_tol=1e-7,
            compute_mmd=True,
            mmd_kernel="rbf",
            mmd_sigma="median",
            mmd_on="full_raw",
            mmd_max_samples=5000
        )
        if out: csv_paths.append(out)

        out_union = evaluate_and_save_real(
            args, true_graphs, recon, aobs_mask_list, plot_paths,
            st=f"final_recon_{n_steps}steps_AobsZero",
            score_mode="raw",
            compute_gw=True, gw_cost_mode="adj",
            gw_entropic=False, gw_epsilon=0.2, gw_max_iter=20000, gw_tol=1e-7,
            compute_mmd=True, mmd_kernel="rbf", mmd_sigma="median",
            mmd_on="full_raw", mmd_max_samples=5000
        )
        if out_union: csv_paths.append(out_union)

        out_truezero = evaluate_and_save_real(
            args, true_graphs, recon, atrue_mask_list, plot_paths,
            st=f"final_recon_{n_steps}steps_trueZero",
            score_mode="raw",
            compute_gw=True, gw_cost_mode="adj",
            gw_entropic=False, gw_epsilon=0.2, gw_max_iter=20000, gw_tol=1e-7,
            compute_mmd=True, mmd_kernel="rbf", mmd_sigma="median",
            mmd_on="full_raw", mmd_max_samples=5000
        )
        if out_truezero: csv_paths.append(out_truezero)

    print("eval a zero baseline")
    out_a0_masked = evaluate_and_save_real(
        args, true_graphs, initial_raws, edge_mask_list, plot_paths,
        st="A0raw",
        score_mode="raw",
        compute_gw=True, gw_cost_mode="adj",
        gw_entropic=False, gw_epsilon=0.2, gw_max_iter=20000, gw_tol=1e-7,
        compute_mmd=True, mmd_kernel="rbf", mmd_sigma="median",
        mmd_on="full_raw", mmd_max_samples=5000
    )
    if out_a0_masked: csv_paths.append(out_a0_masked)

    out_a0_aobs = evaluate_and_save_real(
        args, true_graphs, initial_raws, aobs_mask_list, plot_paths,
        st="A0raw_AobsZero",
        score_mode="raw",
        compute_gw=True, gw_cost_mode="adj",
        gw_entropic=False, gw_epsilon=0.2, gw_max_iter=20000, gw_tol=1e-7,
        compute_mmd=True, mmd_kernel="rbf", mmd_sigma="median",
        mmd_on="full_raw", mmd_max_samples=5000
    )
    if out_a0_aobs: csv_paths.append(out_a0_aobs)

    out_a0_tz = evaluate_and_save_real(
        args, true_graphs, initial_raws, atrue_mask_list, plot_paths,
        st="A0raw_trueZero",
        score_mode="raw",
        compute_gw=True, gw_cost_mode="adj",
        gw_entropic=False, gw_epsilon=0.2, gw_max_iter=20000, gw_tol=1e-7,
        compute_mmd=True, mmd_kernel="rbf", mmd_sigma="median",
        mmd_on="full_raw", mmd_max_samples=5000
    )
    if out_a0_tz: csv_paths.append(out_a0_tz)

    if csv_paths:
        print("saved csv files:")
        for p in csv_paths:
            print(p)
    print("done")

# =========================
# cli
# =========================

def build_parser():
    p = argparse.ArgumentParser()
    sub = p.add_subparsers(dest="cmd", required=True)

    # shared
    p.add_argument("--out_dir", type=str, default="./out")
    p.add_argument("--name", type=str, default="expansion")
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--hidden_dim", type=int, default=32)
    p.add_argument("--num_layers", type=int, default=5)
    p.add_argument("--num_linears", type=int, default=2)
    p.add_argument("--c_init", type=int, default=2)
    p.add_argument("--c_hid", type=int, default=8)
    p.add_argument("--c_final", type=int, default=2)

    # train
    pt = sub.add_parser("train")
    pt.add_argument("--train_pkl", type=str, required=True)
    pt.add_argument("--val_pkl", type=str, required=True)
    pt.add_argument("--drop_prob", type=str, default="NA")
    pt.add_argument("--n2v_prior_train_dir", type=str, required=True)
    pt.add_argument("--n2v_prior_val_dir", type=str, required=True)
    pt.add_argument("--batch_size", type=int, default=8)
    pt.add_argument("--epochs", type=int, default=300)
    pt.add_argument("--lr", type=float, default=2e-4)
    pt.add_argument("--train_noise_std", type=float, default=0.1)
    pt.add_argument("--val_noise_std", type=float, default=0.1)
    pt.add_argument("--ckpt_every", type=int, default=100)
    pt.add_argument("--val_posterior_every", type=int, default=0)
    pt.add_argument("--val_save_steps", type=int, nargs="*", default=[1, 10, 100])
    pt.set_defaults(func=train)

    # sample
    ps = sub.add_parser("sample")
    ps.add_argument("--ckpt", type=str, required=True)
    ps.add_argument("--drop_prob", type=str, default="NA")
    ps.add_argument("--sample_pkl", type=str, default="")
    ps.add_argument("--mask_pkl", type=str, default="")
    ps.add_argument("--input_graph", type=str, default="")
    ps.add_argument("--mask_npy", type=str, default="")
    ps.add_argument("--n2v_prior_test_dir", type=str, required=True)
    ps.add_argument("--noise_std", type=float, default=0.1)
    ps.add_argument("--steps_csv", type=str, default="1,100")
    ps.add_argument("--max_graph_nodes", type=int, default=0)
    ps.add_argument("--traj_plot", action="store_true")
    ps.add_argument("--traj_k", type=int, default=0)
    ps.add_argument("--traj_every", type=int, default=10)
    ps.add_argument("--traj_max_samples", type=int, default=3)
    ps.set_defaults(func=sample)

    return p

if __name__ == "__main__":
    parser = build_parser()
    args = parser.parse_args()
    os.makedirs(args.out_dir, exist_ok=True)
    args.func(args)
