import os
import sys

# Add the project root to sys.path
# This allows the script to find modules imported with 'src.' prefix
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
import time
import json
from typing import Dict, List, Optional, Tuple
import jax
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal

from src.mmsbvi.core.types import GridConfig2D, OUProcessParams2D, MMSBProblem2D, IPFP2DConfig
from src.mmsbvi.algorithms.ipfp_2d import solve_mmsb_ipfp_2d


def _parse_int_list_env(name: str, K: int) -> Optional[List[int]]:
    """Parse comma-separated integers from env; support -1 meaning K-1.
    Returns None if env not set or empty.
    """
    val = os.environ.get(name, "").strip()
    if not val:
        return None
    out = []
    for tok in val.split(','):
        tok = tok.strip()
        if not tok:
            continue
        try:
            v = int(tok)
            if v < 0:
                v = K + v
            if 0 <= v < K:
                out.append(v)
        except Exception:
            continue
    return sorted(list(set(out))) if out else None


def _parse_float_list_env(name: str) -> Optional[List[float]]:
    val = os.environ.get(name, "").strip()
    if not val:
        return None
    out = []
    for tok in val.split(','):
        tok = tok.strip()
        if not tok:
            continue
        try:
            out.append(float(tok))
        except Exception:
            continue
    return out if out else None


def _gaussian_pdf(z):
    # Standard normal PDF / 标准正态密度
    return jnp.exp(-0.5 * z * z) / jnp.sqrt(2.0 * jnp.pi)


# ============================================================================
# Comprehensive Evaluation Metrics from Optuna
# ============================================================================

def _trapz2(a: jnp.ndarray, hx: float, hy: float) -> float:
    """2D trapezoidal integration"""
    nx, ny = a.shape
    wx = jnp.ones((nx,), dtype=jnp.float64).at[0].set(0.5).at[-1].set(0.5)
    wy = jnp.ones((ny,), dtype=jnp.float64).at[0].set(0.5).at[-1].set(0.5)
    return float(jnp.sum(a * wx[:, None] * wy[None, :]) * hx * hy)


def compute_hpd_mask_2d(density: jnp.ndarray, hx: float, hy: float,
                        mass: float = 0.9) -> Tuple[jnp.ndarray, float, float, float]:
    """Compute 2D HPD (Highest Posterior Density) region mask"""
    nx, ny = density.shape
    wx = jnp.ones((nx,), dtype=jnp.float64).at[0].set(0.5).at[-1].set(0.5)
    wy = jnp.ones((ny,), dtype=jnp.float64).at[0].set(0.5).at[-1].set(0.5)
    w2 = (wx[:, None] * wy[None, :]).reshape((-1,))
    area_weights = w2 * float(hx * hy)
    
    flat = density.reshape((-1,))
    order = jnp.argsort(-flat)
    d_sorted = flat[order]
    w_sorted = area_weights[order]
    cum_mass = jnp.cumsum(d_sorted * w_sorted)
    idx = jnp.searchsorted(cum_mass, mass, side='left')
    idx = jnp.clip(idx, 0, d_sorted.size - 1)
    tau = float(d_sorted[idx])
    mask = density >= tau
    achieved_mass = _trapz2(density * mask.astype(density.dtype), hx, hy)
    area = _trapz2(mask.astype(density.dtype), hx, hy)
    return mask, tau, float(achieved_mass), float(area)


def sample_from_density_grid(key: jax.random.PRNGKey, 
                            density: jnp.ndarray, 
                            n_samples: int,
                            grid_x: jnp.ndarray,
                            grid_y: jnp.ndarray) -> jnp.ndarray:
    """Sample from grid density distribution"""
    flat_prob = density.flatten()
    flat_prob = flat_prob / jnp.sum(flat_prob)
    
    indices = jax.random.choice(key, flat_prob.size, shape=(n_samples,), p=flat_prob)
    
    ny = grid_y.size
    i_coords = indices // ny
    j_coords = indices % ny
    
    x_coords = grid_x[i_coords]
    y_coords = grid_y[j_coords]
    
    return jnp.column_stack([x_coords, y_coords])


def _compute_sliced_wasserstein_distance(samples1: jnp.ndarray, samples2: jnp.ndarray, 
                                       n_projections: int = 512, key: jax.random.PRNGKey = None) -> float:
    """Sliced Wasserstein Distance (SWD): gradient-stable alternative"""
    if key is None:
        key = jax.random.PRNGKey(42)
    
    d = samples1.shape[1]
    directions = jax.random.normal(key, (n_projections, d))
    directions = directions / jnp.linalg.norm(directions, axis=1, keepdims=True)
    
    wasserstein_distances = []
    for i in range(n_projections):
        u = directions[i]
        proj1 = samples1 @ u
        proj2 = samples2 @ u
        
        proj1_sorted = jnp.sort(proj1)
        proj2_sorted = jnp.sort(proj2)
        
        n1, n2 = len(proj1_sorted), len(proj2_sorted)
        if n1 != n2:
            n_common = max(n1, n2)
            t1 = jnp.linspace(0, 1, n1)
            t2 = jnp.linspace(0, 1, n2)
            t_common = jnp.linspace(0, 1, n_common)
            proj1_interp = jnp.interp(t_common, t1, proj1_sorted)
            proj2_interp = jnp.interp(t_common, t2, proj2_sorted)
        else:
            proj1_interp = proj1_sorted
            proj2_interp = proj2_sorted
        
        w2_1d = jnp.mean((proj1_interp - proj2_interp)**2)
        wasserstein_distances.append(w2_1d)
    
    swd = jnp.mean(jnp.array(wasserstein_distances))
    return float(jnp.sqrt(swd))


def _compute_hellinger_distance(density1: jnp.ndarray, density2: jnp.ndarray, 
                               hx: float, hy: float) -> float:
    """Density-level Hellinger distance"""
    rho1_safe = jnp.maximum(density1, 1e-15)
    rho2_safe = jnp.maximum(density2, 1e-15)
    sqrt_rho1 = jnp.sqrt(rho1_safe)
    sqrt_rho2 = jnp.sqrt(rho2_safe)
    hellinger_squared = 0.5 * _trapz2((sqrt_rho1 - sqrt_rho2)**2, hx, hy)
    return float(jnp.sqrt(jnp.maximum(hellinger_squared, 0.0)))


def _compute_energy_distance(samples1: jnp.ndarray, samples2: jnp.ndarray,
                             max_samples: int = 400,
                             key: Optional[jax.random.PRNGKey] = None) -> float:
    """Multivariate generalization of distribution-level CRPS (Energy Distance)"""
    X = samples1
    Y = samples2
    n1, n2 = X.shape[0], Y.shape[0]
    if key is None:
        key = jax.random.PRNGKey(0)
    
    if n1 > max_samples:
        key, sub = jax.random.split(key)
        idx = jax.random.choice(sub, n1, shape=(max_samples,), replace=False)
        X = X[idx]
        n1 = X.shape[0]
    if n2 > max_samples:
        key, sub = jax.random.split(key)
        idx = jax.random.choice(sub, n2, shape=(max_samples,), replace=False)
        Y = Y[idx]
        n2 = Y.shape[0]

    def mean_pairwise_norm(A, B):
        chunk = 200
        tot, cnt = 0.0, 0
        for i in range(0, A.shape[0], chunk):
            A_chunk = A[i:i+chunk]
            distances = jnp.linalg.norm(A_chunk[:, None, :] - B[None, :, :], axis=2)
            tot += jnp.sum(distances)
            cnt += distances.size
        return tot / cnt

    term1 = mean_pairwise_norm(X, Y)
    term2 = 0.5 * mean_pairwise_norm(X, X)
    term3 = 0.5 * mean_pairwise_norm(Y, Y)
    ed = term1 - term2 - term3
    return float(jnp.maximum(ed, 0.0))


def _compute_sinkhorn_divergence_2d(samples1: jnp.ndarray, samples2: jnp.ndarray, 
                                   epsilon: float = 0.1, max_iter: int = 100) -> float:
    """Debiased Sinkhorn Divergence"""
    def sinkhorn_ot_cost(x_samples, y_samples, eps, max_iterations):
        n, m = x_samples.shape[0], y_samples.shape[0]
        x_sqnorms = jnp.sum(x_samples**2, axis=1, keepdims=True)
        y_sqnorms = jnp.sum(y_samples**2, axis=1, keepdims=True)
        cost_matrix = x_sqnorms + y_sqnorms.T - 2 * jnp.dot(x_samples, y_samples.T)
        
        cost_percentile_90 = jnp.percentile(cost_matrix, 90)
        cost_scale = jnp.maximum(cost_percentile_90, 1e-6)
        cost_matrix_normalized = cost_matrix / cost_scale
        
        K = jnp.exp(-cost_matrix_normalized / eps)
        a = jnp.ones(n, dtype=jnp.float64) / n
        b = jnp.ones(m, dtype=jnp.float64) / m
        u = jnp.ones(n, dtype=jnp.float64)
        v = jnp.ones(m, dtype=jnp.float64)
        
        for iteration in range(max_iterations):
            u_old = u
            u = a / (K @ v + 1e-12)
            v = b / (K.T @ u + 1e-12)
            
            if jnp.any(~jnp.isfinite(u)) or jnp.any(~jnp.isfinite(v)):
                break
            if iteration > 5 and jnp.max(jnp.abs(u - u_old)) < 1e-6:
                break
        
        transport_plan = u[:, None] * K * v[None, :]
        ot_cost = jnp.sum(transport_plan * cost_matrix)
        return ot_cost
    
    var1 = jnp.mean(jnp.var(samples1, axis=0))
    var2 = jnp.mean(jnp.var(samples2, axis=0))
    med_dist2_approx = jnp.maximum(var1 + var2, 1e-12) * 2.0
    adaptive_eps = jnp.clip(0.1 * med_dist2_approx, 1e-3, 1.0)
    
    try:
        ot_pq = sinkhorn_ot_cost(samples1, samples2, adaptive_eps, max_iter)
        ot_pp = sinkhorn_ot_cost(samples1, samples1, adaptive_eps, max_iter)
        ot_qq = sinkhorn_ot_cost(samples2, samples2, adaptive_eps, max_iter)
        sinkhorn_div = ot_pq - 0.5 * ot_pp - 0.5 * ot_qq
        raw_result = jnp.maximum(sinkhorn_div, 0.0)
        return float(raw_result)
    except Exception as e:
        return float('inf')


def _matrix_sqrt_jax(A):
    """JAX-compatible matrix square root using eigendecomposition"""
    # Eigendecomposition: A = Q * diag(λ) * Q^T
    eigenvalues, eigenvectors = jnp.linalg.eigh(A)
    # Ensure positive eigenvalues (numerical stability)
    eigenvalues = jnp.maximum(eigenvalues, 1e-12)
    # sqrt(A) = Q * diag(sqrt(λ)) * Q^T
    sqrt_eigenvalues = jnp.sqrt(eigenvalues)
    A_sqrt = eigenvectors @ jnp.diag(sqrt_eigenvalues) @ eigenvectors.T
    return A_sqrt


def _compute_wasserstein_2d_legacy(samples1: jnp.ndarray, samples2: jnp.ndarray) -> float:
    """Legacy Gaussian W2 (for comparison) - JAX compatible version"""
    mu1 = jnp.mean(samples1, axis=0)
    mu2 = jnp.mean(samples2, axis=0)
    cov1 = jnp.cov(samples1.T)
    cov2 = jnp.cov(samples2.T)
    
    # Add small regularization for numerical stability
    cov1 = cov1 + 1e-8 * jnp.eye(cov1.shape[0])
    cov2 = cov2 + 1e-8 * jnp.eye(cov2.shape[0])
    
    mean_diff = jnp.linalg.norm(mu1 - mu2)**2
    
    # JAX-compatible matrix square root
    cov2_sqrt = _matrix_sqrt_jax(cov2)
    middle = cov2_sqrt @ cov1 @ cov2_sqrt
    middle_sqrt = _matrix_sqrt_jax(middle)
    
    trace_term = jnp.trace(cov1 + cov2 - 2 * middle_sqrt)
    w2_squared = mean_diff + trace_term
    return float(jnp.sqrt(jnp.maximum(w2_squared, 0.0)))


# ============================================================================
# PHYSICS-BASED EVALUATION METRICS - 物理一致性评估指标
# ============================================================================

# 已移除：基于错误轨迹假设的“物理一致性/因果比”函数，避免噪声干扰。


def make_abs_observation_bimodal_sequence(
    grid: GridConfig2D,
    times: jnp.ndarray,
    amp: float = 3.0,
    sx: float = 0.7,
    sy: float = 0.7,
    weights=(0.5, 0.5),
):
    """Create a time sequence of bimodal anchors induced by |x|-like nonlinear observation.
    生成 |x| 类非线性观测诱导的双峰锚点时间序列（沿 x 对称 ±μ(t)）。
    μ(t) = amp * |cos(2π t / T)|.
    """
    X = grid.points_x[:, None]
    Y = grid.points_y[None, :]
    Ttot = float(times[-1] - times[0] + 1e-12)
    out = []
    for t in list(times):
        theta = float((t - times[0]) / Ttot)
        mu = amp * jnp.abs(jnp.cos(2.0 * jnp.pi * theta))
        centers = [(-mu, 0.0), (mu, 0.0)]
        scales = [(sx, sy), (sx, sy)]
        rho = make_mixture_gaussian_2d(grid, centers, scales, weights)
        out.append(rho)
    return out


def make_mixture_gaussian_2d(grid: GridConfig2D, centers, scales, weights):
    """Axis-factorized Gaussian mixture on 2D grid.
    在二维网格上生成按坐标轴可分的高斯混合：∑ w N_x(cx, sx^2) N_y(cy, sy^2)
    """
    X = grid.points_x[:, None]
    Y = grid.points_y[None, :]
    density = jnp.zeros((grid.n_points_x, grid.n_points_y), dtype=jnp.float64)
    for (cx, cy), (sx, sy), w in zip(centers, scales, weights):
        zx = (X - cx) / sx
        zy = (Y - cy) / sy
        px = _gaussian_pdf(zx) / jnp.asarray(jnp.abs(sx))
        py = _gaussian_pdf(zy) / jnp.asarray(jnp.abs(sy))
        density = density + w * (px * py)
    # 2D trapezoid normalization / 二维梯形归一化
    wx = jnp.ones((grid.n_points_x,), dtype=jnp.float64).at[0].set(0.5).at[-1].set(0.5)
    wy = jnp.ones((grid.n_points_y,), dtype=jnp.float64).at[0].set(0.5).at[-1].set(0.5)
    mass = jnp.sum(density * wx[:, None] * wy[None, :]) * grid.spacing_x * grid.spacing_y
    density = density / (mass + 1e-15)
    return density


def make_ou_consistent_gmm_sequence(
    grid: GridConfig2D,
    times: jnp.ndarray,
    centers0,
    scales0,
    weights,
    ou: OUProcessParams2D,
):
    """Generate a sequence of 2D GMM marginals consistent with separable OU.
    生成与可分 OU 一致的二维 GMM 边际时间序列（闭式均值/方差演化）。
    centers0/scales0: 初始 t=0 的 (cx,cy)/(sx,sy) 列表
    """
    theta_x = float(ou.mean_reversion_x)
    theta_y = float(ou.mean_reversion_y)
    sigma_x = float(ou.diffusion_x)
    sigma_y = float(ou.diffusion_y)
    mu_eq_x = float(ou.equilibrium_mean_x)
    mu_eq_y = float(ou.equilibrium_mean_y)

    def evolve_params(t, c0, s0):
        cx0, cy0 = c0
        sx0, sy0 = s0
        ex = jnp.exp(-theta_x * t)
        ey = jnp.exp(-theta_y * t)
        # means
        mx = mu_eq_x + (cx0 - mu_eq_x) * ex
        my = mu_eq_y + (cy0 - mu_eq_y) * ey
        # variances
        var_eq_x = (sigma_x ** 2) / (2.0 * max(theta_x, 1e-12))
        var_eq_y = (sigma_y ** 2) / (2.0 * max(theta_y, 1e-12))
        vx = (sx0 ** 2) * (ex ** 2) + var_eq_x * (1.0 - ex ** 2)
        vy = (sy0 ** 2) * (ey ** 2) + var_eq_y * (1.0 - ey ** 2)
        sx = jnp.sqrt(jnp.maximum(vx, 1e-12))
        sy = jnp.sqrt(jnp.maximum(vy, 1e-12))
        return (float(mx), float(my)), (float(sx), float(sy))

    seq = []
    for t in list(times):
        centers_t = []
        scales_t = []
        for c0, s0 in zip(centers0, scales0):
            c_t, s_t = evolve_params(float(t), c0, s0)
            centers_t.append(c_t)
            scales_t.append(s_t)
        rho_t = make_mixture_gaussian_2d(grid, centers_t, scales_t, weights)
        seq.append(rho_t)
    return seq


def run_geometric_prior_ablation_experiment():
    """Geometric Prior Ablation: Drift vs Diffusion Mismatch
    
    理论验证："先验即几何" - 扩散项比漂移项更重要
    
    实验设计 (FIXED VERSION)：
    A. 正确先验：θ=1.5, σ=0.4, μ=0.0 (ground truth)
    B. 漂移适度错配：θ=0.75, σ=0.4, μ=(0.8,-0.6) (动力学错配，几何正确)  
    C. 扩散适度错配：θ=1.5, σ=0.8, μ=0.0 (几何错配，动力学正确)
    
    修复：时间范围0-3.0，平衡收敛条件，诊断平衡态方差
    预期：A < B,C 且 B vs C 显示几何 vs 动力学的相对重要性
    """
    print("GEOMETRIC PRIOR ABLATION EXPERIMENT - FIXED VERSION")
    print("=" * 60)
    print("理论验证：先验即几何 - 扩散 vs 漂移的重要性")
    print("修复: 时间范围3.0 + 适度参数错配 + 平衡收敛条件")
    print("=" * 60)
    
    # 可配置快速模式（用于轻量验证）：设置环境变量 MMSBVI_FAST=1 启用
    fast = os.environ.get('MMSBVI_FAST', '0') == '1'
    # Grid 2D - 与OU-GMM实验一致（快速模式降采样）
    nx, ny = (64, 64) if fast else (128, 128)
    grid = GridConfig2D.create(nx, ny, (-6.0, 6.0), (-6.0, 6.0))

    # TRUE OU params for data generation (ground truth)
    ou_true = OUProcessParams2D(
        mean_reversion_x=1.5,    # θx = 1.5 (ground truth)
        diffusion_x=0.4,        # σx = 0.4 (ground truth) 
        equilibrium_mean_x=0.0, # μx = 0.0 (ground truth)
        mean_reversion_y=1.5,    # θy = 1.5
        diffusion_y=0.4,        # σy = 0.4
        equilibrium_mean_y=0.0, # μy = 0.0
    )

    # Observation times (K marginals) - OPTIMAL range for OU dynamics
    K = 5 if fast else 8
    times = jnp.linspace(0.0, 3.0, K)  # 3τ ≈ 2个时间常数，保留动力学信息
    # 选取锚点（硬边际）：默认 {0, last}；可用 MMSBVI_ANCHORS 覆盖（示例："0,-1" 或 "0,4,7"）
    env_anchors = _parse_int_list_env('MMSBVI_ANCHORS', K)
    if env_anchors is not None and len(env_anchors) > 0:
        anchor_indices = set(env_anchors)
    else:
        anchor_indices = {0, K - 1} if K >= 2 else set(range(K))
    anchor_mask = [i in anchor_indices for i in range(K)]

    # Generate UNIFIED observation data using TRUE OU process
    print(f"Generating unified GMM sequence with TRUE OU process...")
    print(f"   Ground truth: θ=({ou_true.mean_reversion_x}, {ou_true.mean_reversion_y}), σ=({ou_true.diffusion_x}, {ou_true.diffusion_y})")
    
    # Use OU-consistent GMM sequence as ground truth data
    initial_centers = [(-2.0, 1.0), (2.0, -1.0)]  # 双峰初始位置
    initial_scales = [(0.6, 0.5), (0.5, 0.6)]     # 各向异性尺度
    weights = [0.5, 0.5]                           # 等权重
    
    observed_unified = make_ou_consistent_gmm_sequence(
        grid, times, initial_centers, initial_scales, weights, ou_true
    )
    
    print(f"   Generated {K} marginals, each with {len(initial_centers)} components")

    # THREE REFERENCE MEASURES FOR ABLATION
    
    # A. 正确先验 (GROUND TRUTH)
    ou_correct = OUProcessParams2D(
        mean_reversion_x=1.5, diffusion_x=0.4, equilibrium_mean_x=0.0,
        mean_reversion_y=1.5, diffusion_y=0.4, equilibrium_mean_y=0.0,
    )
    
    # B. 漂移适度错配 (动力学错误，几何正确) 
    ou_drift_wrong = OUProcessParams2D(
        mean_reversion_x=0.75, diffusion_x=0.4, equilibrium_mean_x=0.8,   # ❌ θ减少50%, μ适度偏移
        mean_reversion_y=0.75, diffusion_y=0.4, equilibrium_mean_y=-0.6,  # ❌ θ减少50%, μ适度偏移
    )
    
    # C. 扩散适度错配 (几何错误，动力学正确) 
    ou_diffusion_wrong = OUProcessParams2D(
        mean_reversion_x=1.5, diffusion_x=0.8, equilibrium_mean_x=0.0,   # ❌ σ增加100%
        mean_reversion_y=1.5, diffusion_y=0.8, equilibrium_mean_y=0.0,   # ❌ σ增加100%
    )
    
    # D. 扩散各向异性错配（保持 σ_x^2+σ_y^2 不变，σ_y=2σ_x）
    sigma_trace = float(ou_true.diffusion_x**2 + ou_true.diffusion_y**2)
    sigma_x_aniso = float(jnp.sqrt(max(sigma_trace / 5.0, 1e-12)))
    sigma_y_aniso = float(2.0 * sigma_x_aniso)
    ou_aniso_diff = OUProcessParams2D(
        mean_reversion_x=1.5, diffusion_x=sigma_x_aniso, equilibrium_mean_x=0.0,
        mean_reversion_y=1.5, diffusion_y=sigma_y_aniso, equilibrium_mean_y=0.0,
    )
    
    # E. 近布朗漂移错配（θ 很小但 >0，几何正确）
    theta_small = 0.05
    ou_brownian_drift = OUProcessParams2D(
        mean_reversion_x=theta_small, diffusion_x=0.4, equilibrium_mean_x=0.0,
        mean_reversion_y=theta_small, diffusion_y=0.4, equilibrium_mean_y=0.0,
    )

    tight = os.environ.get('MMSBVI_TIGHT', '0') == '1'
    cfg_unified = IPFP2DConfig(
        max_iterations=(3000 if tight else (300 if fast else 1500)),
        tolerance=(3e-4 if tight else (1e-3 if fast else 5e-4)),
        check_interval=20,
        use_anderson=True,
        epsilon_scaling=True,
        initial_epsilon=0.5154818968807595,
        eps_decay_high=0.9070680010191174,
        eps_decay_low=0.837639888660412,
        min_epsilon=(2e-4 if tight else (5e-4 if fast else 5e-5)),
        error_threshold=(1.5e-3 if tight else (5e-3 if fast else 2e-3)),
        verbose=True,
        compiled_loop=True,
        compiled_max_iterations=(3000 if tight else (300 if fast else 1500)),
        compiled_check_interval=20,
        use_pallas_kernels=False,
        anchor_mask=anchor_mask,
    )
    
    experiments = [
        ("A_Correct_Prior", ou_correct, "✅ 正确先验 (Ground Truth)"),
        ("B_Drift_Mismatch", ou_drift_wrong, "❌ 漂移适度错配 (θ减少50%, μ适度偏移, σ正确)"),
        ("C_Diffusion_Mismatch", ou_diffusion_wrong, "❌ 扩散适度错配 (σ增加100%, θ & μ正确)"),
        ("D_AnisoDiff_Mismatch", ou_aniso_diff, "❌ 扩散各向异性错配 (保持 σ_x^2+σ_y^2 不变, σ_y=2σ_x)"),
        ("E_Brownian_Drift_Mismatch", ou_brownian_drift, "❌ 漂移近布朗错配 (θ→小, σ 正确)"),
    ]

    # 可选扫描：环境变量 MMSBVI_SWEEP_SIGMA="0.5,0.75,1,1.25,1.5"；MMSBVI_SWEEP_ANISO="1,1.5,2,2.5,3"
    def _append_sigma_scale_sweep(exps: List[Tuple[str, OUProcessParams2D, str]], true_ou: OUProcessParams2D):
        scales = _parse_float_list_env('MMSBVI_SWEEP_SIGMA')
        if not scales:
            return exps
        out = list(exps)
        for c in scales:
            ou = OUProcessParams2D(
                mean_reversion_x=true_ou.mean_reversion_x,
                diffusion_x=float(c) * true_ou.diffusion_x,
                equilibrium_mean_x=true_ou.equilibrium_mean_x,
                mean_reversion_y=true_ou.mean_reversion_y,
                diffusion_y=float(c) * true_ou.diffusion_y,
                equilibrium_mean_y=true_ou.equilibrium_mean_y,
            )
            out.append((f"C_sigma_{c:.2f}", ou, f"σ 等比缩放 c={c:.2f} (几何缩放)"))
        return out

    def _append_anisotropy_sweep(exps: List[Tuple[str, OUProcessParams2D, str]], true_ou: OUProcessParams2D):
        ratios = _parse_float_list_env('MMSBVI_SWEEP_ANISO')
        if not ratios:
            return exps
        out = list(exps)
        S = float(true_ou.diffusion_x**2 + true_ou.diffusion_y**2)
        for r in ratios:
            r = float(r)
            if r <= 0:
                continue
            sx = float(jnp.sqrt(max(S / (1.0 + r*r), 1e-12)))
            sy = float(r * sx)
            ou = OUProcessParams2D(
                mean_reversion_x=true_ou.mean_reversion_x,
                diffusion_x=sx,
                equilibrium_mean_x=true_ou.equilibrium_mean_x,
                mean_reversion_y=true_ou.mean_reversion_y,
                diffusion_y=sy,
                equilibrium_mean_y=true_ou.equilibrium_mean_y,
            )
            out.append((f"D_aniso_r{r:.2f}", ou, f"σ 各向异性比 r={r:.2f}, 保持 Tr(ΣΣᵀ)"))
        return out

    experiments = _append_sigma_scale_sweep(experiments, ou_true)
    experiments = _append_anisotropy_sweep(experiments, ou_true)
    
    results = {}
    
    print(f"\n Running ablation experiments with unified parameters...")
    print(f"   Anchor indices: {sorted(list(anchor_indices))} (of {K} time points)")
    
    for exp_name, ou_params, description in experiments:
        print(f"\n {exp_name}: {description}")
        print(f"   OU params: θ=({ou_params.mean_reversion_x:.1f},{ou_params.mean_reversion_y:.1f}), " +
              f"σ=({ou_params.diffusion_x:.1f},{ou_params.diffusion_y:.1f}), " +
              f"μ=({ou_params.equilibrium_mean_x:.1f},{ou_params.equilibrium_mean_y:.1f})")
        
        # 诊断信息：平衡态方差对比
        equilibrium_var_x = (ou_params.diffusion_x**2) / (2 * ou_params.mean_reversion_x)
        equilibrium_var_y = (ou_params.diffusion_y**2) / (2 * ou_params.mean_reversion_y)
        true_var_x = (ou_true.diffusion_x**2) / (2 * ou_true.mean_reversion_x)
        true_var_y = (ou_true.diffusion_y**2) / (2 * ou_true.mean_reversion_y)
        
        print(f"    Equilibrium variance: σ²/(2θ) = ({equilibrium_var_x:.3f}, {equilibrium_var_y:.3f})")
        print(f"    True data variance:  σ²/(2θ) = ({true_var_x:.3f}, {true_var_y:.3f})")
        print(f"    Variance ratio vs true: ({equilibrium_var_x/true_var_x:.2f}x, {equilibrium_var_y/true_var_y:.2f}x)")
        
        # Create problem with specific reference measure
        problem = MMSBProblem2D(
            observation_times=times,
            ou_params=ou_params,  # 🎯 ONLY this changes between experiments!
            grid=grid,
            observed_marginals=observed_unified,  # Same data for all
        )
        
        t0 = time.time()
        sol = solve_mmsb_ipfp_2d(problem, cfg_unified)  # Same config for all
        t1 = time.time()
        
        # Evaluate L1 marginal errors（区分锚点/非锚点）
        l1_errors = []
        l1_errors_anchor = []
        l1_errors_holdout = []
        sum_nll = 0.0
        hx, hy = grid.spacing_x, grid.spacing_y
        wx = jnp.ones((nx,), dtype=jnp.float64).at[0].set(0.5).at[-1].set(0.5)
        wy = jnp.ones((ny,), dtype=jnp.float64).at[0].set(0.5).at[-1].set(0.5)
        
        for k in range(K):
            comp = sol.path_densities[k].astype(jnp.float64)
            tgt = observed_unified[k]
            l1 = jnp.sum(jnp.abs(comp - tgt) * wx[:, None] * wy[None, :]) * hx * hy
            l1_errors.append(float(l1))
            if anchor_mask[k]:
                l1_errors_anchor.append(float(l1))
            else:
                l1_errors_holdout.append(float(l1))
            
            # NLL computation
            tiny = jnp.finfo(comp.dtype).tiny * 1e10
            tiny = jnp.maximum(tiny, jnp.asarray(1e-30, comp.dtype))
            nll_k = -jnp.sum(tgt * jnp.log(jnp.maximum(comp, tiny)) * wx[:, None] * wy[None, :]) * hx * hy
            sum_nll += float(nll_k)
        
        avg_l1 = float(jnp.mean(jnp.array(l1_errors)))
        avg_l1_anchor = float(jnp.mean(jnp.array(l1_errors_anchor))) if l1_errors_anchor else float('nan')
        avg_l1_holdout = float(jnp.mean(jnp.array(l1_errors_holdout))) if l1_errors_holdout else float('nan')
        
        # COMPREHENSIVE EVALUATION: "Stolen" from Optuna 哈哈哈深度思考
        print(f"   Computing comprehensive evaluation metrics stolen from Optuna...")
        print(f"   This includes: SWD, Sinkhorn, Hellinger, CRPS, HPD coverage, RMSE, LogScore...")
        
        # Compute all comprehensive metrics stolen from Optuna（支持多种子平均采样指标）
        seeds = _parse_int_list_env('MMSBVI_SEEDS', 10) or [42]
        n_samples = 200 if fast else 1000
        n_proj = 128 if fast else 512
        sinkhorn_iter = 50 if fast else 200

        swd_vals, sink_vals, w2_legacy_vals = [], [], []
        crps_vals = []
        for seed in seeds:
            key = jax.random.PRNGKey(int(seed))
            all_pred_samples = []
            all_true_samples = []
            for k in range(K):
                key, subkey = jax.random.split(key)
                model_samples = sample_from_density_grid(subkey, sol.path_densities[k].astype(jnp.float64), n_samples, grid.points_x, grid.points_y)
                all_pred_samples.append(model_samples)
                key, subkey = jax.random.split(key)
                target_samples = sample_from_density_grid(subkey, observed_unified[k].astype(jnp.float64), n_samples, grid.points_x, grid.points_y)
                all_true_samples.append(target_samples)
            pred_samples_flat = jnp.concatenate(all_pred_samples, axis=0)
            true_samples_flat = jnp.concatenate(all_true_samples, axis=0)
            try:
                swd_vals.append(_compute_sliced_wasserstein_distance(pred_samples_flat, true_samples_flat, n_projections=n_proj, key=key))
            except Exception:
                swd_vals.append(float('inf'))
            sink_vals.append(_compute_sinkhorn_divergence_2d(pred_samples_flat, true_samples_flat, epsilon=0.1, max_iter=sinkhorn_iter))
            w2_legacy_vals.append(_compute_wasserstein_2d_legacy(pred_samples_flat, true_samples_flat))
            # Energy Distance per-time then mean
            eds = []
            for k in range(K):
                key, sub_t = jax.random.split(key)
                eds.append(_compute_energy_distance(all_pred_samples[k], all_true_samples[k], max_samples=400, key=sub_t))
            crps_vals.append(float(jnp.mean(jnp.array(eds))))
        swd_value = float(jnp.mean(jnp.array(swd_vals)))
        sinkhorn_divergence = float(jnp.mean(jnp.array(sink_vals)))
        legacy_wasserstein = float(jnp.mean(jnp.array(w2_legacy_vals)))
        crps_value = float(jnp.mean(jnp.array(crps_vals)))
        print(f"     SWD computed (avg over {len(seeds)} seeds): {swd_value:.6f}")
        print(f"     Sinkhorn divergence computed (avg): {sinkhorn_divergence:.6f}")
        print(f"     Legacy Wasserstein computed (avg): {legacy_wasserstein:.6f}")
        print(f"     CRPS (Energy Distance) computed (avg): {crps_value:.6f}")
        
        # Compute Hellinger distance (density-level)
        hellinger_distances = []
        hellinger_holdout = []
        for k in range(K):
            model_density = sol.path_densities[k].astype(jnp.float64)
            target_density = observed_unified[k].astype(jnp.float64)
            hellinger_d = _compute_hellinger_distance(model_density, target_density, hx, hy)
            hellinger_distances.append(hellinger_d)
            if not anchor_mask[k]:
                hellinger_holdout.append(hellinger_d)
        hellinger_value = float(jnp.mean(jnp.array(hellinger_distances)))
        hellinger_holdout_value = float(jnp.mean(jnp.array(hellinger_holdout))) if hellinger_holdout else float('nan')
        print(f"     Hellinger distance computed: {hellinger_value:.6f}")
        if hellinger_holdout:
            print(f"     Hellinger (holdout only): {hellinger_holdout_value:.6f}")
        
        # Compute HPD coverage and areas
        hpd_coverages = []
        hpd_areas = []
        hpd_areas_target = []
        
        for k in range(K):
            model_density = sol.path_densities[k].astype(jnp.float64)
            target_density = observed_unified[k].astype(jnp.float64)
            
            # Normalize model density
            m_mass = _trapz2(model_density, hx, hy)
            if abs(m_mass - 1.0) > 1e-6:
                model_density = model_density / (m_mass + 1e-15)
            
            # Compute HPD for model
            mask_m, tau_m, mass_m, area_m = compute_hpd_mask_2d(model_density, hx, hy, mass=0.9)
            
            # Compute coverage: how much of target falls in model HPD
            cov_hpd = _trapz2(target_density * mask_m.astype(target_density.dtype), hx, hy)
            hpd_coverages.append(float(cov_hpd))
            hpd_areas.append(float(area_m))
            
            # Target HPD area (diagnostic)
            mask_t, tau_t, mass_t, area_t = compute_hpd_mask_2d(target_density, hx, hy, mass=0.9)
            hpd_areas_target.append(float(area_t))
        
        hpd_coverage_90 = float(jnp.mean(jnp.array(hpd_coverages)))
        hpd_area_90 = float(jnp.mean(jnp.array(hpd_areas)))
        target_hpd_area_90 = float(jnp.mean(jnp.array(hpd_areas_target)))
        print(f"     HPD metrics computed: Coverage={hpd_coverage_90:.3f}, Area={hpd_area_90:.3f}")
        
        # Compute point prediction metrics (RMSE, MAE)
        pred_means = []
        true_centers = []
        for k in range(K):
            # Compute weighted mean as point prediction
            X, Y = jnp.meshgrid(grid.points_x, grid.points_y, indexing='ij')
            density = sol.path_densities[k].astype(jnp.float64)
            pred_mean_x = _trapz2(X * density, hx, hy)
            pred_mean_y = _trapz2(Y * density, hx, hy)
            pred_means.append([pred_mean_x, pred_mean_y])
            
            # True center (weighted mean of target)
            target_density = observed_unified[k].astype(jnp.float64)
            true_mean_x = _trapz2(X * target_density, hx, hy)
            true_mean_y = _trapz2(Y * target_density, hx, hy)
            true_centers.append([true_mean_x, true_mean_y])
        
        pred_means = jnp.array(pred_means)  # (K, 2)
        true_centers = jnp.array(true_centers)  # (K, 2)
        
        rmse = float(jnp.sqrt(jnp.mean((pred_means - true_centers)**2)))
        mae = float(jnp.mean(jnp.abs(pred_means - true_centers)))
        print(f"     Point prediction metrics: RMSE={rmse:.6f}, MAE={mae:.6f}")
        
        # Compute Log Score (average over time points)
        log_scores = []
        for k in range(K):
            density = sol.path_densities[k].astype(jnp.float64)
            true_point = true_centers[k]  # Use computed center as "true value"
            
            # Find closest grid point
            x_idx = jnp.argmin(jnp.abs(grid.points_x - true_point[0]))
            y_idx = jnp.argmin(jnp.abs(grid.points_y - true_point[1]))
            
            point_density = density[x_idx, y_idx]
            log_score = jnp.log(jnp.maximum(point_density, 1e-15))
            log_scores.append(float(log_score))
        
        unified_log_score = float(jnp.mean(jnp.array(log_scores)))
        print(f"     Log Score computed: {unified_log_score:.2f}")
        print(f"   All comprehensive metrics successfully computed!")
        
        # (移除) 物理一致性/因果比指标：方法论不正确，避免噪声干扰；仅保留 holdout 指标作为判据。
        
        results[exp_name] = {
            "description": description,
            "anchor_mask": anchor_mask,
            "ou_params": {
                "mean_reversion_x": float(ou_params.mean_reversion_x),
                "diffusion_x": float(ou_params.diffusion_x),
                "equilibrium_mean_x": float(ou_params.equilibrium_mean_x),
                "mean_reversion_y": float(ou_params.mean_reversion_y),
                "diffusion_y": float(ou_params.diffusion_y),
                "equilibrium_mean_y": float(ou_params.equilibrium_mean_y),
            },
            "runtime_sec": t1 - t0,
            "final_error": float(sol.final_error),
            "n_iterations": int(sol.n_iterations),
            "l1_errors": l1_errors,
            "avg_l1_error": avg_l1,
            "avg_l1_error_anchor": avg_l1_anchor,
            "avg_l1_error_holdout": avg_l1_holdout,
            "sum_nll": sum_nll,
            "converged": sol.n_iterations < cfg_unified.max_iterations,
            # COMPREHENSIVE METRICS
            "unified_rmse": rmse,
            "unified_mae": mae,
            "unified_crps": crps_value,
            "unified_wasserstein": sinkhorn_divergence,
            "unified_swd": swd_value,
            "unified_hellinger": hellinger_value,
            "unified_hellinger_holdout": hellinger_holdout_value,
            "unified_legacy_wasserstein": legacy_wasserstein,
            "unified_log_score": unified_log_score,
            "unified_hpd_coverage_90": hpd_coverage_90,
            "unified_hpd_area_90": hpd_area_90,
            "unified_target_hpd_area_90": target_hpd_area_90,
        }
        
        print(f"   Basic Results: L1={avg_l1:.4f} (anchor={avg_l1_anchor:.4f}, holdout={avg_l1_holdout:.4f}), NLL={sum_nll:.2f}, Runtime={t1-t0:.1f}s, " +
              f"Iterations={sol.n_iterations}, Converged={results[exp_name]['converged']}")
        print(f"    KEY COMPREHENSIVE RESULTS:")
        print(f"     SWD: {swd_value:.4f}, Sinkhorn: {sinkhorn_divergence:.4f}, Hellinger: {hellinger_value:.6f}")
        print(f"     CRPS: {crps_value:.4f}, HPD Coverage: {hpd_coverage_90:.3f}, RMSE: {rmse:.4f}")
        print(f"     Legacy Wasserstein: {legacy_wasserstein:.4f}, LogScore: {unified_log_score:.2f}")
    
    return results


def analyze_geometric_prior_results(results):
    """基于 holdout 指标（非锚点时刻）分析“先验即几何”。"""
    print("\n" + "="*80)
    print("HOLDOUT-BASED GEOMETRIC PRIOR ANALYSIS")
    print("="*80)

    def get_holdout(name):
        r = results[name]
        return float(r.get("avg_l1_error_holdout", float('nan'))), float(r.get("unified_hellinger_holdout", float('nan')))

    a_l1, a_h = get_holdout("A_Correct_Prior")
    b_l1, b_h = get_holdout("B_Drift_Mismatch")
    c_l1, c_h = get_holdout("C_Diffusion_Mismatch")

    print("HOLDOUT (非锚点) 指标：")
    print(f"   A: L1={a_l1:.6f}, H={a_h:.6f}")
    print(f"   B: L1={b_l1:.6f}, H={b_h:.6f}")
    print(f"   C: L1={c_l1:.6f}, H={c_h:.6f}")

    dom_l1 = (c_l1 - a_l1) > (b_l1 - a_l1) + 1e-6
    dom_h  = (c_h  - a_h ) > (b_h  - a_h ) + 1e-6
    hypothesis_confirmed = bool(dom_l1 and dom_h)

    print("\n GEOMETRIC DOMINANCE (by holdout):")
    print(f"   L1 dominance: {dom_l1}  (ΔC-A={c_l1-a_l1:+.4f} vs ΔB-A={b_l1-a_l1:+.4f})")
    print(f"   H  dominance: {dom_h}   (ΔC-A={c_h -a_h :+.4f} vs ΔB-A={b_h -a_h :+.4f})")

    for k in sorted(results.keys()):
        if k in ("A_Correct_Prior", "B_Drift_Mismatch", "C_Diffusion_Mismatch"):
            continue
        l1 = results[k].get("avg_l1_error_holdout", None)
        h  = results[k].get("unified_hellinger_holdout", None)
        if l1 is not None or h is not None:
            print(f"   {k}: L1={float(l1) if l1 is not None else float('nan'):.6f}, H={float(h) if h is not None else float('nan'):.6f}")

    l1_A_all = results["A_Correct_Prior"]["avg_l1_error"]
    l1_B_all = results["B_Drift_Mismatch"]["avg_l1_error"]
    l1_C_all = results["C_Diffusion_Mismatch"]["avg_l1_error"]
    print("\n ALL-TIME L1 (参考):")
    print(f"   A={l1_A_all:.6f}, B={l1_B_all:.6f}, C={l1_C_all:.6f}")

    analysis = {
        "theory": "先验即几何 - Geometric Prior Hypothesis",
        "holdout_metrics": {
            "A": {"l1": a_l1, "hellinger": a_h},
            "B": {"l1": b_l1, "hellinger": b_h},
            "C": {"l1": c_l1, "hellinger": c_h},
        },
        "dominance": {
            "l1": dom_l1,
            "hellinger": dom_h,
            "delta_l1": {"C_minus_A": c_l1 - a_l1, "B_minus_A": b_l1 - a_l1},
            "delta_h":  {"C_minus_A": c_h  - a_h,  "B_minus_A": b_h  - a_h },
        },
        "hypothesis_confirmed": hypothesis_confirmed,
        "legacy_l1_errors": {"A": l1_A_all, "B": l1_B_all, "C": l1_C_all},
        "conclusion": "几何先验假设得到验证" if hypothesis_confirmed else "几何先验影响需进一步测试",
    }
    return analysis


def save_ablation_results(results, analysis):
    """Save complete ablation experiment results."""
    os.makedirs("results/geometric_prior_ablation", exist_ok=True)
    
    combined_results = {
        "experiment_type": "geometric_prior_ablation",
        "theory": "先验即几何 - Prior determines geometry",
        "unified_config": "best.json optimized parameters",
        "experiments": results,
        "analysis": analysis,
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
    }
    
    with open("results/geometric_prior_ablation/complete_results.json", "w") as f:
        json.dump(combined_results, f, indent=2)
    
    print(f"\n Complete results saved to: results/geometric_prior_ablation/complete_results.json")
    return combined_results


# Modified main entry point
def run_experiment():
    """Main entry point - runs geometric prior ablation experiment."""
    results = run_geometric_prior_ablation_experiment()
    analysis = analyze_geometric_prior_results(results)
    complete_results = save_ablation_results(results, analysis)
    return complete_results


if __name__ == "__main__":
    jax.config.update("jax_enable_x64", True)
    print(" Starting Geometric Prior Ablation Experiment...")
    print("   Theory: '先验即几何' - Prior determines algorithmic performance")
    
    complete_results = run_experiment()
    
    # Final summary
    analysis = complete_results["analysis"]
    print(f" FINAL CONCLUSION:")
    print(f"   {analysis['theory']}")
    print(f"   Hypothesis Confirmed: {analysis['hypothesis_confirmed']}")
    if analysis['hypothesis_confirmed']:
        print(f"    扩散项决定几何结构，比漂移项影响更根本！") 
        print(f"    MMSBVI的成功关键在于正确的几何先验选择")
    else:
        print(f"    需要进一步加大参数错配来明确验证")
