import numpy as np
import pandas as pd


def al_con(ng1, ng2, delta1, delta2):
    """Calculate normalization constant"""
    if delta1 == delta2:
        temp = 0
    else:
        temp = 1 / np.sqrt((1 / (ng1 - 1) + 1 / ng2) * abs(delta1 - delta2 + 1e-10))
    return temp


def Pest1(a):
    """Estimate parameters from a single vector"""
    n = len(a)
    mu = np.sum(a) / (n - 1)
    delta1 = np.sum(a**2) / (n - 1) - mu**2
    delta2 = ((np.sum(a - mu) + mu) ** 2 - np.sum((a - mu) ** 2) + mu**2) / ((n - 2) * (n - 1))
    
    return {'delta1': delta1, 'delta2': delta2}


def Pfun(A):
    """Estimate global parameters from the entire matrix"""
    n = A.shape[0]
    mu = np.sum(A) / (n * (n - 1))
    delta1 = (np.sum((A - mu)**2) - mu**2 * n) / (n * (n - 1))
    
    row_sums = np.sum(A, axis=1) - np.diag(A)
    A_centered_sq = (A - mu) ** 2
    row_sums_sq = np.sum(A_centered_sq, axis=1) - np.diag(A_centered_sq)
    row_contributions = (row_sums - (n - 1) * mu) ** 2 - row_sums_sq
    delta2 = np.sum(row_contributions) / (n * (n - 1) * (n - 2))
    
    return {'delta1': delta1, 'delta2': delta2}


def compute_all_statistics(A, index):
    """
    Calculate statistics T for all samples (vector form)
    
    Parameters:
        A: Kernel matrix (similarity matrix)
        index: Sample group labels
    
    Returns:
        T: Vector of statistics for all samples
    """
    N = len(index)
    Ns = pd.Series(index).value_counts()
    
    row_sums = np.sum(A, axis=1)
    row_sums_sq = np.sum(A**2, axis=1)
    mu_vec = row_sums / (N - 1)
    delta1_vec = row_sums_sq / (N - 1) - mu_vec**2
    
    A_centered = A - mu_vec[:, np.newaxis]
    row_sums_sq_centered = np.sum(A_centered**2, axis=1)
    sum_centered = row_sums - N * mu_vec
    delta2_vec = ((sum_centered + mu_vec)**2 - row_sums_sq_centered + mu_vec**2) / ((N - 2) * (N - 1))
    
    index_array = index.astype(int)
    group_indicator = (index_array[:, np.newaxis] == index_array)
    other_group_indicator = ~group_indicator
    
    pin_vec = np.sum(A * group_indicator, axis=1)
    pin_vec = pin_vec / (Ns[index_array].values - 1)
    
    pout_vec = np.sum(A * other_group_indicator, axis=1)
    pout_vec = pout_vec / (N - Ns[index_array].values)
    
    ng1_vec = Ns[index_array].values
    ng2_vec = N - ng1_vec
    delta_diff = np.abs(delta1_vec - delta2_vec + 1e-10)
    al_con_val_vec = 1.0 / np.sqrt((1.0 / (ng1_vec - 1) + 1.0 / ng2_vec) * delta_diff)
    
    equal_mask = np.abs(delta1_vec - delta2_vec) < 1e-10
    al_con_val_vec[equal_mask] = 0.0
    
    T = (pin_vec - pout_vec) * al_con_val_vec
    T = np.where(np.isnan(T), 0, T)
    
    return T


def Qblock_MOD(gi, Gs, Ns, p11, p12):
    """Calculate block of covariance matrix"""
    n = Ns.sum()
    
    Q = np.zeros(len(Gs))
    i = 0
    for gj in Gs:
        if gi == gj:
            Q[i] = (al_con(Ns[gi], n - Ns[gi], p11, p12) ** 2 *
                    ((1 / (n - Ns[gi]) + 1 / (Ns[gi] - 1)) * p12 -
                     (3 * p12 - p11) / (Ns[gi] - 1) ** 2))
            i += 1
        else:
            Q[i] = (np.sqrt(Ns[gi] - 1) * np.sqrt(Ns[gj] - 1) * 
                     (p11 - (n + 2) * p12) /
                     np.sqrt(n - Ns[gi]) / np.sqrt(n - Ns[gj]) /
                     (n - 1) / (p11 - p12))
            i += 1

    return Q


def CQ_MOD(index, p11, p12):
    """Calculate covariance matrix Sigma"""
    N = len(index)
    K = len(np.unique(index))
    Gs = np.arange(1, K + 1)
    Ns = pd.Series(index).value_counts()
    
    Qblocks = {}
    for k in Gs:
        Qblocks[k] = Qblock_MOD(k, Gs, Ns, p11, p12)
    
    Q = np.zeros((N, N))
    index_int = index.astype(int) - 1
    np.fill_diagonal(Q, 1.0)
    
    for k in Gs:
        nk = np.where(index == k)[0]
        if len(nk) == 0:
            continue
        Q_row_template = Qblocks[k][index_int]
        Q[nk, :] = Q_row_template
        Q[nk, nk] = 1.0
    
    return Q




from kernel import kernel_matrix, pairwise_distances, get_median_bandwidth


def MODboot(X, Y, kernel="gaussian", alpha=0.05, B=1000, seed=None):
    """
    MOD bootstrap test using multivariate normal distribution
    
    Steps:
    1. Calculate all statistics T, square them and take maximum
    2. Generate B samples from multivariate normal (mean=0, cov=Sigma),
       square each and take maximum, compute (1-alpha) quantile as critical value
    3. Compare statistic from step 1 with critical value from step 2
    
    Parameters:
        X: First group of samples
        Y: Second group of samples
        kernel: Kernel type string, default "gaussian"
        alpha: Significance level
        B: Number of bootstrap samples
        seed: Random seed
    
    Returns:
        result: Dictionary containing statistic, critical value, whether to reject null hypothesis
    """
    if seed is not None:
        np.random.seed(seed)
    
    n1 = X.shape[0]
    n2 = Y.shape[0]
    n = n1 + n2
    index = np.concatenate((np.ones(n1), np.full(n2, 2)))
    Z = np.concatenate([X, Y], axis=0)
    
    if kernel == "laplace" or kernel.endswith("_l1"):
        metric = "l1"
    else:
        metric = "l2"
    
    pairwise = pairwise_distances(Z, metric=metric)
    bandwidth = get_median_bandwidth(Z, metric=metric)
    A = kernel_matrix(pairwise, kernel, bandwidth, metric=metric)
    np.fill_diagonal(A, 0)
    
    # Step 1: Calculate all statistics T, square and take maximum
    T = compute_all_statistics(A, index)
    stat = np.max(T ** 2)
    
    # Step 2: Calculate covariance matrix Sigma
    Ps = Pfun(A)
    p11 = Ps['delta1']
    p12 = Ps['delta2']
    Sigma = CQ_MOD(index, p11, p12)
    
    # Add regularization for numerical stability
    regularization = 1e-8
    Sigma_reg = Sigma + regularization * np.eye(Sigma.shape[0])
    
    # Step 3: Generate B samples from multivariate normal (mean=0, cov=Sigma_reg)
    # Square each and take maximum, compute (1-alpha) quantile
    try:
        # Use Cholesky decomposition for efficient sampling
        L = np.linalg.cholesky(Sigma_reg)
        bootstrap_stats = np.zeros(B)
        for b in range(B):
            # Generate N samples from N(0, Sigma_reg)
            z = np.random.randn(n)
            sample = L @ z
            # Square and take maximum
            bootstrap_stats[b] = np.max(sample ** 2)
        
        # Compute (1-alpha) quantile as critical value
        thresh = np.quantile(bootstrap_stats, 1 - alpha)
    except np.linalg.LinAlgError:
        # Fallback: use eigenvalue decomposition if Cholesky fails
        eigenvals, eigenvecs = np.linalg.eigh(Sigma_reg)
        eigenvals = np.maximum(eigenvals, 1e-10)
        bootstrap_stats = np.zeros(B)
        for b in range(B):
            z = np.random.randn(n)
            sample = eigenvecs @ (np.sqrt(eigenvals) * z)
            bootstrap_stats[b] = np.max(sample ** 2)
        thresh = np.quantile(bootstrap_stats, 1 - alpha)
    
    # Step 4: Compare statistic with critical value
    reject = stat > thresh
    
    result = {
        'statistic': stat,
        'critical_value': thresh,
        'reject_null': reject
    }
    
    return result


