"""
MOD_fuse_raw: MOD-FUSE test using unwhitened statistic (max(T^2)).
Steps:
1. Compute unwhitened stats across bandwidths, then aggregate (log-sum-exp).
2. Use permutation test to obtain critical value.
3. Compare aggregated stat vs critical value.
"""
import numpy as np
from scipy.special import logsumexp
from MOD import compute_all_statistics
from kernel import kernel_matrix, pairwise_distances, compute_bandwidths_from_distances


def _agg_stat(stats, lambda_scale=1.0, lambda_multiplier=1.0):
    """Aggregate stats via log-sum-exp."""
    lambda_reg = lambda_scale * lambda_multiplier
    N = stats.shape[0]
    log_sum_exp = logsumexp(lambda_reg * stats) - np.log(N)
    return log_sum_exp / lambda_reg


def _compute_unwhitened_stat(A, index):
    """Compute unwhitened MOD statistic max(T^2)."""
    T = compute_all_statistics(A, index)
    stat = np.max(T ** 2)
    return stat


def mod_fuse_raw(
    X,
    Y,
    kernels="gaussian",
    number_bandwidths=10,
    B=1000,
    alpha=0.05,
    lambda_multiplier=1.0,
    seed=None,
    return_all=False,
):
    """
    MOD-FUSE with unwhitened statistic.
    Returns agg_stat, thresh (1-alpha quantile), reject flag, and optionally
    per-bandwidth stats and permutation stats.
    """
    if Y.shape[0] > X.shape[0]:
        X, Y = Y, X
    n1 = X.shape[0]
    n2 = Y.shape[0]
    assert n2 <= n1
    assert n2 >= 2 and n1 >= 2
    assert 0 < alpha < 1
    assert lambda_multiplier > 0
    assert number_bandwidths > 1 and isinstance(number_bandwidths, int)
    assert B > 0 and isinstance(B, int)
    
    if isinstance(kernels, str):
        kernels = (kernels,)
    
    valid_kernels = (
        "imq", "rq", "gaussian", "matern_0.5_l2", "matern_1.5_l2", "matern_2.5_l2",
        "matern_3.5_l2", "matern_4.5_l2", "laplace", "matern_0.5_l1", "matern_1.5_l1",
        "matern_2.5_l1", "matern_3.5_l1", "matern_4.5_l1"
    )
    for kernel in kernels:
        assert kernel in valid_kernels, f"Invalid kernel: {kernel}"
    
    rng = np.random.default_rng(seed)
    
    min_n = min(n1, n2)
    lambda_scale = np.sqrt(min_n * (min_n - 1))
    index_orig = np.concatenate((np.ones(n1), np.full(n2, 2)))
    
    Z = np.concatenate([X, Y], axis=0)
    
    kernels_l1 = [k for k in kernels if k.endswith("_l1") or k == "laplace"]
    kernels_l2 = [k for k in kernels if k.endswith("_l2") or k in ("gaussian", "imq", "rq")]
    
    A_list = []
    for r in range(2):
        kernels_l = (kernels_l1, kernels_l2)[r]
        metric = ("l1", "l2")[r]
        if len(kernels_l) > 0:
            # 计算当前度量的距离矩阵
            pairwise_metric = pairwise_distances(Z, metric=metric)
            # 计算带宽（基于当前度量的距离）
            distances_metric = pairwise_metric[np.triu_indices(pairwise_metric.shape[0], k=1)]
            bandwidths_metric = compute_bandwidths_from_distances(distances_metric, number_bandwidths)
            
            for kernel_name in kernels_l:
                for bw in bandwidths_metric:
                    # 构建核矩阵
                    A = kernel_matrix(pairwise_metric, kernel_name, bw, metric=metric)
                    np.fill_diagonal(A, 0)
                    A_list.append(A)
    
    # 步骤2: 计算原始数据的未白化统计量（所有带宽/核组合）
    stats_orig = np.zeros(len(A_list), dtype=float)
    for i, A in enumerate(A_list):
        stats_orig[i] = _compute_unwhitened_stat(A, index_orig)
    
    # 步骤3: 聚合原始统计量
    agg_stat = _agg_stat(
        stats_orig, lambda_scale=lambda_scale, lambda_multiplier=lambda_multiplier
    )
    
    # 步骤4: 排列检验
    # 生成B个随机排列的标签
    perm_indices = [rng.permutation(index_orig) for _ in range(B)]
    
    # 对每个排列计算统计量并聚合
    permuted_agg_stats = np.zeros(B, dtype=float)
    for b, perm_index in enumerate(perm_indices):
        stats_perm = np.zeros(len(A_list), dtype=float)
        for i, A in enumerate(A_list):
            # 使用排列后的标签计算未白化统计量
            stats_perm[i] = _compute_unwhitened_stat(A, perm_index)
        
        # 聚合排列后的统计量
        permuted_agg_stats[b] = _agg_stat(
            stats_perm, lambda_scale=lambda_scale, lambda_multiplier=lambda_multiplier
        )
    
    # 步骤5: 计算临界值（(1-alpha)分位数）
    thresh = np.quantile(permuted_agg_stats, 1 - alpha)
    
    # 步骤6: 比较并得出结论
    reject = agg_stat > thresh
    
    result = {
        "agg_stat": agg_stat,
        "thresh": thresh,
        "reject": reject,
    }
    
    if return_all:
        result["stats_per_bw"] = stats_orig
        result["permuted_agg_stats"] = permuted_agg_stats
    
    return result
