"""
2D Multi-marginal IPFP Algorithm (Separable OU)
二维多边际IPFP算法（可分OU核）

数值要点：
- 前后向消息在概率空间用可分核两次 1D 矩阵乘，避免 (nx·ny)^2 复杂度
- 更新在对数域：φ_k ← φ_k + ε (log ρ_target − log ρ_current)
- 梯形归一化与数值裁剪
- CUDA路径默认float32计算、bfloat16存储

性能要点：
- 使用 vmap 向量化批量构造核，减少 Python 循环与编译开销
- 使用 fori_loop 按 ii=0..K-2 映射到 i=(K-2-ii) 的反向更新
"""

from __future__ import annotations

from typing import Dict, List, Optional, Tuple, Callable

import jax
import jax.numpy as jnp
from jax import jit, lax
from jax.lib import xla_bridge as xb
from functools import partial

from ..core.types import (
    Density2D, Potential2D, GridConfig2D, Scalar,
    MMSBProblem2D, MMSBSolution2D, IPFP2DState, IPFP2DConfig,
)
from ..constants import MIN_DENSITY
from ..utils.precision import compute_dtype, storage_dtype, matmul_precision
from ..solvers.gaussian_kernel_2d import (
    compute_log_transition_kernels_2d_separable,
)
from ..solvers.gaussian_kernel_2d import _pallas_build_K1d_from_log as _pallas_build_K1d_from_log_unsafe
from ..solvers.gaussian_kernel_2d import _pallas_build_K1d_from_log_batched as _pallas_build_K1d_from_log_batched_unsafe
from ..solvers.gaussian_kernel_2d import _pallas_fused_clip_trapz_normalize_2d as _pallas_fused_clip_trapz_normalize_2d_unsafe
from ..solvers.gaussian_kernel_2d import _pallas_fused_clip_trapz_normalize_2d_tiled as _pallas_fused_clip_trapz_normalize_2d_tiled_unsafe


# ------------------------------
# Utilities / 工具
# ------------------------------

def _weights_trapz_1d(n: int, dtype: jnp.dtype) -> jnp.ndarray:
    w = jnp.ones((n,), dtype=dtype)
    w = w.at[0].set(0.5).at[-1].set(0.5)
    return w


def _is_cuda_backend() -> bool:
    """Return True if running on CUDA backend. 在CUDA后端返回True。
    Metal/ROCM/CPU 返回 False。
    """
    try:
        pv = getattr(xb.get_backend(), "platform_version", "")
        return "CUDA" in str(pv).upper()
    except Exception:
        return False


@jit
def _trapz2(values: jnp.ndarray, hx: Scalar, hy: Scalar) -> Scalar:
    # CUDA上使用float64归约以增强稳定性；Metal/CPU保持原dtype
    rdtype = jnp.float64 if _is_cuda_backend() else values.dtype
    v = values.astype(rdtype)
    nx, ny = v.shape
    wx = _weights_trapz_1d(nx, rdtype)
    wy = _weights_trapz_1d(ny, rdtype)
    w2 = wx[:, None] * wy[None, :]
    hx_r = jnp.asarray(hx, rdtype)
    hy_r = jnp.asarray(hy, rdtype)
    res = hx_r * hy_r * jnp.sum(v * w2)
    return res.astype(values.dtype)


@partial(jit, static_argnames=["dtype"])
def _build_K1d_from_log(logK: jnp.ndarray, h: Scalar, dtype: jnp.dtype) -> jnp.ndarray:
    """Exponentiate and normalize columns by trapezoid rule / 指数化并按列用梯形法归一化"""
    K = jnp.exp(logK.astype(dtype))
    n_tgt = K.shape[0]
    w = _weights_trapz_1d(n_tgt, dtype)
    col_sums = h * jnp.sum(K * w[:, None], axis=0)
    return K / (col_sums[None, :] + jnp.asarray(1e-30, dtype))


def _precompute_kernels(problem: MMSBProblem2D, use_pallas: bool = False, block_rows: Optional[int] = None) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Stacked 1D kernels for each interval: (Kx_all, Ky_all).
    形状：(K-1, nx, nx), (K-1, ny, ny)
    向量化（vmap）批量构造，减少 Python 循环与编译开销。
    """
    # 精度策略：CPU 用 float64（稳健性），GPU 用默认 compute_dtype()（通常 float32）
    comp_dtype = jnp.float64 if jax.default_backend() == "cpu" else compute_dtype()
    grid = problem.grid
    hx = jnp.asarray(grid.spacing_x, comp_dtype)
    hy = jnp.asarray(grid.spacing_y, comp_dtype)

    # 将时间间隔转为JAX数组，便于vmap
    dts = jnp.asarray(problem.time_intervals, comp_dtype)

    def _logs_for_dt(dt_scalar):
        return compute_log_transition_kernels_2d_separable(
            grid.points_x, grid.points_x, grid.points_y, grid.points_y, dt_scalar, problem.ou_params
        )

    logKx_all, logKy_all = jax.vmap(_logs_for_dt, in_axes=(0,))(dts)

    with matmul_precision("high"):
        if use_pallas:
            # 使用Pallas二维grid批量列归一化
            if block_rows is None:
                Kx_all = _pallas_build_K1d_from_log_batched_unsafe(logKx_all.astype(comp_dtype), hx)
                Ky_all = _pallas_build_K1d_from_log_batched_unsafe(logKy_all.astype(comp_dtype), hy)
            else:
                Kx_all = _pallas_build_K1d_from_log_batched_unsafe(logKx_all.astype(comp_dtype), hx, block_rows)
                Ky_all = _pallas_build_K1d_from_log_batched_unsafe(logKy_all.astype(comp_dtype), hy, block_rows)
        else:
            Kx_all = jax.vmap(lambda L: _build_K1d_from_log(L, hx, comp_dtype), in_axes=0)(logKx_all)
            Ky_all = jax.vmap(lambda L: _build_K1d_from_log(L, hy, comp_dtype), in_axes=0)(logKy_all)
    return Kx_all, Ky_all


def _initialize_state(problem: MMSBProblem2D) -> IPFP2DState:
    # 精度策略：CPU 用 float64，GPU 用默认 compute_dtype()
    comp_dtype = jnp.float64 if jax.default_backend() == "cpu" else compute_dtype()
    nx, ny = problem.grid.n_points_x, problem.grid.n_points_y
    K = problem.n_marginals
    key = jax.random.PRNGKey(42)
    potentials: List[Potential2D] = []
    if problem.observed_marginals is not None:
        # 用观测边际的对数作为初始势，显著加速收敛
        for k in range(K):
            m = jnp.asarray(problem.observed_marginals[k], comp_dtype)
            m = jnp.maximum(m, jnp.asarray(MIN_DENSITY, comp_dtype))
            m = m / (_trapz2(m, jnp.asarray(problem.grid.spacing_x, comp_dtype), jnp.asarray(problem.grid.spacing_y, comp_dtype)) + jnp.asarray(1e-15, comp_dtype))
            phi = jnp.log(m)
            phi = phi - jnp.mean(phi)
            potentials.append(phi)
        # 初始化占位边际为观测边际的裁剪归一化版本
        h = (jnp.asarray(problem.grid.spacing_x, comp_dtype), jnp.asarray(problem.grid.spacing_y, comp_dtype))
        marginals = []
        for m in problem.observed_marginals:
            m2 = jnp.maximum(m.astype(comp_dtype), jnp.asarray(MIN_DENSITY, comp_dtype))
            m2 = m2 / (_trapz2(m2, h[0], h[1]) + jnp.asarray(1e-15, comp_dtype))
            marginals.append(m2)
    else:
        # 若无硬边际，使用均匀初始化
        hx = jnp.asarray(problem.grid.spacing_x, comp_dtype)
        hy = jnp.asarray(problem.grid.spacing_y, comp_dtype)
        uniform = jnp.ones((nx, ny), dtype=comp_dtype)
        uniform = uniform / (_trapz2(uniform, hx, hy) + jnp.asarray(1e-15, comp_dtype))
        for k in range(K):
            phi0 = jnp.zeros((nx, ny), dtype=comp_dtype)
            # 若提供线性高斯观测 (y,C,R)，叠加 Doob–h 对数似然作为初值（不改变模式，仅加速收敛）
            if (problem.y_observations is not None) and (problem.C is not None):
                yk = jnp.asarray(problem.y_observations[k]) if problem.y_observations.ndim == 2 else None
                if yk is not None:
                    loglik = _gaussian_log_likelihood_grid(problem.grid, jnp.asarray(problem.C), jnp.asarray(problem.R) if not isinstance(problem.R, float) else jnp.asarray(problem.R), yk, comp_dtype)
                    phi0 = phi0 + loglik - jnp.mean(loglik)
            potentials.append(phi0)
        marginals = [uniform for _ in range(K)]
    return IPFP2DState(potentials=potentials, marginals=marginals, iteration=0, error=jnp.inf, converged=False)


@jit
def _log_trapz2_from_log(log_values: jnp.ndarray, hx: Scalar, hy: Scalar) -> Scalar:
    """2D trapezoid integral in log-space. 返回 log ∫∫ exp(log_values) dxdy (梯形权)."""
    # CUDA: 使用float64做归约更稳健；Metal/CPU保持原dtype
    rdtype = jnp.float64 if _is_cuda_backend() else log_values.dtype
    nx, ny = log_values.shape
    wx = _weights_trapz_1d(nx, rdtype)
    wy = _weights_trapz_1d(ny, rdtype)
    max_log = jnp.max(log_values)
    vals = jnp.exp((log_values - max_log).astype(rdtype))
    s = jnp.sum(vals * (wx[:, None] * wy[None, :]))
    tiny = jnp.asarray(1e-300, rdtype)
    res = jnp.log(s + tiny) + max_log.astype(rdtype) + jnp.log(jnp.asarray(hx, rdtype)) + jnp.log(jnp.asarray(hy, rdtype))
    return res.astype(log_values.dtype)


@jit
def _forward_messages_log(
    pots: jnp.ndarray,   # (K, nx, ny)
    Kx_all: jnp.ndarray, # (K-1, nx, nx)
    Ky_all: jnp.ndarray, # (K-1, ny, ny)
    hx: Scalar,
    hy: Scalar,
) -> jnp.ndarray:
    """Compute forward messages in log-space. 返回 log fwd shape (K,nx,ny)."""
    K, nx, ny = pots.shape
    eps = jnp.asarray(1e-300, pots.dtype)
    log_hx = jnp.log(hx)
    log_hy = jnp.log(hy)
    logKx_all = jnp.log(Kx_all)
    logKy_all = jnp.log(Ky_all)
    logKyT_all = jnp.transpose(logKy_all, (0, 2, 1))

    # 初始化 log fwd 堆栈
    log_fwd = jnp.full_like(pots, -jnp.inf)
    log_fwd = log_fwd.at[0].set(pots[0])

    def body(i, log_fwd_in):
        log_prev = log_fwd_in[i - 1]  # (nx, ny)
        logKx = logKx_all[i - 1]
        logKyT = logKyT_all[i - 1]
        # x 维卷积（log-sum-exp）
        logTmp = jax.scipy.special.logsumexp(logKx[:, :, None] + log_prev[None, :, :], axis=1) + log_hx
        # y 维卷积（log-sum-exp with Ky^T）
        def conv_y(row_log):
            return jax.scipy.special.logsumexp(logKyT + row_log[None, :], axis=1) + log_hy
        logConv = jax.vmap(conv_y)(logTmp)
        log_cur = pots[i] + logConv
        return log_fwd_in.at[i].set(log_cur)

    return lax.fori_loop(1, K, body, log_fwd)


@jit
def _backward_messages_log(
    pots: jnp.ndarray,   # (K, nx, ny)
    Kx_all: jnp.ndarray, # (K-1, nx, nx)
    Ky_all: jnp.ndarray, # (K-1, ny, ny)
    hx: Scalar,
    hy: Scalar,
) -> jnp.ndarray:
    """Compute backward messages in log-space. 返回 log bwd shape (K,nx,ny)."""
    K, nx, ny = pots.shape
    eps = jnp.asarray(1e-300, pots.dtype)
    log_hx = jnp.log(hx)
    log_hy = jnp.log(hy)
    logKx_all = jnp.log(Kx_all)
    logKy_all = jnp.log(Ky_all)
    logKxT_all = jnp.transpose(logKx_all, (0, 2, 1))
    log_bwd = jnp.full_like(pots, -jnp.inf)
    log_bwd = log_bwd.at[K - 1].set(pots[K - 1])

    # 使用 fori_loop 按 ii=0..K-2 映射到 i=(K-2-ii) 的反向更新
    def bwd_body(ii, log_bwd_in):
        i = (K - 2) - ii
        log_next = log_bwd_in[i + 1]  # (nx, ny)
        logKxT = logKxT_all[i]
        logKy = logKy_all[i]
        # x 维反向卷积（log-sum-exp with Kx^T）
        logTmp = jax.scipy.special.logsumexp(logKxT[:, :, None] + log_next[None, :, :], axis=1) + log_hx
        # y 维反向卷积（log-sum-exp with Ky）
        def conv_y_back(row_log):
            return jax.scipy.special.logsumexp(logKy + row_log[None, :], axis=1) + log_hy
        logConv = jax.vmap(conv_y_back)(logTmp)
        log_cur = pots[i] + logConv
        return log_bwd_in.at[i].set(log_cur)

    return lax.fori_loop(0, K - 1, bwd_body, log_bwd)


@jit
def _compute_log_current_marginal(
    k: int,
    potentials: List[Potential2D],
    Kx_all: jnp.ndarray,
    Ky_all: jnp.ndarray,
    hx: Scalar,
    hy: Scalar,
) -> jnp.ndarray:
    """单时刻边际（对数域版），与主循环一致。"""
    comp_dtype = potentials[0].dtype
    pots = jnp.stack(potentials, axis=0).astype(comp_dtype)
    log_fwd = _forward_messages_log(pots, Kx_all, Ky_all, hx, hy)
    log_bwd = _backward_messages_log(pots, Kx_all, Ky_all, hx, hy)
    log_unnorm = log_fwd[k] + log_bwd[k] - pots[k]
    # 梯形权对数归一化
    nx, ny = log_unnorm.shape
    wx = _weights_trapz_1d(nx, comp_dtype)
    wy = _weights_trapz_1d(ny, comp_dtype)
    max_log = jnp.max(log_unnorm)
    vals = jnp.exp(log_unnorm - max_log)
    s = jnp.sum(vals * (wx[:, None] * wy[None, :]))
    logZ = jnp.log(s + jnp.asarray(1e-300, comp_dtype)) + max_log + jnp.log(hx) + jnp.log(hy)
    return log_unnorm - logZ


@jit
def _compute_all_log_current_marginals(
    pots: jnp.ndarray,           # (K, nx, ny)
    Kx_all: jnp.ndarray,
    Ky_all: jnp.ndarray,
    hx: Scalar,
    hy: Scalar,
) -> jnp.ndarray:
    """Compute log current marginals for all k in one pass (pure log-space)."""
    log_fwd = _forward_messages_log(pots, Kx_all, Ky_all, hx, hy)
    log_bwd = _backward_messages_log(pots, Kx_all, Ky_all, hx, hy)
    log_unnorm = log_fwd + log_bwd - pots

    # 2D trapezoid normalization in log-space
    def normalize_log(slice_log):
        logZ = _log_trapz2_from_log(slice_log, hx, hy)
        return slice_log - logZ
    log_normed = jax.vmap(normalize_log)(log_unnorm)
    return log_normed


@jit
def _update_potentials_batch(
    pots_stack: jnp.ndarray,       # (K, nx, ny)
    logs_all: jnp.ndarray,         # (K, nx, ny)
    targets_stack: jnp.ndarray,    # (K, nx, ny)
    eps_t: Scalar,
    fixed_mask: jnp.ndarray,       # (K,), bool
) -> jnp.ndarray:
    """Vectorized potential update for all k.
    向量化的势函数批量更新：φ_k ← φ_k + ε (log ρ_target − log ρ_current)
    fixed_mask 为 True 的索引保持不变。
    """
    comp_dtype = pots_stack.dtype
    eps = jnp.asarray(eps_t, comp_dtype)
    tiny = jnp.asarray(MIN_DENSITY, comp_dtype)
    log_ratio = jnp.log(jnp.maximum(targets_stack, tiny)) - logs_all
    upd = pots_stack + eps * log_ratio
    # 裁剪与去均值（逐 k）/ Clip then debias per-k
    upd = jnp.clip(upd, -80.0, 80.0)
    mean_k = jnp.mean(upd, axis=(1, 2), keepdims=True)
    upd = upd - mean_k
    # 应用固定掩码
    fm = fixed_mask[:, None, None]
    result = jnp.where(fm, pots_stack, upd)
    return result


@jit
def _rel_change_trapz2(new_pots: jnp.ndarray, old_pots: jnp.ndarray, hx: Scalar, hy: Scalar) -> Scalar:
    """Compute relative L2 change under 2D trapezoid rule.
    计算相对变化：sqrt(∫(Δφ^2)/∫(φ^2))
    """
    diff = new_pots - old_pots
    num = _trapz2(jnp.sum(diff * diff, axis=0), hx, hy) if new_pots.ndim == 3 else _trapz2(diff * diff, hx, hy)
    den = _trapz2(jnp.sum(new_pots * new_pots, axis=0), hx, hy) if new_pots.ndim == 3 else _trapz2(new_pots * new_pots, hx, hy)
    return jnp.sqrt(num / (den + jnp.asarray(1e-15, new_pots.dtype)))


def _gaussian_log_likelihood_grid(
    grid: GridConfig2D,
    C: jnp.ndarray,
    R: jnp.ndarray,
    y: jnp.ndarray,
    dtype: jnp.dtype,
) -> jnp.ndarray:
    """Compute log N(y; C x, R) over 2D grid. 在二维网格上计算高斯观测对数似然。
    支持 R 为标量/对角/满秩。返回形状 (nx, ny)。仅差常数（忽略 -0.5 log det）。
    """
    nx, ny = grid.n_points_x, grid.n_points_y
    X = jnp.stack(jnp.meshgrid(grid.points_x, grid.points_y, indexing="ij"), axis=-1)  # (nx,ny,2)
    X = X.astype(dtype)
    y = y.astype(dtype)
    C = C.astype(dtype)
    # 预测观测
    pred = jnp.einsum("md,ijd->ijm", C, X)  # (nx,ny,m)
    resid = y[None, None, :] - pred  # (nx,ny,m)
    # R^{-1}
    if R.ndim == 0:
        invR = (1.0 / (R + jnp.asarray(1e-30, dtype))).astype(dtype)
        quad = invR * jnp.sum(resid * resid, axis=-1)
    elif R.ndim == 1:
        invR = 1.0 / (R + jnp.asarray(1e-30, dtype))
        quad = jnp.sum(invR[None, None, :] * resid * resid, axis=-1)
    else:
        invR = jnp.linalg.inv(R + jnp.asarray(1e-6, dtype) * jnp.eye(R.shape[0], dtype=dtype))
        quad = jnp.einsum("ijm,mn,ijn->ij", resid, invR, resid)
    return -0.5 * quad  # 忽略常数项


def _solve_ipfp_compiled(
    pots_init: jnp.ndarray,            # (K, nx, ny)
    targets_stack: jnp.ndarray,        # (K, nx, ny)
    fixed_mask: jnp.ndarray,           # (K,) bool
    anchor_mask: jnp.ndarray,          # (K,) bool
    Kx_all: jnp.ndarray,
    Ky_all: jnp.ndarray,
    hx: Scalar,
    hy: Scalar,
    cfg: IPFP2DConfig,
) -> Tuple[jnp.ndarray, Scalar, int, bool]:
    """Fully compiled IPFP loop using lax.fori_loop with in-graph checks.
    完整编译化的IPFP主循环：每个 check_interval 内更新，多间隔执行收敛判据与ε更新。
    返回：最终势函数栈、最终误差、收敛迭代（若未收敛则为max_iters）、是否收敛。
    """
    max_iters = int(cfg.compiled_max_iterations or cfg.max_iterations)
    check_iv = int(cfg.compiled_check_interval or cfg.check_interval)

    comp_dtype = pots_init.dtype
    tol_eff = jnp.asarray(max(cfg.tolerance, 1e-7), comp_dtype)
    eps_min = jnp.asarray(cfg.min_epsilon, comp_dtype)
    eps_decay_high = jnp.asarray(cfg.eps_decay_high, comp_dtype)
    eps_decay_low = jnp.asarray(cfg.eps_decay_low, comp_dtype)
    err_thresh = jnp.asarray(cfg.error_threshold, comp_dtype)

    def body(i, carry):
        pots, prev_pots, eps, converged, conv_iter, last_err = carry

        def do_update(carry):
            pots, prev_pots, eps, converged, conv_iter, last_err = carry
            logs_all = _compute_all_log_current_marginals(pots, Kx_all, Ky_all, hx, hy)
            # 动态目标：锚点用提供的目标，否则用当前密度（无锚）
            cur = jnp.exp(logs_all)
            targets_eff = jnp.where(anchor_mask[:, None, None], targets_stack, cur)
            new_pots = _update_potentials_batch(pots, logs_all, targets_eff, eps, fixed_mask)

            def on_check(ops):
                pots_new, pots_cur, prev, logs_local, eps_in, last_err_in, i_in = ops
                # 相对变化
                rel = _rel_change_trapz2(pots_new, prev, hx, hy)
                # L1 边际误差（向量化）
                cur = jnp.exp(logs_local)
                def l1_pair(c, t):
                    return _trapz2(jnp.abs(c - t), hx, hy)
                l1s = jax.vmap(l1_pair)(cur, targets_stack)
                max_err = jnp.max(l1s)
                # Anderson AA(1) 混合（简化）
                def aa1(x_new, x_cur, x_prev):
                    r_new = (x_new - x_cur).reshape((-1,))
                    d_prev = (x_cur - x_prev).reshape((-1,))
                    denom = jnp.sum(d_prev * d_prev) + jnp.asarray(1e-30, comp_dtype)
                    alpha = jnp.clip(jnp.sum(r_new * d_prev) / denom, 0.0, 1.0)
                    return x_new - alpha * (x_cur - x_prev)
                mixed_pots = lax.cond(
                    jnp.asarray(cfg.use_anderson),
                    lambda _: aa1(pots_new, pots_cur, prev),
                    lambda _: pots_new,
                    operand=None,
                )
                # 分段 ε 调度：早期用高ε快速下降，后期用低ε精修
                stage_switch = jnp.logical_or(i_in > (max_iters // 2), max_err < (err_thresh * 2.0))
                decay_sel = jnp.where(stage_switch, eps_decay_low, eps_decay_high)
                eps_new = jnp.maximum(eps_in * decay_sel, eps_min)
                # 简单回溯：若误差劣化则减半步长
                eps_new = jnp.where(max_err > last_err_in, jnp.maximum(eps_in * jnp.asarray(0.5, comp_dtype), eps_min), eps_new)
                # 收敛
                is_conv = jnp.logical_and(rel < tol_eff, max_err < tol_eff)
                conv_iter_new = jnp.where(jnp.logical_and(is_conv, conv_iter < 0), i_in, conv_iter)
                # 统一返回 dtype（防止 cond 分支 dtype 不一致）
                np_cast = mixed_pots.astype(comp_dtype)
                return (np_cast, np_cast, eps_new.astype(comp_dtype), jnp.logical_or(converged, is_conv), conv_iter_new, max_err.astype(comp_dtype))

            do_check = (i % check_iv) == 0
            ops = (new_pots, pots, prev_pots, logs_all, eps, last_err, i)
            pots, prev_pots, eps, converged, conv_iter, last_err = lax.cond(
                do_check,
                on_check,
                lambda op: (op[0].astype(comp_dtype), op[2].astype(comp_dtype), eps, converged, conv_iter, op[5].astype(comp_dtype)),
                ops,
            )
            return (pots, prev_pots, eps, converged, conv_iter, last_err)

        return lax.cond(converged, lambda c: c, do_update, carry)

    init_eps = jnp.asarray(cfg.initial_epsilon, comp_dtype)
    init_carry = (pots_init, pots_init, init_eps, jnp.array(False), jnp.array(-1), jnp.asarray(1.0, comp_dtype))
    pots_final, _, _, converged, conv_iter, last_err = lax.fori_loop(0, max_iters, body, init_carry)
    n_done = jnp.where(conv_iter >= 0, conv_iter + 1, jnp.asarray(max_iters))
    return pots_final, last_err, int(n_done), bool(converged)


@jit
def _compute_errors(
    state: IPFP2DState,
    problem: MMSBProblem2D,
    Kx_all: jnp.ndarray,
    Ky_all: jnp.ndarray,
) -> Dict[str, Scalar]:
    if problem.observed_marginals is None:
        return {}
    comp_dtype = state.potentials[0].dtype
    hx = jnp.asarray(problem.grid.spacing_x, comp_dtype)
    hy = jnp.asarray(problem.grid.spacing_y, comp_dtype)
    errors: Dict[str, Scalar] = {}
    # 一次性计算所有当前边际，避免重复前/后向
    pots_stack = jnp.stack(state.potentials, axis=0)
    logs_all = _compute_all_log_current_marginals(pots_stack, Kx_all, Ky_all, hx, hy)
    for k in range(problem.n_marginals):
        cur = jnp.exp(logs_all[k])
        tgt = problem.observed_marginals[k].astype(comp_dtype)
        l1 = _trapz2(jnp.abs(cur - tgt), hx, hy)
        errors[f"l1_marginal_{k}"] = l1
    return errors


def _ipfp_iteration(
    state: IPFP2DState,
    problem: MMSBProblem2D,
    Kx_all: jnp.ndarray,
    Ky_all: jnp.ndarray,
    config: IPFP2DConfig,
    eps_t: float,
) -> IPFP2DState:
    K = problem.n_marginals
    comp_dtype = state.potentials[0].dtype
    hx = jnp.asarray(problem.grid.spacing_x, comp_dtype)
    hy = jnp.asarray(problem.grid.spacing_y, comp_dtype)

    fixed_mask = jnp.array(
        config.fixed_potential_mask if config.fixed_potential_mask is not None else [False] * K,
        dtype=bool,
    )

    # 势函数栈
    pots_stack = jnp.stack(state.potentials, axis=0)

    # 迭代内只做一次前/后向，得到所有 k 的 log ρ 当前
    logs_all = _compute_all_log_current_marginals(pots_stack, Kx_all, Ky_all, hx, hy)

    # 目标边际（支持“无锚”时间步：目标=当前密度），默认全锚与旧行为一致
    target_list = (problem.observed_marginals or state.marginals)
    targets_provided = jnp.stack([t.astype(comp_dtype) for t in target_list], axis=0)
    cur_dens = jnp.exp(logs_all)
    if getattr(config, "anchor_mask", None) is not None:
        anchor_mask = jnp.array(list(config.anchor_mask), dtype=bool)
        targets_stack = jnp.where(anchor_mask[:, None, None], targets_provided, cur_dens)
    else:
        targets_stack = targets_provided

    new_pots_stack = _update_potentials_batch(pots_stack, logs_all, targets_stack, eps_t, fixed_mask)
    new_pots: List[Potential2D] = [new_pots_stack[i] for i in range(K)]

    return state.update(potentials=new_pots)


def solve_mmsb_ipfp_2d(
    problem: MMSBProblem2D,
    config: Optional[IPFP2DConfig] = None,
    progress_callback: Optional[Callable[[int, float], None]] = None,
) -> MMSBSolution2D:
    if config is None:
        config = IPFP2DConfig()

    if (getattr(config, "observation_mode", "anchors") == "anchors") and (problem.observed_marginals is None):
        raise ValueError("IPFP2DConfig.observation_mode='anchors' requires observed_marginals")

    state = _initialize_state(problem)
    Kx_all, Ky_all = _precompute_kernels(
        problem,
        use_pallas=config.use_pallas_kernels,
        block_rows=config.pallas_block_rows,
    )

    eps_cur = config.initial_epsilon
    convergence = []
    marginal_hist: List[Dict[str, Scalar]] = []

    # 编译化主循环路径（极致性能）
    if getattr(config, "compiled_loop", False):
        comp_dtype = jnp.float64 if jax.default_backend() == "cpu" else compute_dtype()
        hx = jnp.asarray(problem.grid.spacing_x, comp_dtype)
        hy = jnp.asarray(problem.grid.spacing_y, comp_dtype)
        pots_init = jnp.stack(state.potentials, axis=0).astype(comp_dtype)
        if problem.observed_marginals is not None:
            targets_stack = jnp.stack([m.astype(comp_dtype) for m in problem.observed_marginals], axis=0)
        else:
            logs0 = _compute_all_log_current_marginals(pots_init, Kx_all.astype(comp_dtype), Ky_all.astype(comp_dtype), hx, hy)
            targets_stack = jnp.exp(logs0)
        fixed_mask = jnp.array(
            config.fixed_potential_mask if config.fixed_potential_mask is not None else [False] * problem.n_marginals,
            dtype=bool,
        )
        # 构造锚点掩码（默认全 True 与旧行为等价）
        if getattr(config, "observation_mode", "anchors") == "anchors":
            anchor_mask = jnp.array(
                config.anchor_mask if config.anchor_mask is not None else [True] * problem.n_marginals,
                dtype=bool,
            )
        else:
            anchor_mask = jnp.array([False] * problem.n_marginals, dtype=bool)
        new_pots, final_err, n_done, converged = _solve_ipfp_compiled(
            pots_init,
            targets_stack,
            fixed_mask,
            anchor_mask,
            Kx_all.astype(comp_dtype),
            Ky_all.astype(comp_dtype),
            hx,
            hy,
            config,
        )
        state = state.update(potentials=[new_pots[i] for i in range(new_pots.shape[0])])
        state = state.update(iteration=int(n_done))
        state = state.update(converged=bool(converged))
        state = state.update(error=final_err)
    else:
        for it in range(config.max_iterations):
            old = [p.copy() for p in state.potentials]

            # epsilon 缩放
            if config.epsilon_scaling:
                if it % config.check_interval == 0 and it > 0:
                    last_err = convergence[-1] if convergence else 1.0
                    decay = config.eps_decay_low if last_err < config.error_threshold else config.eps_decay_high
                    eps_cur = max(eps_cur * decay, config.min_epsilon)
                elif it == 0:
                    eps_cur = config.initial_epsilon

            state = _ipfp_iteration(state, problem, Kx_all, Ky_all, config, eps_cur)

            if it % config.check_interval == 0:
                # 势函数相对变化（2D梯形范数）
                def pot_norm(phi):
                    return _trapz2(phi * phi, jnp.asarray(problem.grid.spacing_x, phi.dtype), jnp.asarray(problem.grid.spacing_y, phi.dtype))
                num = sum(_trapz2((n - o) * (n - o), jnp.asarray(problem.grid.spacing_x, n.dtype), jnp.asarray(problem.grid.spacing_y, n.dtype)) for n, o in zip(state.potentials, old))
                den = sum(pot_norm(n) for n in state.potentials) + jnp.asarray(1e-15, state.potentials[0].dtype)
                rel = jnp.sqrt(num / den)
                convergence.append(rel)

                errs = _compute_errors(state, problem, Kx_all, Ky_all)
                marginal_hist.append(errs)
                max_err = jnp.max(jnp.array(list(errs.values()))) if len(errs) > 0 else jnp.asarray(0.0, rel.dtype)
                # 记录当前最大边际误差，避免最终 error 为空
                state = state.update(error=max_err)

                tol_eff = max(config.tolerance, 1e-7)
                if (rel < tol_eff) and (max_err < tol_eff):
                    state = state.update(converged=True, error=max_err)
                    break
                # progress callback per check_interval
                if progress_callback is not None:
                    try:
                        progress_callback(int(it), float(max_err))
                    except Exception:
                        pass

            state = state.update(iteration=it + 1)

        # 若未收敛，确保最终误差为最后一次的最大边际误差
        if not state.converged:
            errs = _compute_errors(state, problem, Kx_all, Ky_all)
            if len(errs) > 0:
                last_max_err = jnp.max(jnp.array(list(errs.values())))
                state = state.update(error=last_max_err)

    # 提取路径密度（一次性计算）
    comp_dtype = state.potentials[0].dtype
    hx = jnp.asarray(problem.grid.spacing_x, comp_dtype)
    hy = jnp.asarray(problem.grid.spacing_y, comp_dtype)
    pots_stack2 = jnp.stack(state.potentials, axis=0)
    logs_all = _compute_all_log_current_marginals(pots_stack2, Kx_all, Ky_all, hx, hy)
    dens_all = jnp.exp(logs_all)
    if config.use_pallas_kernels:
        # 使用融合核逐k保证归一；可选 tiled 版本
        tmp_list = []
        for k in range(problem.n_marginals):
            if config.pallas_norm_tiled:
                ti = int(config.pallas_tile_i) if config.pallas_tile_i is not None else 64
                tj = int(config.pallas_tile_j) if config.pallas_tile_j is not None else 64
                dens_k = _pallas_fused_clip_trapz_normalize_2d_tiled_unsafe(
                    dens_all[k].astype(comp_dtype),
                    hx,
                    hy,
                    float(MIN_DENSITY),
                    ti,
                    tj,
                )
            else:
                dens_k = _pallas_fused_clip_trapz_normalize_2d_unsafe(
                    dens_all[k].astype(comp_dtype),
                    hx,
                    hy,
                    float(MIN_DENSITY),
                )
            # 评测友好：GPU 上改为 float32 存储，避免 bfloat16 量化误差放大 L1
            out_dtype = jnp.float64 if jax.default_backend() == "cpu" else jnp.float32
            tmp_list.append(dens_k.astype(out_dtype))
        path_densities = tmp_list
    else:
        out_dtype = jnp.float64 if jax.default_backend() == "cpu" else jnp.float32
        path_densities: List[Density2D] = [dens_all[k].astype(out_dtype) for k in range(problem.n_marginals)]

    return MMSBSolution2D(
        potentials=state.potentials,
        path_densities=path_densities,
        velocities=None,
        convergence_history=[float(c) for c in convergence],
        final_error=float(state.error) if isinstance(state.error, jnp.ndarray) else state.error,
        n_iterations=state.iteration,
    )

