"""
2D Gaussian OU Kernel (Separable)
二维可分OU高斯核

高性能实现要点（遵循JAX.md最佳实践）：
- 使用向量化广播与log-sum-exp避免Python循环
- 可分核：K2D = Kx ⊗ Ky，应用核时两次一维卷积，避免(nxy)^2内存/算力
- 全程float64确保SB数值稳定；必要处裁剪与归一化
"""

from typing import Tuple, Dict

import jax
import jax.numpy as jnp
from jax import jit, lax
try:
    # Pallas is available from jax.experimental import pallas
    from jax.experimental import pallas as pl
    _PALLAS_AVAILABLE = True
except Exception:
    _PALLAS_AVAILABLE = False

from ..core.types import (
    Density2D, Grid2D_X, Grid2D_Y, Scalar,
    OUProcessParams2D, GridConfig2D, OUProcessParams,
)
from ..constants import MIN_DENSITY, LOG_STABILITY
from ..utils.precision import compute_dtype, storage_dtype, matmul_precision
from .gaussian_kernel_1d import compute_log_transition_kernel_1d


def _weights_trapz_1d(n: int, dtype: jnp.dtype = jnp.float64) -> jnp.ndarray:
    """1D trapezoid weights vector of length n. 端点0.5，其余1.0"""
    w = jnp.ones((n,), dtype=dtype)
    w = w.at[0].set(0.5).at[-1].set(0.5)
    return w


@jit
def jax_trapz_2d(values: jnp.ndarray, hx: float, hy: float) -> float:
    """JAX-compatible 2D trapezoidal integration. 二维梯形积分。
    values: (nx, ny)
    """
    nx, ny = values.shape
    wx = _weights_trapz_1d(nx, dtype=values.dtype)
    wy = _weights_trapz_1d(ny, dtype=values.dtype)
    # 外积权重
    w2 = wx[:, None] * wy[None, :]
    return hx * hy * jnp.sum(values * w2)


@jit
def _build_K1d_from_log(logK1d: jnp.ndarray, h: float) -> jnp.ndarray:
    """Exponentiate logK and normalize each column with trapezoid rule.
    将log核转为正空间并按列用梯形法归一化。
    logK1d: (n_tgt, n_src)
    返回: K1d (n_tgt, n_src) 满足 ∑_tgt K*h ≈ 1
    """
    K = jnp.exp(logK1d)
    # 数值稳定裁剪
    K = jnp.where(K > jnp.asarray(LOG_STABILITY, K.dtype), K, jnp.asarray(LOG_STABILITY, K.dtype))
    # 每列归一化（目标积分为1）
    n_tgt = K.shape[0]
    w = _weights_trapz_1d(n_tgt, dtype=K.dtype)
    col_sums = h * jnp.sum(K * w[:, None], axis=0)
    K = K / (col_sums[None, :] + 1e-30)
    return K


def _pallas_build_K1d_from_log(logK1d: jnp.ndarray, h: float) -> jnp.ndarray:
    """Pallas kernel version of exponentiate + column-wise trapezoid normalization.
    使用Pallas将 `exp(logK)` 与列归一化融合为单核，减少读写与访存往返。
    约束：当前实现为参考版，要求 logK1d 为 C(列) 主存布局；JAX 默认行优先需显式处理切片。
    """
    if not _PALLAS_AVAILABLE:
        return _build_K1d_from_log(logK1d, h)

    # 形状信息
    n_tgt, n_src = logK1d.shape
    dtype = logK1d.dtype
    wx = _weights_trapz_1d(n_tgt, dtype=dtype)

    # 简化：按列处理，一个block处理一列，带宽友好
    def kernel(logK_ref, out_ref, wx_ref, h_scalar):
        i = pl.program_id(0)  # 列索引
        # 逐元素exp并加权求和
        col_sum = jnp.asarray(0.0, dtype)
        for r in range(n_tgt):  # 受限：小n_tgt时尚可；大规模可分块+向量化
            v = jnp.exp(logK_ref[r, i])
            col_sum = col_sum + v * wx_ref[r]
            out_ref[r, i] = v
        col_sum = col_sum * h_scalar
        inv = jnp.reciprocal(col_sum + jnp.asarray(1e-30, dtype))
        for r in range(n_tgt):
            out_ref[r, i] = out_ref[r, i] * inv

    grid = (n_src,)
    out = jnp.empty_like(logK1d)
    pl.pallas_call(
        kernel,
        out_shape=jax.ShapeDtypeStruct(out.shape, out.dtype),
        grid=grid,
    )(logK1d, out, wx, jnp.asarray(h, dtype))
    return out


def _pallas_build_K1d_from_log_batched(
    logK1d_all: jnp.ndarray,  # (M, n_tgt, n_src)
    h: float,
    block_rows: int = 128,
) -> jnp.ndarray:
    """Batched Pallas kernel for exp + column-wise trapezoid normalization.
    对 (dt×列) 使用二维 grid 并沿行分块，提升并行度与缓存利用。
    形状：输入 (M, n_tgt, n_src) 输出同形状。
    """
    if not _PALLAS_AVAILABLE:
        # 回退：按 M vmap 普通版本
        M = logK1d_all.shape[0]
        return jax.vmap(lambda L: _build_K1d_from_log(L, h), in_axes=0)(logK1d_all)

    M, n_tgt, n_src = logK1d_all.shape
    dtype = logK1d_all.dtype
    wx = _weights_trapz_1d(n_tgt, dtype=dtype)

    def kernel(logK_ref, out_ref, wx_ref, h_scalar):
        m = pl.program_id(0)  # dt 批次
        j = pl.program_id(1)  # 列
        # 第一遍：计算列和并写入未归一化的 K
        col_sum = jnp.asarray(0.0, dtype)
        # 行分块
        for r0 in range(0, n_tgt, block_rows):
            r1 = jnp.minimum(r0 + block_rows, n_tgt)
            for r in range(r0, r1):
                v = jnp.exp(logK_ref[m, r, j])
                col_sum = col_sum + v * wx_ref[r]
                out_ref[m, r, j] = v
        col_sum = col_sum * h_scalar
        inv = jnp.reciprocal(col_sum + jnp.asarray(1e-30, dtype))
        # 第二遍：归一化写回
        for r0 in range(0, n_tgt, block_rows):
            r1 = jnp.minimum(r0 + block_rows, n_tgt)
            for r in range(r0, r1):
                out_ref[m, r, j] = out_ref[m, r, j] * inv

    out = jnp.empty_like(logK1d_all)
    grid = (M, n_src)
    pl.pallas_call(
        kernel,
        out_shape=jax.ShapeDtypeStruct(out.shape, out.dtype),
        grid=grid,
    )(logK1d_all, out, wx, jnp.asarray(h, dtype))
    return out


def _pallas_fused_clip_trapz_normalize_2d(
    density: jnp.ndarray,  # (nx, ny)
    hx: float,
    hy: float,
    min_density: float,
) -> jnp.ndarray:
    """Fused 2D kernel: clip + trapezoid mass + normalize.
    将裁剪、二维梯形积分的质量计算、归一化融合为单核（单块遍历）。
    说明：为保持可移植性，此版本单block遍历全部元素，便于准确计算总质量；
    大网格可扩展为分块+两阶段归约（需要原子加或中间缓冲）。
    """
    if not _PALLAS_AVAILABLE:
        comp_dtype = density.dtype
        rho = jnp.maximum(density, jnp.asarray(min_density, comp_dtype))
        mass = jax_trapz_2d(rho, hx, hy)
        return rho / (mass + jnp.asarray(1e-15, comp_dtype))

    nx, ny = density.shape
    dtype = density.dtype
    wx = _weights_trapz_1d(nx, dtype=dtype)
    wy = _weights_trapz_1d(ny, dtype=dtype)

    def kernel(rho_ref, out_ref, wx_ref, wy_ref, hx_s, hy_s, min_d):
        # 第一遍：裁剪 + 计算总质量（使用加权）
        mass = jnp.asarray(0.0, dtype)
        for i in range(nx):
            for j in range(ny):
                v = rho_ref[i, j]
                v = jnp.maximum(v, min_d)
                out_ref[i, j] = v
                mass = mass + v * wx_ref[i] * wy_ref[j]
        mass = mass * hx_s * hy_s
        inv = jnp.reciprocal(mass + jnp.asarray(1e-15, dtype))
        # 第二遍：归一化
        for i in range(nx):
            for j in range(ny):
                out_ref[i, j] = out_ref[i, j] * inv

    out = jnp.empty_like(density)
    pl.pallas_call(
        kernel,
        out_shape=jax.ShapeDtypeStruct(out.shape, out.dtype),
        grid=1,
    )(
        density,
        out,
        wx,
        wy,
        jnp.asarray(hx, dtype),
        jnp.asarray(hy, dtype),
        jnp.asarray(min_density, dtype),
    )
    return out


def _pallas_fused_clip_trapz_normalize_2d_tiled(
    density: jnp.ndarray,  # (nx, ny)
    hx: float,
    hy: float,
    min_density: float,
    tile_i: int = 64,
    tile_j: int = 64,
) -> jnp.ndarray:
    """Tiled fused 2D kernel: clip + trapezoid mass + normalize.
    分两阶段：
    1) tile 级裁剪并计算分块质量，写出临时 out 与 partials
    2) 归一化阶段按 tile 对 out 进行缩放
    """
    if not _PALLAS_AVAILABLE:
        return _pallas_fused_clip_trapz_normalize_2d(density, hx, hy, min_density)

    nx, ny = density.shape
    dtype = density.dtype
    wx = _weights_trapz_1d(nx, dtype=dtype)
    wy = _weights_trapz_1d(ny, dtype=dtype)

    gx = (nx + tile_i - 1) // tile_i
    gy = (ny + tile_j - 1) // tile_j

    # 第一阶段：裁剪 + 分块质量
    def kernel_stage1(rho_ref, out_ref, partials_ref, wx_ref, wy_ref, hx_s, hy_s, min_d):
        bi = pl.program_id(0)
        bj = pl.program_id(1)
        i0 = bi * tile_i
        j0 = bj * tile_j
        local_sum = jnp.asarray(0.0, dtype)
        for di in range(tile_i):
            ii = i0 + di
            if ii < nx:
                wxi = wx_ref[ii]
                for dj in range(tile_j):
                    jj = j0 + dj
                    if jj < ny:
                        v = rho_ref[ii, jj]
                        v = jnp.maximum(v, min_d)
                        out_ref[ii, jj] = v
                        local_sum = local_sum + v * wxi * wy_ref[jj]
        partials_ref[bi, bj] = local_sum * hx_s * hy_s

    out = jnp.empty_like(density)
    partials = jnp.empty((gx, gy), dtype=dtype)
    out, partials = pl.pallas_call(
        kernel_stage1,
        out_shape=(
            jax.ShapeDtypeStruct(out.shape, out.dtype),
            jax.ShapeDtypeStruct(partials.shape, partials.dtype),
        ),
        grid=(gx, gy),
    )(
        density,
        wx,
        wy,
        jnp.asarray(hx, dtype),
        jnp.asarray(hy, dtype),
        jnp.asarray(min_density, dtype),
    )

    mass = jnp.sum(partials)
    inv = jnp.reciprocal(mass + jnp.asarray(1e-15, dtype))

    # 第二阶段：按 tile 归一化
    def kernel_stage2(out_ref, inv_s):
        bi = pl.program_id(0)
        bj = pl.program_id(1)
        i0 = bi * tile_i
        j0 = bj * tile_j
        for di in range(tile_i):
            ii = i0 + di
            if ii < nx:
                for dj in range(tile_j):
                    jj = j0 + dj
                    if jj < ny:
                        out_ref[ii, jj] = out_ref[ii, jj] * inv_s

    out = pl.pallas_call(
        kernel_stage2,
        out_shape=jax.ShapeDtypeStruct(out.shape, out.dtype),
        grid=(gx, gy),
    )(out, inv)
    return out


def compute_log_transition_kernels_2d_separable(
    x_target: Grid2D_X,
    x_source: Grid2D_X,
    y_target: Grid2D_Y,
    y_source: Grid2D_Y,
    dt: Scalar,
    ou_params: OUProcessParams2D,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute separable log kernels for 2D OU: (logKx, logKy).
    计算二维可分OU对数核（按x轴与y轴）：返回 (logKx, logKy)
    形状：logKx (nx_tgt, nx_src), logKy (ny_tgt, ny_src)
    """
    # x轴
    ou_x = OUProcessParams(
        mean_reversion=ou_params.mean_reversion_x,
        diffusion=ou_params.diffusion_x,
        equilibrium_mean=ou_params.equilibrium_mean_x,
    )
    logKx = compute_log_transition_kernel_1d(
        x_target, x_source, dt, ou_x
    )
    # y轴
    ou_y = OUProcessParams(
        mean_reversion=ou_params.mean_reversion_y,
        diffusion=ou_params.diffusion_y,
        equilibrium_mean=ou_params.equilibrium_mean_y,
    )
    logKy = compute_log_transition_kernel_1d(
        y_target, y_source, dt, ou_y
    )
    return logKx.astype(jnp.float64), logKy.astype(jnp.float64)


@jit
def apply_ou_kernel_2d_separable(
    density: Density2D,
    dt: Scalar,
    ou_params: OUProcessParams2D,
    grid: GridConfig2D,
    use_pallas_fused_norm: bool = False,
) -> Density2D:
    """Apply separable 2D OU kernel: ρ' = (Kx ⊗ Ky) ρ
    两步法：先沿x，再沿y；避免(nxy)^2。
    """
    comp_dtype = compute_dtype()
    store_dtype = storage_dtype()
    rho = density.astype(comp_dtype)
    hx = jnp.asarray(grid.spacing_x, comp_dtype)
    hy = jnp.asarray(grid.spacing_y, comp_dtype)

    # 构造1D对数核并转为正空间矩阵
    logKx, logKy = compute_log_transition_kernels_2d_separable(
        grid.points_x, grid.points_x, grid.points_y, grid.points_y, dt, ou_params
    )
    # 在计算路径中使用float32并启用高精度matmul（通常走TF32）
    with matmul_precision("high"):
        Kx = _build_K1d_from_log(logKx.astype(comp_dtype), hx)
        Ky = _build_K1d_from_log(logKy.astype(comp_dtype), hy)

    # 先沿x: (nx,nx) @ (nx,ny) * hx
    tmp = (Kx @ rho) * hx
    # 再沿y: (nx,ny) @ (ny,ny)^T * hy
    evolved = (tmp @ Ky.T) * hy
    
    # 正性与归一化（可选Pallas融合）
    if use_pallas_fused_norm:
        evolved = _pallas_fused_clip_trapz_normalize_2d(
            evolved, hx, hy, float(MIN_DENSITY)
        )
    else:
        evolved = jnp.maximum(evolved, jnp.asarray(MIN_DENSITY, comp_dtype))
        mass = jax_trapz_2d(evolved, hx, hy)
        evolved = evolved / (mass + jnp.asarray(1e-15, comp_dtype))
    # 存储为较低精度以节省显存
    return evolved.astype(store_dtype)


def validate_ou_kernel_properties_2d(
    grid: GridConfig2D,
    dt: Scalar,
    ou_params: OUProcessParams2D,
    tolerance: float = 1e-10,
) -> Dict[str, float]:
    """Validate separable 2D kernel properties via 1D marginals.
    通过各轴1D核的性质间接验证2D核性质。
    """
    hx = grid.spacing_x
    hy = grid.spacing_y
    logKx, logKy = compute_log_transition_kernels_2d_separable(
        grid.points_x, grid.points_x, grid.points_y, grid.points_y, dt, ou_params
    )
    Kx = _build_K1d_from_log(logKx, hx)
    Ky = _build_K1d_from_log(logKy, hy)

    # 列积分误差（应为1）
    wx = _weights_trapz_1d(Kx.shape[0])
    wy = _weights_trapz_1d(Ky.shape[0])
    colsum_x = hx * jnp.sum(Kx * wx[:, None], axis=0)
    colsum_y = hy * jnp.sum(Ky * wy[:, None], axis=0)
    err_x = float(jnp.max(jnp.abs(colsum_x - 1.0)))
    err_y = float(jnp.max(jnp.abs(colsum_y - 1.0)))

    # 正性
    min_x = float(jnp.min(Kx))
    min_y = float(jnp.min(Ky))

    return {
        "probability_conservation_error_x": err_x,
        "probability_conservation_error_y": err_y,
        "min_kernel_value_x": min_x,
        "min_kernel_value_y": min_y,
        "positivity_satisfied": float(min(min_x, min_y) >= 0.0),
    }


def estimate_kernel_bandwidth_2d(
    dt: Scalar,
    ou_params: OUProcessParams2D,
) -> Tuple[Scalar, Scalar]:
    """Estimate standard deviations along x and y. 返回(σ_x, σ_y)."""
    # 1D OU方差：σ^2/(2θ) * (1 - e^{-2θt})，θ≈0时退化为σ^2 t
    def var_axis(theta, sigma):
        var_non = (sigma**2 / (2.0 * theta)) * (1.0 - jnp.exp(-2.0 * theta * dt))
        var_bm = sigma**2 * dt
        return jnp.where(theta > 1e-10, var_non, var_bm)

    vx = var_axis(ou_params.mean_reversion_x, ou_params.diffusion_x)
    vy = var_axis(ou_params.mean_reversion_y, ou_params.diffusion_y)
    return jnp.sqrt(vx), jnp.sqrt(vy)


# Backward-compatible style aliases (explicit 2D names)
apply_ou_kernel_2d_separable_fixed = apply_ou_kernel_2d_separable
compute_log_transition_kernels_2d_separable_fixed = compute_log_transition_kernels_2d_separable
estimate_kernel_bandwidth_2d_fixed = estimate_kernel_bandwidth_2d

