import argparse, json, os, numpy as np, pandas as pd, torch, torch.nn as nn
from joblib import load
import glob
import math

# ------------ Model ------------
class SurrogateMLP(nn.Module):
    def __init__(self, d_in=86, d_hidden=2048, d_out=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden), nn.LeakyReLU(),
            nn.Linear(d_hidden, d_hidden), nn.LeakyReLU(),
            nn.Linear(d_hidden, d_out)
        )
    def forward(self, z):  # z is standardized X
        return self.net(z) # standardized y

def load_artifacts(outdir, device):
    ckpt = torch.load(os.path.join(outdir, "surrogate.pt"), map_location="cpu")
    d_in = ckpt.get("d_in", 86); d_hidden = ckpt.get("d_hidden", 2048)
    model = SurrogateMLP(d_in, d_hidden, 1)
    model.load_state_dict(ckpt["state_dict"]); model.to(device).eval()
    x_scaler = load(os.path.join(outdir, "x_scaler.joblib"))
    y_scaler = load(os.path.join(outdir, "y_scaler.joblib"))
    with open(os.path.join(outdir, "elem_cols.txt")) as f:
        elem_cols = [ln.strip() for ln in f if ln.strip()]
    return model, x_scaler, y_scaler, elem_cols, d_hidden

# ------------ MALA in z-space (sample from exp(beta f)) ------------

def reflect_to_box(z, z_lo, z_hi):
    """Component-wise reflection into [z_lo, z_hi]."""
    below = z < z_lo
    z = torch.where(below, 2*z_lo - z, z)
    above = z > z_hi
    z = torch.where(above, 2*z_hi - z, z)
    # final clamp in case of double-bounce
    return torch.minimum(torch.maximum(z, z_lo), z_hi)

def project_radial(z, rad_cap):
    if rad_cap is None or rad_cap <= 0:
        return z
    r = torch.linalg.norm(z, dim=1, keepdim=True)
    scale = torch.minimum(torch.ones_like(r), rad_cap / (r + 1e-12))
    return z * scale

def grad_logp_beta(model, z, beta=1.0):
    z = z.detach().requires_grad_(True)
    f = model(z)[:, 0]
    (g,) = torch.autograd.grad((beta * f).sum(), z)
    return (beta * f).detach(), g.detach()

def mala_step_zboxed(z, logp_and_grad_fn, eps, z_lo, z_hi, rad_cap=None):
    # current
    z = z.detach()
    logp, grad = logp_and_grad_fn(z)
    mean = z + 0.5 * (eps**2) * grad

    # propose
    prop = mean + eps * torch.randn_like(z)
    prop = reflect_to_box(prop, z_lo, z_hi)
    prop = project_radial(prop, rad_cap)

    # reverse kernel terms
    logp_p, grad_p = logp_and_grad_fn(prop)
    mean_p = prop + 0.5 * (eps**2) * grad_p

    def log_gauss(x, m): return -0.5 * ((x - m)**2).sum(dim=1)
    log_q_fwd = log_gauss(prop, mean)
    log_q_rev = log_gauss(z,   mean_p)

    log_alpha = (logp_p - logp) + (log_q_rev - log_q_fwd)
    accept = (torch.log(torch.rand_like(log_alpha)) < log_alpha).float().unsqueeze(1)
    z_new = accept * prop + (1 - accept) * z

    # safety
    z_new = reflect_to_box(z_new, z_lo, z_hi)
    z_new = project_radial(z_new, rad_cap)
    return z_new.detach(), float(accept.mean().item())

def run_mala_stable(
    model, x_scaler, X_init_raw, steps, base_eps, beta, device,
    Z_vis=None, z_plo=1.0, z_phi=99.0, z_rad_phi=99.5,
    adapt=True, target_acc=0.57, adapt_interval=10, eps_min=1e-4, eps_max=0.5
):
    """
    - model: predicts y_z from z
    - X_init_raw: raw starts (e.g., top-128 visible)
    - base_eps: pre-scaled step size; actual eps = base_eps / sqrt(d)
    - Z_vis: standardized visible-set features to compute bounds; if None, no box
    """
    d = X_init_raw.shape[1]
    eps = base_eps / (d ** 0.5)

    Z0 = x_scaler.transform(X_init_raw).astype(np.float32)
    z  = torch.tensor(Z0, dtype=torch.float32, device=device)

    # Z-space bounds from visible data (recommended)
    if Z_vis is not None:
        z_lo_np = np.percentile(Z_vis, z_plo, axis=0).astype(np.float32)
        z_hi_np = np.percentile(Z_vis, z_phi, axis=0).astype(np.float32)
        z_lo = torch.from_numpy(z_lo_np).to(device)
        z_hi = torch.from_numpy(z_hi_np).to(device)
        # radial cap from visible norms
        rad_cap = float(np.percentile(np.linalg.norm(Z_vis, axis=1), z_rad_phi))
    else:
        z_lo = torch.full_like(z, -float('inf')).mean(dim=0, keepdim=False)  # dummy; will be broadcast
        z_hi = torch.full_like(z,  float('inf')).mean(dim=0, keepdim=False)
        rad_cap = None

    acc_hist = []
    def lg(xx): return grad_logp_beta(model, xx, beta=beta)

    for t in range(1, steps + 1):
        z, acc = mala_step_zboxed(z, lg, eps, z_lo, z_hi, rad_cap)
        acc_hist.append(acc)
        # simple adaptation toward target acceptance (burn-in style)
        if adapt and (t % adapt_interval == 0):
            # Robbins–Monro style log-step update
            avg = np.mean(acc_hist[-adapt_interval:])
            # small gain; keep eps within bounds
            eps = float(np.clip(eps * np.exp((avg - target_acc) * 0.1), eps_min, eps_max))

    X_out = x_scaler.inverse_transform(z.detach().cpu().numpy())
    return X_out, float(np.mean(acc_hist)), eps

# ------------ Tempered SMC (β-anneal 0→β*, with MALA rejuvenation) ------------
def systematic_resample_np(w_np, rng):
    N = w_np.size
    positions = (rng.random() + np.arange(N)) / N
    idx = np.zeros(N, dtype=np.int64)
    cumsum = np.cumsum(w_np)
    i = j = 0
    while i < N:
        if positions[i] < cumsum[j]: idx[i] = j; i += 1
        else: j += 1
    return idx

def run_smc_tempered(model, x_scaler, X0_raw, n_stages, n_particles, ess_frac,
                     mala_steps, base_eps, beta, device, Z_vis=None,
                     z_plo=1.0, z_phi=99.0, z_rad_phi=99.5, seed=0):
    d = X0_raw.shape[1]
    eps = base_eps / (d ** 0.5)
    rng = np.random.RandomState(seed)

    # replicate seeds to N
    if len(X0_raw) < n_particles:
        reps = int(np.ceil(n_particles / len(X0_raw)))
        X0_raw = np.vstack([X0_raw] * reps)[:n_particles]
    else:
        X0_raw = X0_raw[:n_particles]

    Z0 = x_scaler.transform(X0_raw).astype(np.float32)
    z  = torch.tensor(Z0, dtype=torch.float32, device=device)

    if Z_vis is not None:
        z_lo_np = np.percentile(Z_vis, z_plo, axis=0).astype(np.float32)
        z_hi_np = np.percentile(Z_vis, z_phi, axis=0).astype(np.float32)
        z_lo = torch.from_numpy(z_lo_np).to(device)
        z_hi = torch.from_numpy(z_hi_np).to(device)
        rad_cap = float(np.percentile(np.linalg.norm(Z_vis, axis=1), z_rad_phi))
    else:
        z_lo = torch.full_like(z, -float('inf')).mean(dim=0, keepdim=False)
        z_hi = torch.full_like(z,  float('inf')).mean(dim=0, keepdim=False)
        rad_cap = None

    betas = np.linspace(0, beta, n_stages+1) #np.geomspace(1e-3, beta, n_stages+1)[1:]
    logw = torch.zeros(n_particles, device=device)

    for i, b in enumerate(betas):
        with torch.no_grad():
            f = model(z)[:, 0]
        db = b - (0 if i == 0 else betas[i-1])
        logw = logw + db * f

        # normalize, ESS, resample
        w = torch.softmax(logw - torch.max(logw), dim=0)
        ess = 1.0 / float(torch.sum(w*w).item())
        print("ess", ess)
        if ess < ess_frac * n_particles:
            print("Resampling...")
            idx = systematic_resample_np(w.detach().cpu().numpy(), rng)
            z = z[idx]
            logw = torch.zeros_like(logw)

        # MALA rejuvenation targeting π_b(z) ∝ exp(b*f(z))
        def lg(xx): return grad_logp_beta(model, xx, beta=b)
        for _ in range(mala_steps):
            z, _ = mala_step_zboxed(z, lg, eps, z_lo, z_hi, rad_cap)

    X_out = x_scaler.inverse_transform(z.detach().cpu().numpy())
    return X_out



# ------------ SMC with checkpoint-annealing π_k(z) ∝ exp(f_{θ_k}(z)) ------------
def run_smc_ckpt(ckpt_paths, model, x_scaler, n_particles, ess_frac,
                 mala_steps, base_eps, device, Z_vis=None, z_plo=1.0, z_phi=99.0,
                 z_rad_phi=99.5, seed=0, init_mode="gaussian", X_vis=None):
    """
    ckpt_paths: list of checkpoint files in annealing order
    init_mode: 'gaussian' -> z0 ~ N(0,I) with proper q0 weighting
               'visible'  -> sample from visible set (no q0 term; logw starts at f_0)
    """
    assert len(ckpt_paths) >= 2, "Need at least two checkpoints for annealing."
    d = x_scaler.mean_.shape[0]
    eps = base_eps / (d ** 0.5)
    rng = np.random.RandomState(seed)

    # Z-box + radial cap from Z_vis if provided
    if Z_vis is not None:
        z_lo_np = np.percentile(Z_vis, z_plo, axis=0).astype(np.float32)
        z_hi_np = np.percentile(Z_vis, z_phi, axis=0).astype(np.float32)
        z_lo = torch.from_numpy(z_lo_np).to(device)
        z_hi = torch.from_numpy(z_hi_np).to(device)
        rad_cap = float(np.percentile(np.linalg.norm(Z_vis, axis=1), z_rad_phi))
    else:
        # broadcastable sentinels
        z_lo = torch.full((d,), -float('inf'), device=device)
        z_hi = torch.full((d,),  float('inf'), device=device)
        rad_cap = None

    # Initialize particles in Z
    if init_mode == "gaussian":
        z = torch.randn(n_particles, d, device=device)
        logq0 = -0.5 * (z.pow(2).sum(dim=1)) - 0.5 * d * math.log(2*math.pi)
        z = reflect_to_box(z, z_lo, z_hi); z = project_radial(z, rad_cap)
    else:  # 'visible'
        assert X_vis is not None and len(X_vis) > 0, "Need X_vis for init_mode='visible'."
        idx = rng.choice(len(X_vis), size=n_particles, replace=True)
        Z0 = x_scaler.transform(X_vis[idx]).astype(np.float32)
        z = torch.tensor(Z0, dtype=torch.float32, device=device)
        logq0 = torch.zeros(n_particles, device=device)

    # Load first checkpoint and weight: w0 ∝ γ0/q0
    load_model_state(ckpt_paths[0], model, map_location=device)
    with torch.no_grad(): f_prev = model(z)[:, 0]
    logw = f_prev - logq0  # importance from q0 to γ0

    # Anneal across checkpoints
    for ck in ckpt_paths[1:]:
        load_model_state(ck, model, map_location=device)
        with torch.no_grad(): f_curr = model(z)[:, 0]
        logw = logw + (f_curr - f_prev)  # w *= γ_k / γ_{k-1}
        f_prev = f_curr

        # normalize weights, ESS, resample
        w = torch.softmax(logw - torch.max(logw), dim=0)
        ess = 1.0 / float(torch.sum(w*w).item())
        print("ess", ess)
        if ess < ess_frac * n_particles:
           print("Resampling...")
           idx = systematic_resample_np(w.detach().cpu().numpy(), rng)
           z = z[idx]
           logw = torch.zeros_like(logw)

        # rejuvenate at current π_k with MALA
        def lg(xx): return grad_logp_beta(model, xx, beta=1.0)
        for _ in range(mala_steps):
            z, _ = mala_step_zboxed(z, lg, eps, z_lo, z_hi, rad_cap)
        
        with torch.no_grad():
            f_prev = model(z)[:, 0] 

    # Optional final rejuvenation at π_K
    def lg_final(xx): return grad_logp_beta(model, xx, beta=1.0)
    for _ in range(max(1, mala_steps // 2)):
        z, _ = mala_step_zboxed(z, lg_final, eps, z_lo, z_hi, rad_cap)

    X_out = x_scaler.inverse_transform(z.detach().cpu().numpy())
    return X_out


def log_q0_stdnormal(z):
    d = z.shape[1]
    return -0.5 * (z**2).sum(dim=1) - 0.5 * d * math.log(2*math.pi)

def grad_log_gamma(model, z, beta):
    # ∇ log γ_β(z) = β∇f(z) + (1-β)∇log q0(z), with ∇log q0(z) = -z
    z = z.detach().requires_grad_(True)
    f = model(z)[:, 0]
    (grad_f,) = torch.autograd.grad(f.sum(), z)
    grad = beta * grad_f - (1.0 - beta) * z
    # surrogate logγ for MALA accept ratios (consts drop out)
    logp = beta * f - 0.5 * (1.0 - beta) * (z**2).sum(dim=1)
    return logp.detach(), grad.detach()

@torch.no_grad()
def _conditional_ess_frac(dlogw):
    # dlogw: [N] incremental log-weights at this stage
    dlogw = dlogw - torch.max(dlogw)
    w = torch.exp(dlogw)
    s1 = torch.sum(w)
    s2 = torch.sum(w * w)
    cess = (s1 * s1) / (s2 + 1e-12)
    N = float(dlogw.numel())
    return float(cess / N)

@torch.no_grad()
def _find_next_beta(b_prev, f_curr, z, target_frac, max_beta=1.0, iters=25):
    """Bisection to pick β_next so that cESS/N ≈ target_frac for Δlogw = (β-β_prev)*(f - log q0)."""
    logq0 = log_q0_stdnormal(z)
    g = (f_curr - logq0)  # [N]
    # monotone in β: cESS goes down as β increases (unless g≈0)
    lo, hi = b_prev, max_beta
    # if even β=1 keeps cESS above target, jump straight to 1
    if _conditional_ess_frac((1.0 - b_prev) * g) >= target_frac:
        return 1.0
    for _ in range(iters):
        mid = 0.5 * (lo + hi)
        frac = _conditional_ess_frac((mid - b_prev) * g)
        if frac < target_frac:
            hi = mid
        else:
            lo = mid
    return hi

def run_ais_autotemp(
    model, x_scaler,
    n_particles, n_stages, mala_steps_per_stage, base_eps, device,
    Z_vis, z_plo=1.0, z_phi=99.0, z_rad_phi=99.5,
    target_cess_frac=0.7,  # per-stage cESS/N target
    eps_min=1e-4, eps_max=0.5, adapt_eps=False, target_acc=0.57, adapt_interval=10,
    seed=0
):
    """
    AIS with adaptive temperatures (AutoTemp) using a fixed stage budget.
    - Init: z ~ N(0,I) then immediately reflect_to_box + project_radial.
    - At each stage t=1..n_stages:
        * choose β_t via bisection so cESS/N ≈ target_cess_frac (relative to β_{t-1}),
          but force β_nstages = 1.0 (uses the exact number of stages).
        * weight increment: Δlogw = (β_t - β_{t-1}) * [ f(z) - log q0(z) ].
        * MALA rejuvenation at γ_{β_t} with reflection & radial cap.
    - Notes: Using bounds changes the target to the truncated/capped domain (logZ no longer unbiased).
    Returns: X_out [N,d_raw], logw [N], betas (list), mean_accept (float), final_eps (float)
    """
    torch.cuda.empty_cache() if device.type == "cuda" else None

    # --- Gaussian-within-box init ---
    rng = np.random.RandomState(seed)
    d = x_scaler.mean_.shape[0]
    eps = base_eps / (d ** 0.5)
    z = torch.from_numpy(rng.randn(n_particles, d).astype(np.float32)).to(device)

    # Bounds from Z_vis
    z_lo_np = np.percentile(Z_vis, z_plo, axis=0).astype(np.float32)
    z_hi_np = np.percentile(Z_vis, z_phi, axis=0).astype(np.float32)
    z_lo = torch.from_numpy(z_lo_np).to(device)
    z_hi = torch.from_numpy(z_hi_np).to(device)
    rad_cap = float(np.percentile(np.linalg.norm(Z_vis, axis=1), z_rad_phi))

    # push init into bounds
    z = reflect_to_box(z, z_lo, z_hi)
    z = project_radial(z, rad_cap)

    logw = torch.zeros(n_particles, device=device)
    betas = [0.0]
    acc_hist = []

    for t in range(1, n_stages + 1):
        b_prev = betas[-1]

        # 1) choose β_t
        with torch.no_grad():
            f_curr = model(z)[:, 0]
        if t < n_stages:
            b_t = _find_next_beta(b_prev, f_curr, z, target_cess_frac, max_beta=1.0)
        else:
            b_t = 1.0  # force finish exactly at 1.0 in the final stage
        d_beta = max(0.0, b_t - b_prev)
        betas.append(b_t)

        # 2) weight increment
        if d_beta > 0.0:
            logw += d_beta * (f_curr - log_q0_stdnormal(z))

        # 3) rejuvenation at γ_{β_t}
        def lg(xx): return grad_log_gamma(model, xx, beta=b_t)
        stage_acc = []
        for k in range(mala_steps_per_stage):
            z, acc = mala_step_zboxed(z, lg, eps, z_lo, z_hi, rad_cap)
            stage_acc.append(acc)
            if adapt_eps and ((k + 1) % adapt_interval == 0):
                # Robbins–Monro step-size adaptation inside the stage
                avg = float(np.mean(stage_acc[-adapt_interval:]))
                eps = float(np.clip(eps * np.exp((avg - target_acc) * 0.1), eps_min, eps_max))
        acc_hist.append(np.mean(stage_acc) if stage_acc else 1.0)

        # early finish (hit β=1 before using all stages): keep rejuvenating at β=1 without more weight
        if (b_t >= 1.0) and (t < n_stages):
            continue

    mean_acc = float(np.mean(acc_hist)) if acc_hist else 1.0

    # Diagnostics / outputs
    X_out = x_scaler.inverse_transform(z.detach().cpu().numpy())
    return X_out, logw.detach().cpu().numpy(), betas, mean_acc, eps


def make_beta_ladder(R, beta_max=1.0, scheme="linear", slope=4.0):
    """
    Return an increasing list betas[0..R-1], with betas[-1]=beta_max and betas[0]≈0.
    scheme="linear": betas = linspace(0, beta_max, R)
    scheme="geom"  : logistic spacing denser near 0 and 1, controlled by 'slope'
    """
    if R < 2:
        return np.array([beta_max], dtype=np.float32)
    if scheme == "linear":
        b = np.linspace(0.0, float(beta_max), R, dtype=np.float64)
    else:
        # map linspace through logistic; then scale to [0, beta_max]
        t = np.linspace(-slope, slope, R, dtype=np.float64)
        b = 1.0 / (1.0 + np.exp(-t))
        b[0]  = 0.0
        b[-1] = 1.0
        b *= float(beta_max)
    return b.astype(np.float32)

def run_parallel_tempering(
    model, x_scaler,
    n_replicas, chains_per_replica, iters, mala_steps_per_iter,
    base_eps, device, Z_vis,
    z_plo=1.0, z_phi=99.0, z_rad_phi=99.5,
    beta_max=1.0, ladder_scheme="geom", ladder_slope=4.0,
    swap_interval=1,  # how often (in PT iterations) to attempt swaps
    adapt_ladder=False, target_swap=0.20, ladder_lr=0.05,  # simple slope adaptation
    seed=0
):
    """
    Parallel Tempering with reflective-box MALA.
    Shapes:
      - Replicas: r=0..R-1 with betas[r] increasing, betas[-1]=beta_max (usually 1.0).
      - Batch: keep 'M = chains_per_replica' chains per replica (vectorized).
    Returns:
      X_cold [M, d_raw]  — final states from the cold replica (β=beta_max) in X-space
      stats dict         — diagnostics (swap rates, accept rates, ladder)
    """
    rng = np.random.RandomState(seed)
    torch.manual_seed(seed)

    d = x_scaler.mean_.shape[0]
    eps = base_eps / (d ** 0.5)

    # --- Bounds from visible set ---
    z_lo_np = np.percentile(Z_vis, z_plo, axis=0).astype(np.float32)
    z_hi_np = np.percentile(Z_vis, z_phi, axis=0).astype(np.float32)
    z_lo = torch.from_numpy(z_lo_np).to(device)
    z_hi = torch.from_numpy(z_hi_np).to(device)
    rad_cap = float(np.percentile(np.linalg.norm(Z_vis, axis=1), z_rad_phi))

    # --- Init (Gaussian → reflect/cap) ---
    R, M = n_replicas, chains_per_replica
    z = torch.from_numpy(rng.randn(R, M, d).astype(np.float32)).to(device)
    z = reflect_to_box(z, z_lo, z_hi)
    z = project_radial(z, rad_cap)

    # --- Ladder ---
    betas = make_beta_ladder(R, beta_max=beta_max, scheme=ladder_scheme, slope=ladder_slope)
    betas_t = torch.from_numpy(betas).to(device)  # [R]

    # Stats
    mala_acc_hist = [[] for _ in range(R)]
    swap_attempts = np.zeros(R-1, dtype=np.int64)
    swap_accepts  = np.zeros(R-1, dtype=np.int64)

    def local_logp_and_grad_fn(beta_scalar):
        def fn(z_block):  # z_block: [M,d] or [K,d]
            return grad_logp_beta(model, z_block, beta=float(beta_scalar))
        return fn

    for t in range(1, iters + 1):
        # 1) Local MALA updates per replica (vectorized over chains)
        for r in range(R):
            lg = local_logp_and_grad_fn(betas_t[r])
            # run mala_steps_per_iter steps at this beta
            for _ in range(mala_steps_per_iter):
                z_r, acc = mala_step_zboxed(z[r], lg, eps, z_lo, z_hi, rad_cap)
                z[r] = z_r
                mala_acc_hist[r].append(acc)

        # 2) Replica-exchange (adjacent) every swap_interval iterations
        if (swap_interval > 0) and (t % swap_interval == 0) and (R > 1):
            # Compute f for all states once
            with torch.no_grad():
                f_all = model(z.reshape(R*M, d))[:, 0].reshape(R, M)  # [R,M]

            # Alternate even/odd pairings each attempt for better mixing
            parity = (t // swap_interval) % 2
            start = parity  # 0 for (0,1),(2,3)...; 1 for (1,2),(3,4)...
            for r in range(start, R-1, 2):
                b_lo = betas_t[r]
                b_hi = betas_t[r+1]

                f_lo = f_all[r]     # [M]
                f_hi = f_all[r+1]   # [M]

                # Swap acceptance per chain j: α = exp((β_hi-β_lo)*(f_lo - f_hi))
                delta = (b_hi - b_lo) * (f_lo - f_hi)   # [M]
                # clamp to avoid inf/NaN in extreme tails
                delta = torch.clamp(delta, min=-100.0, max=100.0)
                log_u = torch.log(torch.rand_like(delta))
                accept_mask = log_u < torch.minimum(delta, torch.zeros_like(delta))  # compare in log-space safely
                accept_mask = accept_mask.unsqueeze(1)  # [M,1] for broadcasting over d

                # Do the swaps where accepted
                z_lo_new = torch.where(accept_mask, z[r+1], z[r])
                z_hi_new = torch.where(accept_mask, z[r],   z[r+1])
                z[r]   = z_lo_new
                z[r+1] = z_hi_new

                # Update counters
                swap_attempts[r] += M
                swap_accepts[r]  += int(accept_mask.float().sum().item())

        # 3) Optional: adapt ladder shape via 'slope' to hit target mean swap prob
        if adapt_ladder and (t % max(5, swap_interval) == 0) and (R > 2):
            # Estimate mean adjacent swap rate over recent window
            # (Use cumulative rates here for simplicity)
            rates = np.divide(swap_accepts, np.maximum(1, swap_attempts))
            mean_rate = float(np.mean(rates)) if rates.size > 0 else 0.0
            # Robbins–Monro update on the logistic slope controls compression near the ends
            # Increase slope if swaps too frequent (spread betas), decrease if too rare (compress betas)
            ladder_slope = float(np.clip(ladder_slope * np.exp((target_swap - mean_rate) * ladder_lr), 1.0, 20.0))
            new_betas = make_beta_ladder(R, beta_max=float(beta_max), scheme=ladder_scheme, slope=ladder_slope)
            betas[:] = new_betas
            betas_t = torch.from_numpy(betas).to(device)

    # Gather outputs from the cold (β=beta_max) replica
    z_cold = z[-1]  # [M,d]
    X_cold = x_scaler.inverse_transform(z_cold.detach().cpu().numpy())

    # Diagnostics
    mala_acc = {float(betas[r]): float(np.mean(mala_acc_hist[r])) if mala_acc_hist[r] else 1.0 for r in range(R)}
    swap_rates = np.divide(swap_accepts, np.maximum(1, swap_attempts)).tolist()

    stats = {
        "betas": betas.tolist(),
        "mala_acc_by_beta": mala_acc,
        "swap_rates_adjacent": swap_rates,
        "swap_attempts": swap_attempts.tolist(),
        "swap_accepts": swap_accepts.tolist(),
        "ladder_slope": ladder_slope,
    }
    return X_cold, stats

def _largest_divisor_leq(n, limit):
    for d in range(min(limit, n), 0, -1):
        if n % d == 0:
            return d
    return 1

# ------------ Main ------------
def main():
    ap = argparse.ArgumentParser("MALA / SMC / SMC-CKPT / AIS / PT (optimize in Z; de-norm at end)")

    # Core
    ap.add_argument("--method", choices=["mala","smc","smc_ckpt","ais_auto","pt"], default="pt")
    ap.add_argument("--artifacts", default="runs/run3")
    ap.add_argument("--csv", default="unique_m.csv")
    ap.add_argument("--seed", type=int, default=5)
    ap.add_argument("--out", default=None, help="If omitted, auto: runs/proposals_{method}_{seed}.csv")
    ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--use-oracle", action="store_true")

    # Data split & physics
    ap.add_argument("--visible-frac", type=float, default=1.0)  # fraction by Tc kept as "visible"
    ap.add_argument("--beta", type=float, default=1.0)          # target exp(beta f)
    ap.add_argument("--step-size", type=float, default=0.05,    # lr / MALA eps / SMC eps (pre 1/sqrt(d))
                    help="Base step (scaled by 1/sqrt(d))")

    # One unified compute budget
    ap.add_argument("--particles", type=int, default=500, help="Unified: particles / batch size")
    ap.add_argument("--stages", type=int, default=450, help="Unified: number of stages")
    ap.add_argument("--steps-per-stage", type=int, default=5, help="Unified: MALA steps per stage")

    # Stability knobs in Z
    ap.add_argument("--z-plo", type=float, default=1.0)
    ap.add_argument("--z-phi", type=float, default=99.0)
    ap.add_argument("--z-rad-phi", type=float, default=99.5)
    ap.add_argument("--adapt", action="store_true", help="Adapt eps during MALA/AIS burn-in")

    # Minimal method-specific (kept small, sensible defaults inside code)
    ap.add_argument("--ess-frac", type=float, default=0.5, help="SMC/AIS effective sample size trigger")
    ap.add_argument("--ckpt-glob", type=str, default="runs/epoch_*.pt")
    ap.add_argument("--ckpts", type=str, default=None)
    ap.add_argument("--ckpt-init", choices=["gaussian","visible"], default="gaussian")
    ap.add_argument("--pt-replicas", type=int, default=16, help="If not a divisor of particles, auto-adjust")
    ap.add_argument("--pt-beta-max", type=float, default=1.0)
    ap.add_argument("--pt-ladder", choices=["linear","geom"], default="geom")
    ap.add_argument("--pt-swap-interval", type=int, default=1)
    ap.add_argument("--pt-adapt-ladder", action="store_true")

    args = ap.parse_args()
    device = torch.device(args.device)

    # ---- Unified compute budget -> method knobs ----
    particles = args.particles
    stages = args.stages
    steps_per_stage = args.steps_per_stage
    total_steps = stages * steps_per_stage

    # SMC family
    args.smc_particles   = particles
    args.smc_stages      = stages
    args.smc_mala_steps  = steps_per_stage
    args.smc_ess_frac    = args.ess_frac

    # AIS
    args.ais_auto_stages     = stages
    args.ais_auto_steps      = steps_per_stage
    args.ais_auto_cess_frac  = args.ess_frac
    args.ais_auto_adapt_eps  = args.adapt
    args.ais_auto_target_acc = 0.57  # sane default

    # MALA (batch size == particles; steps == stages * steps_per_stage)
    args.n_starts = particles
    args.steps    = total_steps

    # PT: replicas * chains = particles; iters * pt_mala_steps = total_steps
    # choose pt_mala_steps=steps_per_stage so iters=stages (clean mental model)
    args.pt_mala_steps = steps_per_stage
    args.pt_iters      = stages

    replicas = _largest_divisor_leq(particles, args.pt_replicas)
    chains   = max(1, particles // replicas)
    args.pt_replicas = replicas
    args.pt_chains   = chains
    
    model, x_scaler, y_scaler, elem_cols, _ = load_artifacts(args.artifacts, device)
    df = pd.read_csv(args.csv)
    assert all(c in df.columns for c in elem_cols), "CSV columns don't match elem_cols.txt"
    X = df[elem_cols].to_numpy(np.float32)
    y = df["critical_temp"].to_numpy(np.float32)

    # Visible split (bottom 80% by Tc)
    thresh = np.quantile(y, args.visible_frac)
    vis_mask = (y <= thresh)
    X_vis, y_vis = X[vis_mask], y[vis_mask]
    Z_vis = x_scaler.transform(X_vis).astype(np.float32)
    idx_sorted = np.argsort(y_vis)[::-1]
    n0 = min(args.n_starts, len(idx_sorted))
    X_init = X_vis[idx_sorted[:n0]]
    
    # ---- Run selected method (unified compute applied earlier) ----

    if args.method == "mala":
        X_prop, acc, final_eps = run_mala_stable(
            model, x_scaler, X_init,
            steps=args.steps,                    # steps = stages * steps_per_stage
            base_eps=args.step_size,
            beta=args.beta,
            device=device,
            Z_vis=Z_vis,
            z_plo=args.z_plo, z_phi=args.z_phi, z_rad_phi=args.z_rad_phi,
            adapt=args.adapt
        )
        print(f"MALA acceptance ≈ {acc:.3f} (final eps={final_eps:.4g})")
        K = min(args.n_starts, len(X_init))

    elif args.method == "smc":
        if len(X_vis) == 0:
            raise ValueError("Visible set is empty.")
        rng = np.random.RandomState(args.seed)
        idx = rng.choice(len(X_vis), size=args.smc_particles, replace=True)
        X0 = X_vis[idx]
        X_prop = run_smc_tempered(
            model, x_scaler, X0,
            n_stages=args.smc_stages,
            n_particles=args.smc_particles,
            ess_frac=args.smc_ess_frac,
            mala_steps=args.smc_mala_steps,
            base_eps=args.step_size,
            beta=args.beta,
            device=device,
            Z_vis=Z_vis,
            z_plo=args.z_plo, z_phi=args.z_phi, z_rad_phi=args.z_rad_phi,
            seed=args.seed
        )
        K = min(args.n_starts, len(X_prop))

    elif args.method == "smc_ckpt":
        # Build checkpoint list and pick exactly `stages` evenly spaced
        paths = []
        if args.ckpt_glob: paths += sorted(glob.glob(args.ckpt_glob))
        if args.ckpts:     paths += [p.strip() for p in args.ckpts.split(",") if p.strip()]
        paths = sorted(dict.fromkeys(paths))
        if len(paths) < 2:
            raise ValueError("Provide at least two checkpoints via --ckpt-glob and/or --ckpts.")
        if args.smc_stages > len(paths):
            raise ValueError(f"Requested --stages={args.smc_stages} but only {len(paths)} checkpoints found.")
        idx = np.linspace(0, len(paths) - 1, num=args.smc_stages, dtype=int)
        subset = [paths[i] for i in idx]
        print(f"SMC-CKPT using {len(subset)} staged checkpoints")

        X_prop = run_smc_ckpt(
            ckpt_paths=subset,
            model=model, x_scaler=x_scaler,
            n_particles=args.smc_particles,
            ess_frac=args.smc_ess_frac,
            mala_steps=args.smc_mala_steps,
            base_eps=args.step_size,
            device=device,
            Z_vis=Z_vis,
            z_plo=args.z_plo, z_phi=args.z_phi, z_rad_phi=args.z_rad_phi,
            seed=args.seed,
            init_mode=args.ckpt_init,
            X_vis=X_vis
        )
        K = min(args.n_starts, len(X_prop))

    elif args.method == "ais_auto":
        X_prop, logw, betas, mean_acc, final_eps = run_ais_autotemp(
            model, x_scaler,
            n_particles=args.smc_particles,
            n_stages=args.ais_auto_stages,
            mala_steps_per_stage=args.ais_auto_steps,
            base_eps=args.step_size,
            device=device,
            Z_vis=Z_vis,
            z_plo=args.z_plo, z_phi=args.z_phi, z_rad_phi=args.z_rad_phi,
            target_cess_frac=args.ais_auto_cess_frac,
            adapt_eps=args.ais_auto_adapt_eps,
            target_acc=args.ais_auto_target_acc,
            seed=args.seed
        )
        print(f"AIS mean MALA acceptance ≈ {mean_acc:.3f} (final eps={final_eps:.4g})")
        K = min(args.n_starts, len(X_prop))

    elif args.method == "pt":
        X_prop, stats = run_parallel_tempering(
            model, x_scaler,
            n_replicas=args.pt_replicas,
            chains_per_replica=args.pt_chains,
            iters=args.pt_iters,                       # = stages
            mala_steps_per_iter=args.pt_mala_steps,    # = steps_per_stage
            base_eps=args.step_size,
            device=device,
            Z_vis=Z_vis,
            z_plo=args.z_plo, z_phi=args.z_phi, z_rad_phi=args.z_rad_phi,
            beta_max=args.pt_beta_max,
            ladder_scheme=args.pt_ladder,
            ladder_slope=4.0,          # fixed sane default; keeps CLI small
            swap_interval=args.pt_swap_interval,
            adapt_ladder=args.pt_adapt_ladder,
            target_swap=0.20,          # fixed sane default
            ladder_lr=0.05,            # fixed sane default
            seed=args.seed
        )
        # PT returns cold-replica batch of shape [chains, d]
        K = min(args.n_starts, len(X_prop))

    else:
        raise ValueError(f"Unknown method: {args.method}")

    # ---- Scoring & CSV ----
    if args.use_oracle:
        y_pred, s_norm, s_pct = score_with_oracle(X_prop)
        top = np.argsort(-y_pred)[:K]
        out_df = pd.DataFrame(X_prop[top], columns=elem_cols)
        out_df["predicted_tc"] = y_pred[top]
        out_df["score_norm01"] = s_norm[top]
        out_df["score_percentile"] = s_pct[top]
    else:
        Z_prop = x_scaler.transform(X_prop).astype(np.float32)
        with torch.no_grad():
            y_z = model(torch.from_numpy(Z_prop).to(device)).cpu().numpy()
        y_kelvin = y_scaler.inverse_transform(y_z).ravel()
        top = np.argsort(-y_kelvin)[:K]
        out_df = pd.DataFrame(X_prop[top], columns=elem_cols)
        out_df["predicted_tc"] = y_kelvin[top]

    out_df.to_csv(args.out, index=False)
    print(f"Saved {len(out_df)} proposals → {args.out}")

if __name__ == "__main__":
    main()
    
    