#!/usr/bin/env python3
import argparse
import glob
import json
import math
import os
import random
from typing import List, Tuple
import numpy as np
import torch
from model import FCNN, ackley

# --------- utilities ---------
def standardize(x, mean, std): return (x - mean) / std

def load_model(ckpt_path, model):
    state = torch.load(ckpt_path, map_location='cpu')
    # print(state.keys())
    model.load_state_dict(state['model_state_dict'])
    model.eval()
    return model

def log_gamma_from_model(model, x, mean, std, alpha=0.1, scale=5.0, eps=1e-8):
    # x: [N, d] (UNstandardized)
    x_std = standardize(x, mean, std)
    y = model(x_std).squeeze(-1)               # predicts y_trans
    y = torch.clamp(y, min=eps)                # ensure positive
    return (torch.log(y) - log(scale)) / alpha # log gamma (unnormalized)

def grad_log_gamma(model, x, mean, std, alpha=0.1, scale=5.0, eps=1e-8):
    x = x.detach().requires_grad_(True)
    lg = log_gamma_from_model(model, x, mean, std, alpha, scale, eps)  # [N]
    s = lg.sum()
    (grad_x,) = torch.autograd.grad(s, x)       # [N,d]
    return lg.detach(), grad_x.detach()

def systematic_resample(weights):
    N = weights.shape[0]
    positions = (torch.arange(N, device=weights.device) + torch.rand(1, device=weights.device)) / N
    cumsum = torch.cumsum(weights, dim=0)
    idx = torch.searchsorted(cumsum, positions)
    return idx

def clamp_cube(x, L=10.0):  # keep support in [-L, L]^d
    return torch.clamp(x, -L, L)

# --------- MALA move (several steps) ---------
def mala_move(x, loggamma_fn, step_size=1e-2, n_steps=5, L=10.0, target_acc=0.57):
    N, d = x.shape
    logg, grad = loggamma_fn(x)  # current
    accepts = 0
    for _ in range(n_steps):
        mu = x + step_size * grad
        noise = torch.randn_like(x) * (2*step_size)**0.5
        x_prop = mu+noise #clamp_cube(mu + noise, L=L)

        # evaluate at proposal
        logg_prop, grad_prop = loggamma_fn(x_prop)
        mu_prop = x_prop + step_size * grad_prop

        # MH log acceptance
        def log_q(x_to, x_from, mu_from):
            # Gaussian proposal N(mu_from, 2*step_size*I)
            diff = (x_to - mu_from)
            return -0.25/step_size * (diff*diff).sum(dim=1)
        logacc = (logg_prop - logg) + (log_q(x, x_prop, mu_prop) - log_q(x_prop, x, mu))
        accept = (torch.log(torch.rand(N, device=x.device)) < logacc)
        accepts += accept.float().mean().item()

        # update states
        x = torch.where(accept.view(-1,1), x_prop, x)
        logg = torch.where(accept, logg_prop, logg)
        grad = torch.where(accept.view(-1,1), grad_prop, grad)
    acc_rate = accepts / n_steps
    # (optional) adapt step_size outside here based on acc_rate vs target_acc
    return x, logg, acc_rate

# --------- SMC driver ---------
def smc_with_checkpoints(
    model_ctor, ckpt_paths, mean, std, N=4096, L=10.0, ess_frac=0.5,
    move_steps=5, step_size=1e-2, device='cuda', alpha=0.1, scale=5.0):
    
    d = mean.numel()
    R = 10.0
    x = torch.empty(N, d, device=device).uniform_(-R, R)
    logq0 = -d * math.log(2 * R)  # constant within the box

    # --- initial target γ0 ---
    model = model_ctor().to(device)
    model = load_model(ckpt_paths[0], model)
    logg = log_gamma_from_model(model, x, mean.to(device), std.to(device), alpha, scale)

    # importance weights: w ∝ γ0 / q0
    logw = logg - logq0
    logZ = []

    # normalize & ESS at stage 0
    maxlw = torch.max(logw)
    w = torch.exp(logw - maxlw); w = w / w.sum()
    ess = 1.0 / (w * w).sum().item()

    # (optional) stage-0 evidence estimate (IS): log ∑_i exp(logw_i) - log N
    Z0 = torch.logsumexp(logw, dim=0) - math.log(N)
    logZ.append(Z0.item())

    # resample + rejuvenate at π0 if ESS low, and UPDATE logg afterwards
    if ess < ess_frac * N:
        idx = systematic_resample(w)
        x = x[idx]
        logw = torch.zeros(N, device=device)  # weights reset post-resample
        # rejuvenate at π0 and carry forward current logγ0(x)
        def lg_fn_state(xx):
            return grad_log_gamma(model, xx, mean.to(device), std.to(device), alpha, scale)
        x, logg, _ = mala_move(x, lg_fn_state, step_size=step_size, n_steps=move_steps, L=L)
    else:
        # even without resampling, a short MALA can reduce variance (optional)
        pass

    # --- stages 1..K-1 ---
    for ckpt in ckpt_paths[1:]:
        # move target to π_k
        model = load_model(ckpt, model)

        # incremental weights: log w += log γ_k(x) - log γ_{k-1}(x)
        logg_new = log_gamma_from_model(model, x, mean.to(device), std.to(device), alpha, scale)
        logw = logw + (logg_new - logg)
        logg = logg_new  # carry forward current log γ_k(x)

        # normalize & ESS
        maxlw = torch.max(logw)
        w = torch.exp(logw - maxlw); w = w / w.sum()
        ess = 1.0 / (w * w).sum().item()

        # (optional) evidence at this stage (still IS style): logsumexp(logw) - log N
        Zk = torch.logsumexp(logw, dim=0) - math.log(N)
        logZ.append(Zk.item())

        # resample if ESS low; then rejuvenate under π_k and UPDATE logg
        if ess < ess_frac * N:
            idx = systematic_resample(w)
            x = x[idx]
            logw = torch.zeros(N, device=device)
            
        def lg_fn_state(xx):
            return grad_log_gamma(model, xx, mean.to(device), std.to(device), alpha, scale)
        
        x, logg, _ = mala_move(x, lg_fn_state, step_size=step_size, n_steps=move_steps, L=L)



    # --- optional: one more rejuvenation at final π_K (weights unchanged) ---
    def lg_fn_final(xx):
        return grad_log_gamma(model, xx, mean.to(device), std.to(device), alpha, scale)
    x, logg, _ = mala_move(x, lg_fn_final, step_size=step_size, n_steps=move_steps, L=L)

    # --- return equally-weighted samples (final resample) ---
    maxlw = torch.max(logw)
    w = torch.exp(logw - maxlw); w = w / w.sum()
    idx = systematic_resample(w)
    x = x[idx]
    return x


def smc_tempered_fixed(
    model_ctor, ckpt_path, mean, std,
    betas=np.linspace(0,1,10),                    # betas or linspace(0,1,K)
    N=4096, ess_frac=0.5,
    move_steps=5, step_size=1e-2, target_acc=0.57,
    device='cuda', alpha=0.1, scale=5.0, R=10.0, L=10.0
):
    """
    Bridges π_{β}(x) ∝ γ(x)^β with a FIXED schedule β_0=0 < ... < β_{K-1}=1.
    Weights: logw += (β_k - β_{k-1}) * logγ(x)  (computed BEFORE resampling).
    Rejuvenation: MALA targeting π_{β_k}.
    """
    mean_d, std_d = mean.to(device), std.to(device)
    d = mean_d.numel()

    # init from uniform box (π_{β=0})
    x = torch.empty(N, d, device=device).uniform_(-R, R)
    logw = torch.zeros(N, device=device)

    # load final model once
    model = model_ctor().to(device)
    model = load_model(ckpt_path, model)

    @torch.no_grad()
    def eval_logg(xx):
        # base logγ(x) (unscaled)
        return log_gamma_from_model(model, xx, mean_d, std_d, alpha, scale)

    def ess_from_logw(logw_):
        m = torch.max(logw_)
        w = torch.exp(logw_ - m)
        w = w / w.sum()
        return (1.0 / (w * w).sum()).item(), w

    # base logγ at current x
    logg = eval_logg(x)

    logZ_terms = []   # log E_{π_{β_{k-1}}}[ γ^{Δβ_k} ]
    acc_hist   = []

    # stages k = 1..K-1 (β_0=0 already)
    for k in range(1, len(betas)):
        beta_prev, beta_k = betas[k-1], betas[k]
        d_beta = beta_k - beta_prev
        assert d_beta > 0, "betas must be strictly increasing"

        # --- incremental weights (correct SMC tempering update) ---
        # log w += Δβ * logγ(x), evaluated at current particles
        logw = logw + d_beta * logg

        # evidence increment BEFORE any resampling
        Zk = torch.logsumexp(logw, dim=0) - math.log(N)
        logZ_terms.append(Zk.item())

        # normalize, compute ESS, maybe resample
        ess, w = ess_from_logw(logw)
        if ess < ess_frac * N:
            idx = systematic_resample(w)
            x = x[idx]
            logw = torch.zeros(N, device=device)  # weights reset after resampling

        # --- rejuvenation at π_{β_k} using MALA ---
        def lg_fn_state(xx):
            lg, g = grad_log_gamma(model, xx, mean_d, std_d, alpha, scale)  # base
            return beta_k * lg, beta_k * g

        x, _, acc = mala_move(x, lg_fn_state, step_size=step_size,
                              n_steps=move_steps, L=L, target_acc=target_acc)
        acc_hist.append(acc)

        # refresh base logγ for the NEXT incremental weight
        logg = eval_logg(x)

    # final equal-weight resample to return samples ~ π_{β=1}
    ess, w = ess_from_logw(logw)
    idx = systematic_resample(w)
    x = x[idx]

    # log normalizer of π_{β=1} relative to π_{β=0}
    logZ_rel = sum(logZ_terms)
    # absolute log Z(f) = log ∫ γ(x) dx, if π_{β=0} is uniform on box of volume (2R)^d:
    logZ_abs = logZ_rel + d * math.log(2 * R)

    return x

def mala_sample_final(
    model_ctor,
    ckpt_path,
    mean,
    std,
    N=4096,
    T=5000,            # total MALA steps
    burn_in=1000,      # steps to discard
    thin=10,           # keep one sample every `thin` steps after burn-in
    step_size=1e-2,    # same notion as in your mala_move (δ = 2*step_size)
    device='cuda',
    alpha=0.1,
    scale=5.0,
    R=10.0,            # init box half-width
    L=10.0,            # (optional) clamp box half-width; leave unused by default
    adapt=False,       # set True to enable light adaptation
    target_acc=0.57,
    adapt_every=25,
    adapt_gain=0.1,
    seed=None,
):
    """
    Runs N independent MALA chains targeting π(x) ∝ γ(x) from the FINAL model checkpoint.
    Returns:
      samples: [M*N, d] stacked thinned samples (M = floor((T - burn_in)/thin))
      last_x:  [N, d]   final states of the N chains
      acc_hist: list of running acceptance rates (per adapt window if adapt=True, else per 100 steps)
    """
    if seed is not None:
        torch.manual_seed(seed); np.random.seed(seed)

    mean_d, std_d = mean.to(device), std.to(device)
    d = mean_d.numel()

    # init from uniform box
    x = torch.empty(N, d, device=device).uniform_(-R, R)

    # load final model
    model = model_ctor().to(device)
    model = load_model(ckpt_path, model)

    # convenience: base logγ and grad
    def lg_fn(xx):
        return grad_log_gamma(model, xx, mean_d, std_d, alpha, scale)

    # evaluate at start
    logg, grad = lg_fn(x)

    kept = []
    accepts_window = 0.0
    acc_hist = []
    window = 0

    for t in range(1, T + 1):
        # proposal
        mu = x + step_size * grad
        noise = torch.randn_like(x) * math.sqrt(2.0 * step_size)
        x_prop = mu + noise
        # x_prop = clamp_cube(x_prop, L=L)  # optional hard box; typically leave commented

        # evaluate proposal
        logg_prop, grad_prop = lg_fn(x_prop)
        mu_prop = x_prop + step_size * grad_prop

        # MH log-accept
        def log_q(x_to, x_from, mu_from):
            diff = x_to - mu_from
            return -0.25 / step_size * (diff * diff).sum(dim=1)

        logacc = (logg_prop - logg) + (log_q(x, x_prop, mu_prop) - log_q(x_prop, x, mu))
        accept = (torch.log(torch.rand(N, device=x.device)) < logacc)

        # update states
        x = torch.where(accept.view(-1, 1), x_prop, x)
        logg = torch.where(accept, logg_prop, logg)
        grad = torch.where(accept.view(-1, 1), grad_prop, grad)

        # bookkeeping
        accepts_window += accept.float().mean().item()
        window += 1

        # optional adaptation
        if adapt and (t % adapt_every == 0):
            acc_rate = accepts_window / adapt_every
            step_size = step_size * math.exp(adapt_gain * (acc_rate - target_acc))
            acc_hist.append(acc_rate)
            accepts_window = 0.0
            window = 0

        # lightweight progress acceptance if not adapting
        if (not adapt) and (t % 100 == 0):
            acc_hist.append(accepts_window / window)
            accepts_window = 0.0
            window = 0

        # thinning after burn-in
        if t > burn_in and ((t - burn_in) % thin == 0):
            kept.append(x.detach().clone())

    samples = torch.cat(kept, dim=0) if len(kept) > 0 else x.detach().clone()
    last_x = x.detach().clone()
    return last_x #samples, last_x, acc_hist


def ais_auto_temp_fixed(
    model_ctor,
    ckpt_path,
    mean,
    std,
    N=4096,
    K=10,                         # total stages including β=0 and β=1
    ess_target_frac=0.8,          # maintain ESS ≥ this fraction * N after each update
    move_steps=5,
    step_size=1e-2,
    target_acc=0.57,
    device='cuda',
    alpha=0.1,
    scale=5.0,
    R=10.0,
    L=10.0,
    bisection_tol=1e-4,
    max_bisect_iters=50,
):
    """
    Annealed Importance Sampling with *automatic* temperature increments (Δβ) chosen
    on-the-fly by bisection to keep the *total* ESS after each increment above a target,
    while using exactly K stages (β_0=0 < ... < β_{K-1}=1).

    - No resampling (that's AIS).
    - At each stage k, we:
        1) choose Δβ by bisection so that ESS(logw + Δβ*logγ(x)) ≈ ess_target_frac*N,
           but cap Δβ so that we still reach β=1 by stage K-1;
        2) update weights: logw += Δβ * logγ(x_current);
        3) run MALA targeting π_{β_k} to move particles.

    Returns:
      dict with:
        - 'x':           [N, d] final particles (after β=1 moves)
        - 'logw':        [N] final log-weights
        - 'betas':       list of β's (length K)
        - 'logZ_rel':    float, log Z(β=1) - log Z(β=0)
        - 'logZ_abs':    float, adds log((2R)^d) since π_0 is uniform on [-R,R]^d
        - 'acc_hist':    list of acceptance rates per stage (post-move)
    """
    mean_d, std_d = mean.to(device), std.to(device)
    d = mean_d.numel()

    # init particles from π_0 (uniform on box)
    x = torch.empty(N, d, device=device).uniform_(-R, R)
    logw = torch.zeros(N, device=device)

    # model & base logγ(x) (no β scaling here)
    model = model_ctor().to(device)
    model = load_model(ckpt_path, model)

    @torch.no_grad()
    def eval_logg(xx):
        return log_gamma_from_model(model, xx, mean_d, std_d, alpha, scale)

    def ess_from_logw(logw_):
        m = torch.max(logw_)
        w = torch.exp(logw_ - m)
        w = w / w.sum()
        return (1.0 / (w * w).sum()).item()

    # current base logγ at x
    logg = eval_logg(x)

    betas = [0.0]
    beta = 0.0
    acc_hist = []

    # helper to bisection Δβ s.t. ESS(logw + Δβ*logg) hits target
    def choose_delta_beta(beta_curr, logw_curr, logg_curr, steps_left):
        """
        steps_left: how many more *after* this increment (so that we must reach 1 by then).
        We cap Δβ so remaining steps can still reach 1.
        """
        target_ess = ess_target_frac * N
        # maximum Δβ allowed so we can still hit β=1 in exactly 'steps_left' steps
        # (last step can take whatever remains)
        max_allowed = 1.0 - beta_curr
        if steps_left > 0:
            max_allowed = min(max_allowed, (1.0 - beta_curr))  # still cap by what's left
        # if we're at the final stage, force Δβ to whatever remains
        if steps_left == 0:
            return 1.0 - beta_curr

        # Quick checks
        # If even taking the whole remaining chunk keeps ESS above target, take as much as possible
        ess_full = ess_from_logw(logw_curr + max_allowed * logg_curr)
        if ess_full >= target_ess:
            return max_allowed

        # If any positive Δβ already crushes ESS below target, we'll find a small step via bisection
        lo, hi = 0.0, max_allowed
        # Guard: if hi is numerically tiny, just take it
        if hi < 1e-12:
            return hi

        # We want ESS(logw + Δβ*logg) == target_ess, and ESS decreases with Δβ (usually)
        # Bisection:
        for _ in range(max_bisect_iters):
            mid = 0.5 * (lo + hi)
            ess_mid = ess_from_logw(logw_curr + mid * logg_curr)
            if ess_mid >= target_ess:
                lo = mid
            else:
                hi = mid
            if (hi - lo) < bisection_tol:
                break
        return max(lo, 0.0)

    # Run K-1 annealing updates (since β_0=0 is already counted)
    for stage in range(1, K):
        steps_left_after = (K - 1) - (stage - 1)  # after *this* increment
        d_beta = choose_delta_beta(beta, logw, logg, steps_left_after)
        # in case of numerical weirdness, ensure we finish on the last stage
        if stage == (K - 1):
            d_beta = 1.0 - beta
        beta_next = beta + d_beta
        betas.append(beta_next)

        # --- AIS incremental weights
        # logw += Δβ * logγ(x_current)
        logw = logw + d_beta * logg

        # --- MALA move targeting π_{β_next}
        def lg_fn_state(xx):
            # base logγ and grad, then scale by β
            lg, g = grad_log_gamma(model, xx, mean_d, std_d, alpha, scale)
            return beta_next * lg, beta_next * g

        x, _, acc = mala_move(
            x, lg_fn_state, step_size=step_size,
            n_steps=move_steps, L=L, target_acc=target_acc
        )
        acc_hist.append(acc)

        # refresh base logγ(x) for next stage
        logg = eval_logg(x)
        beta = beta_next

    # Final log normalizer estimates:
    # AIS estimate of Z(β=1)/Z(β=0) is: (1/N) * sum_i exp(logw_i)
    logZ_rel = torch.logsumexp(logw, dim=0) - math.log(N)

    # Since π_0 is uniform on [-R,R]^d with volume (2R)^d, log Z(β=0) = log((2R)^d)
    logZ_abs = logZ_rel + d * math.log(2.0 * R)

    return {
        'x': x.detach(),
        'logw': logw.detach(),
        'betas': betas,
        'logZ_rel': float(logZ_rel.item()),
        'logZ_abs': float(logZ_abs),
        'acc_hist': acc_hist,
    }


def parallel_tempering_mala(
    model_ctor,
    ckpt_path,
    mean,
    std,
    N=4096,                 # total chains across ALL replicas (compute-matching knob)
    K=10,                   # number of replicas (match your “stages”)
    n_stages=10,            # number of sweep “stages” (each does move_steps + swaps)
    move_steps=5,           # per-replica MALA steps per stage (match your other code)
    base_step_size=1e-2,    # base MALA step size (scaled per β below)
    step_scale_power=1.0,   # step_size_i = base_step_size / max(beta_i,1e-6)**step_scale_power
    target_acc=0.57,
    device='cuda',
    alpha=0.1,
    scale=5.0,
    R=10.0,
    L=10.0,
    betas=None,             # optional list/1D tensor of length K with betas (ascending, ends at 1.0)
    ladder_kind='power',    # used if betas is None: 'lin' | 'geom' | 'power'
    ladder_param=4.0,       # power exponent (for 'power') or common ratio (for 'geom')
    seed=None,
):
    """
    Parallel Tempering with MALA within-replica moves and adjacent replica exchange.

    Compute matching:
      - Total chains = N (same N as your AIS/SMC).
      - K replicas, each holds ~N/K chains.
      - Each stage: every replica does `move_steps` of MALA, then we propose swaps.
      => Total grad calls per stage ≈ (N * move_steps), same scaling as your tempering schemes.

    Returns:
      {
        'cold_samples': [Ncold, d]  # final β=1 states
        'betas': list of β values (len K, ascending, last = 1.0)
        'swap_accept_mean': float   # overall swap acceptance
        'swap_accept_by_pair': [K-1] list of mean acceptances for each adjacent pair
        'mala_acc_by_replica': [K] mean MALA acceptance per replica
      }
    """
    if seed is not None:
        torch.manual_seed(seed); np.random.seed(seed)

    mean_d, std_d = mean.to(device), std.to(device)
    d = mean_d.numel()

    # --- Build temperature ladder if not provided ---
    if betas is None:
        if ladder_kind == 'lin':
            betas = torch.linspace(0.0, 1.0, K)
        elif ladder_kind == 'geom':
            # place more replicas near β=1: map t in [0,1] via geometric ratio r
            # construct via power on 1-t, then flip
            t = torch.linspace(0, 1, K)
            r = float(ladder_param)
            # avoid r<=0; r close to 0 squeezes near 1, r>1 squeezes near 0
            r = max(min(r, 0.999), 1e-6)
            betas = 1.0 - (1.0 - t)**(1.0 - r)
        else:  # 'power' (default): β = t**p, concentrates near 0 for p>1; near 1 for 0<p<1
            t = torch.linspace(0.0, 1.0, K)
            betas = t**float(ladder_param)
        betas[-1] = 1.0  # ensure exact 1.0
        betas[0]  = 0.0  # ensure exact 0.0
    else:
        betas = torch.as_tensor(betas, dtype=torch.float32, device='cpu')
        assert betas.numel() == K, "betas must have length K"
        assert torch.all(betas[:-1] <= betas[1:]), "betas must be nondecreasing"
        assert abs(float(betas[0])) < 1e-12 and abs(float(betas[-1]-1.0)) < 1e-6, "betas must start at 0 and end at 1"
    betas = betas.to(device)

    # --- Split N across K replicas (balanced; works even if N%K!=0) ---
    base = N // K
    rem  = N - base * K
    n_per = [base + (1 if i < rem else 0) for i in range(K)]
    assert sum(n_per) == N and all(n > 0 for n in n_per)

    # --- Model & hooks ---
    model = model_ctor().to(device)
    model = load_model(ckpt_path, model)

    def lg_and_grad(xx):
        return grad_log_gamma(model, xx, mean_d, std_d, alpha, scale)  # base (β=1) logγ and grad

    @torch.no_grad()
    def eval_logg(xx):
        return log_gamma_from_model(model, xx, mean_d, std_d, alpha, scale)  # base (β=1) logγ only

    # --- Initialize replicas ---
    x = [torch.empty(n_per[i], d, device=device).uniform_(-R, R) for i in range(K)]

    mala_acc_sums = torch.zeros(K, device=device)
    mala_acc_cnts = torch.zeros(K, device=device)
    swap_acc_sums = torch.zeros(K-1, device=device)   # per pair stats
    swap_acc_cnts = torch.zeros(K-1, device=device)

    # helper: one replica MALA
    def one_replica_mala(xi, beta_i, step_size_i):
        def lg_fn_state(xx):
            lg, g = lg_and_grad(xx)
            return beta_i * lg, beta_i * g
        xi, _, acc = mala_move(
            xi, lg_fn_state,
            step_size=step_size_i, n_steps=move_steps, L=L, target_acc=target_acc
        )
        return xi, acc

    # --- Main PT sweeps ---
    tiny = 1e-6
    for s in range(n_stages):
        # 1) Within-replica MALA (can be looped; you can torch.compile this if needed)
        for i in range(K):
            beta_i = float(betas[i].item())
            step_i = base_step_size / max(beta_i, tiny)**step_scale_power
            x[i], acc = one_replica_mala(x[i], beta_i, step_i)
            mala_acc_sums[i] += acc
            mala_acc_cnts[i] += 1.0

        # 2) Compute base logγ for swap decisions
        logg_base = [eval_logg(x[i]) for i in range(K)]  # each [n_i]

        # 3) Adjacent swaps (even/odd alternation)
        # even pairs: (0,1), (2,3), ...   odd pairs: (1,2), (3,4), ...
        start = 0 if (s % 2 == 0) else 1
        for i in range(start, K - 1, 2):
            ni = min(x[i].shape[0], x[i+1].shape[0])
            if ni == 0:
                continue
            # take first ni chains for pairwise swaps
            gi   = logg_base[i][:ni]
            gj   = logg_base[i+1][:ni]
            bi   = betas[i]
            bj   = betas[i+1]

            # log α = (β_j - β_i) * ( logγ(x_i) - logγ(x_j) )
            log_alpha = (bj - bi) * (gi - gj)
            u = torch.log(torch.rand(ni, device=device))
            accept = (u < log_alpha)

            # swap states where accepted
            if accept.any():
                xi = x[i][:ni].clone()
                xj = x[i+1][:ni].clone()
                x[i][:ni][accept]   = xj[accept]
                x[i+1][:ni][accept] = xi[accept]

                # swap cached energies for fairness in multi-pair stats
                gi2 = gi.clone()
                logg_base[i][:ni][accept]   = gj[accept]
                logg_base[i+1][:ni][accept] = gi2[accept]

            # stats
            swap_acc_sums[i] += accept.float().mean()
            swap_acc_cnts[i] += 1.0

    # --- Collect outputs ---
    cold_idx = K - 1  # β=1 replica
    cold_samples = x[cold_idx].detach().clone()

    mala_acc_by_replica = (mala_acc_sums / torch.clamp_min(mala_acc_cnts, 1)).tolist()
    pair_acc = (swap_acc_sums / torch.clamp_min(swap_acc_cnts, 1)).tolist()
    overall_swap_acc = float(torch.tensor(pair_acc, device=device).mean().item()) if len(pair_acc) > 0 else 0.0

    return {
        'cold_samples': cold_samples,
        'betas': betas.detach().tolist(),
        'swap_accept_mean': overall_swap_acc,
        'swap_accept_by_pair': pair_acc,
        'mala_acc_by_replica': mala_acc_by_replica,
    }





# ----------------------------
# Helpers
# ----------------------------

def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def find_checkpoints(ckpt_dir: str) -> Tuple[List[str], str]:
    """Return (sorted_ckpts, final_ckpt). Requires at least 2 training ckpts."""
    all_paths = sorted(glob.glob(os.path.join(ckpt_dir, "*.pt")))
    # keep only training checkpoints named ckpt_epoch_XXXX.pt
    paths = [p for p in all_paths if "ckpt_epoch_" in os.path.basename(p)]
    if len(paths) < 1:
        raise FileNotFoundError(f"No ckpt_epoch_*.pt found in {ckpt_dir}")
    # sort numerically by epoch
    def key_fn(x: str) -> int:
        b = os.path.basename(x)
        return int(b.split("ckpt_epoch_")[1].split(".pt")[0])
    sorted_files = sorted(paths, key=key_fn)
    final_ckpt = sorted_files[-1]
    return sorted_files, final_ckpt


def load_stats(ckpt_dir: str, device: str):
    stats_path = os.path.join(ckpt_dir, "stats.pt")
    if not os.path.exists(stats_path):
        raise FileNotFoundError(f"Missing stats at {stats_path}")
    ms = torch.load(stats_path, map_location=device)
    mean = ms["mean"].to(device)
    std = ms["std"].to(device)
    return mean, std


def device_from_arg(s: str) -> torch.device:
    if s == "auto":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return torch.device(s)


def energies_from_samples(samples, ackley_func):
    """samples: torch.Tensor [N,d] or np.ndarray; returns 1D np.ndarray energies."""
    if isinstance(samples, torch.Tensor):
        X = samples.detach().cpu().float()
    else:
        X = torch.from_numpy(np.asarray(samples)).float()
    with torch.no_grad():
        Y = ackley_func.f_func(X).detach().cpu()   # [N]
    return Y.numpy()


def summarize_energy(E_np):
    """E_np: 1D numpy array -> dict of metrics."""
    return {
        "min":  float(np.min(E_np)),
        "mean": float(np.mean(E_np)),
        "std":  float(np.std(E_np, ddof=1)),
    }


def equal_weight_resample(x: torch.Tensor, logw: torch.Tensor) -> torch.Tensor:
    with torch.no_grad():
        m = torch.max(logw)
        w = torch.exp(logw - m)
        w = w / w.sum()
        idx = systematic_resample(w)
        return x[idx]


# ----------------------------
# Subcommand runners
# ----------------------------

def run_smc_ckpt(args):
    dev = device_from_arg(args.device)
    if args.seed is not None:
        set_all_seeds(args.seed)

    mean, std = load_stats(args.ckpt_dir, dev)
    ckpts, _ = find_checkpoints(args.ckpt_dir)
    model_ctor = lambda: FCNN(n_dim=args.n_dim)

    samples = smc_with_checkpoints(
        model_ctor=model_ctor,
        ckpt_paths=ckpts,
        mean=mean,
        std=std,
        N=args.N,
        L=args.L,
        ess_frac=args.ess_frac,
        move_steps=args.move_steps,
        step_size=args.step_size,
        device=str(dev),
        alpha=args.alpha,
        scale=args.scale,
    )

    postprocess_and_save(samples, args, dev)


def run_smc_temp(args):
    dev = device_from_arg(args.device)
    if args.seed is not None:
        set_all_seeds(args.seed)

    mean, std = load_stats(args.ckpt_dir, dev)
    _, final_ckpt = find_checkpoints(args.ckpt_dir)
    model_ctor = lambda: FCNN(n_dim=args.n_dim)

    if args.K is None:
        # default: match the number of training ckpts
        ckpts, _ = find_checkpoints(args.ckpt_dir)
        K = len(ckpts)
    else:
        K = args.K
    betas = np.linspace(0.0, 1.0, K)

    samples = smc_tempered_fixed(
        model_ctor=model_ctor,
        ckpt_path=final_ckpt,
        mean=mean,
        std=std,
        betas=betas,
        N=args.N,
        ess_frac=args.ess_frac,
        move_steps=args.move_steps,
        step_size=args.step_size,
        target_acc=args.target_acc,
        device=str(dev),
        alpha=args.alpha,
        scale=args.scale,
        R=args.R,
        L=args.L,
    )

    postprocess_and_save(samples, args, dev)


def run_mcmc_mala(args):
    dev = device_from_arg(args.device)
    if args.seed is not None:
        set_all_seeds(args.seed)

    mean, std = load_stats(args.ckpt_dir, dev)
    _, final_ckpt = find_checkpoints(args.ckpt_dir)
    model_ctor = lambda: FCNN(n_dim=args.n_dim)

    samples = mala_sample_final(
        model_ctor=model_ctor,
        ckpt_path=final_ckpt,
        mean=mean,
        std=std,
        N=args.N,
        T=args.T,
        burn_in=args.burn_in,
        thin=args.thin,
        step_size=args.step_size,
        device=str(dev),
        alpha=args.alpha,
        scale=args.scale,
        R=args.R,
        L=args.L,
        adapt=args.adapt,
        target_acc=args.target_acc,
        adapt_every=args.adapt_every,
        adapt_gain=args.adapt_gain,
        seed=args.seed,
    )

    postprocess_and_save(samples, args, dev)


def run_ais_auto(args):
    dev = device_from_arg(args.device)
    if args.seed is not None:
        set_all_seeds(args.seed)

    mean, std = load_stats(args.ckpt_dir, dev)
    _, final_ckpt = find_checkpoints(args.ckpt_dir)
    model_ctor = lambda: FCNN(n_dim=args.n_dim)

    out = ais_auto_temp_fixed(
        model_ctor=model_ctor,
        ckpt_path=final_ckpt,
        mean=mean,
        std=std,
        N=args.N,
        K=args.K,
        ess_target_frac=args.ess_target_frac,
        move_steps=args.move_steps,
        step_size=args.step_size,
        target_acc=args.target_acc,
        device=str(dev),
        alpha=args.alpha,
        scale=args.scale,
        R=args.R,
        L=args.L,
        bisection_tol=args.bisect_tol,
        max_bisect_iters=args.bisect_max_iters,
    )

    if args.equal_resample:
        samples = equal_weight_resample(out["x"].to(dev), out["logw"].to(dev))
    else:
        # return weighted set by default (x and logw)
        samples = out["x"]

    postprocess_and_save(samples, args, dev, extra=out)


def run_pt_mala(args):
    dev = device_from_arg(args.device)
    if args.seed is not None:
        set_all_seeds(args.seed)

    mean, std = load_stats(args.ckpt_dir, dev)
    ckpts, _ = find_checkpoints(args.ckpt_dir)
    K_default = len(ckpts)
    K = args.K if args.K is not None else K_default

    _, final_ckpt = find_checkpoints(args.ckpt_dir)
    model_ctor = lambda: FCNN(n_dim=args.n_dim)

    out = parallel_tempering_mala(
        model_ctor=model_ctor,
        ckpt_path=final_ckpt,
        mean=mean,
        std=std,
        N=args.N,
        K=K,
        n_stages=args.n_stages if args.n_stages is not None else K,
        move_steps=args.move_steps,
        base_step_size=args.base_step_size,
        step_scale_power=args.step_scale_power,
        target_acc=args.target_acc,
        device=str(dev),
        alpha=args.alpha,
        scale=args.scale,
        R=args.R,
        L=args.L,
        betas=None,  # can expose if you want manual ladder
        ladder_kind=args.ladder_kind,
        ladder_param=args.ladder_param,
        seed=args.seed,
    )

    samples = out["cold_samples"]
    postprocess_and_save(samples, args, dev, extra=out)


# ----------------------------
# Common post-processing
# ----------------------------

def postprocess_and_save(samples: torch.Tensor, args, dev: torch.device, extra=None):
    """Optionally summarize Ackley energy and/or save samples."""
    # summarize energies if requested
    if args.summarize:
        ack = ackley(n_dim=args.n_dim)
        E = energies_from_samples(samples, ack)
        stats = summarize_energy(E)
        print(json.dumps({"energy_stats": stats}, indent=2))

    if args.out:
        os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)
        payload = {"samples": samples.detach().cpu()}
        if extra is not None:
            # drop large tensors except essentials
            cleaned = {}
            for k, v in extra.items():
                if isinstance(v, torch.Tensor):
                    cleaned[k] = v.detach().cpu()
                else:
                    cleaned[k] = v
            payload["extras"] = cleaned
        torch.save(payload, args.out)
        print(f"Saved outputs to {args.out}")


# ----------------------------
# Argparse
# ----------------------------

def build_parser():
    p = argparse.ArgumentParser(
        description="CLI to run SMC/AIS/MCMC/PT samplers against a learned Ackley energy."
    )
    sub = p.add_subparsers(dest="cmd", required=True)

    # Common args mixin
    def add_common(ap: argparse.ArgumentParser):
        ap.add_argument("--ckpt-dir", type=str, default="tanh_model", help="Directory containing ckpt_epoch_*.pt and stats.pt")
        ap.add_argument("--n-dim", type=int, default=10, dest="n_dim", help="Problem dimensionality for FCNN/Ackley.")
        ap.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda"], help="Compute device.")
        ap.add_argument("--N", type=int, default=10_000, help="Number of particles/chains.")
        ap.add_argument("--move-steps", type=int, default=5, dest="move_steps", help="MALA steps per stage.")
        ap.add_argument("--step-size", type=float, default=1e-2, dest="step_size", help="MALA step-size parameter.")
        ap.add_argument("--alpha", type=float, default=0.1, help="Target transform power used during training.")
        ap.add_argument("--scale", type=float, default=5.0, help="Target transform scale used during training.")
        ap.add_argument("--R", type=float, default=10.0, help="Init box half-width.")
        ap.add_argument("--L", type=float, default=10.0, help="(Optional) clamp half-width for proposals.")
        ap.add_argument("--target-acc", type=float, default=0.57, dest="target_acc", help="Target acceptance for adaptations.")
        ap.add_argument("--seed", type=int, default=None, help="Random seed.")
        ap.add_argument("--out", type=str, default=None, help="Path to save a torch .pt with samples and extras.")
        ap.add_argument("--summarize", action="store_true", help="Compute and print energy stats under true Ackley.")

    # smc-ckpt
    ap = sub.add_parser("smc-ckpt", help="SMC over training checkpoints.")
    add_common(ap)
    ap.add_argument("--ess-frac", type=float, default=1.1, dest="ess_frac", help="Resample when ESS < ess_frac * N.")
    ap.set_defaults(func=run_smc_ckpt)

    # smc-temp
    ap = sub.add_parser("smc-temp", help="SMC with fixed tempering schedule.")
    add_common(ap)
    ap.add_argument("--K", type=int, default=None, help="Number of stages (default: number of training ckpts).")
    ap.add_argument("--ess-frac", type=float, default=0.5, dest="ess_frac", help="Resample when ESS < ess_frac * N.")
    ap.set_defaults(func=run_smc_temp)

    # mcmc-mala
    ap = sub.add_parser("mcmc-mala", help="Plain MALA sampling.")
    add_common(ap)
    ap.add_argument("--T", type=int, default=None, help="Total MALA steps. If None, uses (#ckpts * move_steps).")
    ap.add_argument("--burn-in", type=int, default=0, dest="burn_in", help="Burn-in steps.")
    ap.add_argument("--thin", type=int, default=1, help="Keep 1 of every 'thin' samples after burn-in.")
    ap.add_argument("--adapt", action="store_true", help="Enable light step-size adaptation.")
    ap.add_argument("--adapt-every", type=int, default=25, dest="adapt_every", help="Adaptation window.")
    ap.add_argument("--adapt-gain", type=float, default=0.1, dest="adapt_gain", help="Adaptation gain.")
    ap.set_defaults(func=run_mcmc_mala)

    # ais-auto
    ap = sub.add_parser("ais-auto", help="AIS with automatic Δβ via ESS bisection.")
    add_common(ap)
    ap.add_argument("--K", type=int, default=10, help="Total stages including β=0 and β=1.")
    ap.add_argument("--ess-target-frac", type=float, default=0.8, dest="ess_target_frac",
                    help="Keep ESS ≥ this fraction * N after each increment.")
    ap.add_argument("--bisect-tol", type=float, default=1e-4, dest="bisect_tol", help="Bisection tolerance.")
    ap.add_argument("--bisect-max-iters", type=int, default=50, dest="bisect_max_iters", help="Max bisection iterations.")
    ap.add_argument("--equal-resample", action="store_true", help="Equal-weight resample the AIS output.")
    ap.set_defaults(func=run_ais_auto)

    # pt-mala
    ap = sub.add_parser("pt-mala", help="Parallel Tempering with MALA.")
    add_common(ap)
    ap.add_argument("--K", type=int, default=None, help="Number of replicas (default: number of training ckpts).")
    ap.add_argument("--n-stages", type=int, default=None, help="Number of PT sweeps (default: K).")
    ap.add_argument("--base-step-size", type=float, default=1e-2, dest="base_step_size", help="Base MALA step size.")
    ap.add_argument("--step-scale-power", type=float, default=1.0, dest="step_scale_power",
                    help="Replica step size scaling: step_i = base / max(beta_i,1e-6)^power.")
    ap.add_argument("--ladder-kind", type=str, default="power", choices=["power", "lin", "geom"],
                    help="Temperature ladder kind if betas not supplied.")
    ap.add_argument("--ladder-param", type=float, default=4.0, help="Power exponent (power) or ratio (geom).")
    ap.set_defaults(func=run_pt_mala)

    return p


# ----------------------------
# Entry
# ----------------------------

def main():
    parser = build_parser()
    args = parser.parse_args()

    # Convenience default for MCMC T: match compute with #ckpts * move_steps
    if args.cmd == "mcmc-mala" and args.T is None:
        ckpts, _ = find_checkpoints(args.ckpt_dir)
        args.T = len(ckpts) * args.move_steps

    # Dispatch
    args.func(args)


if __name__ == "__main__":
    main()
