"""
Core statistical computations for CoBET methods.
"""
import numpy as np


def compute_full_T(A, B, W_A, W_B, W_C):
    """
    Compute the full CoBET test statistic T = T1 - 2*T2 + T3.
    
    Parameters
    ----------
    A, B : np.ndarray, shape (p, n)
        Feature matrices for X and Y
    W_A, W_B : np.ndarray, shape (p, p)
        Weight matrices for A and B
    W_C : np.ndarray, shape (2p, 2p)
        Combined weight matrix for [A; B]
        
    Returns
    -------
    T : float
        Test statistic value
        
    Notes
    -----
    The statistic is a U-statistic with three terms:
    - T1: Second-order term (kernel of order 2)
    - T2: Third-order correction term
    - T3: Fourth-order correction term
    """
    n = A.shape[1]
    
    # Compute kernel matrices
    KA = (A.T @ W_A) @ A  # (n, n)
    KB = (B.T @ W_B) @ B  # (n, n)
    C = np.vstack((A, B))
    KC = (C.T @ W_C) @ C  # (n, n)
    
    # Off-diagonal mask
    off = ~np.eye(n, dtype=bool)
    
    # T1: E[K_A(i,j) * K_B(i,j)] for i != j
    T1 = (KA[off] * KB[off]).sum() / (n * (n - 1))
    
    def compute_sums(K):
        """Helper to compute S1, S2, S3 from kernel matrix."""
        S1 = K.sum() - np.trace(K)
        row_sums_off = K.sum(axis=1) - np.diag(K)
        S2 = np.sum(row_sums_off ** 2)
        S3 = (K ** 2).sum() - np.trace(K ** 2)
        return S1, S2, S3
    
    S1C, S2C, S3C = compute_sums(KC)
    S1A, S2A, S3A = compute_sums(KA)
    S1B, S2B, S3B = compute_sums(KB)
    
    # T2: Third-order correction
    T2 = ((S2C - S3C) - (S2A - S3A) - (S2B - S3B)) / (2 * n * (n - 1) * (n - 2))
    
    # T3: Fourth-order correction
    def term(S1, S2, S3):
        return (S1 ** 2) - 4 * (S2 - S3) - 2 * S3
    
    T3_num = term(S1C, S2C, S3C) - term(S1A, S2A, S3A) - term(S1B, S2B, S3B)
    T3 = T3_num / (2 * n * (n - 1) * (n - 2) * (n - 3))
    
    return T1 - 2 * T2 + T3


def plugin_var_tildeT1(A, B, W_A, W_B, unbiased=True):
    """
    Compute plug-in variance estimate for the first-order term.
    
    Parameters
    ----------
    A, B : np.ndarray, shape (p, n)
        Feature matrices (centered)
    W_A, W_B : np.ndarray, shape (p, p)
        Weight matrices
    unbiased : bool, default=True
        Use unbiased covariance estimator (divide by n-1)
        
    Returns
    -------
    var_T1 : float
        Plug-in variance estimate
        
    Notes
    -----
    Uses the centered features and sample covariance matrices to estimate
    the asymptotic variance of the leading term ~T1.
    """
    n = A.shape[1]
    
    # Center features (should already be centered, but ensure it)
    A_c = A - A.mean(axis=1, keepdims=True)
    B_c = B - B.mean(axis=1, keepdims=True)
    
    # Sample covariance matrices
    denom = (n - 1) if unbiased else n
    S_A = (A_c @ A_c.T) / denom
    S_B = (B_c @ B_c.T) / denom
    
    # Trace computations
    E_A = np.trace(W_A @ S_A @ W_A @ S_A)
    E_B = np.trace(W_B @ S_B @ W_B @ S_B)
    
    return (2.0 / (n * (n - 1))) * E_A * E_B


def compute_Z_statistic(A, B, W_A, W_B, W_C, unbiased=True):
    """
    Compute standardized test statistic Z = T / sqrt(Var).
    
    Parameters
    ----------
    A, B : np.ndarray, shape (p, n)
        Feature matrices
    W_A, W_B : np.ndarray, shape (p, p)
        Weight matrices for A and B
    W_C : np.ndarray, shape (2p, 2p)
        Combined weight matrix
    unbiased : bool, default=True
        Use unbiased variance estimator
        
    Returns
    -------
    Z : float
        Standardized test statistic
    T : float
        Test statistic value
    var_T : float
        Variance estimate
    """
    T = compute_full_T(A, B, W_A, W_B, W_C)
    var_T = plugin_var_tildeT1(A, B, W_A, W_B, unbiased=unbiased)
    
    # Avoid division by zero
    Z = T / np.sqrt(max(var_T, 1e-16))
    
    return Z, T, var_T
