import warnings
import numpy as np
import torch
from sklearn.model_selection import train_test_split
import math
from typing import Tuple, Optional, Union
from numpy.typing import NDArray
import logging
try:
    import optuna.integration.lightgbm as lgb
    if not hasattr(lgb, 'LightGBMTuner'):
        import lightgbm as lgb # Fallback to standard
        HAS_OPTUNA_TUNER = False
    else:
        HAS_OPTUNA_TUNER = True
except ImportError:
    import lightgbm as lgb
    HAS_OPTUNA_TUNER = False

logger = logging.getLogger(__name__)

def get_device() -> torch.device:
    """
    Returns the available torch device, prioritizing CUDA.
    """
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def validate_device(device: Optional[torch.device] = None) -> torch.device:
    """
    Validate device availability and fallback to CPU if necessary.
    
    Args:
        device: Torch device to validate. If None, uses get_device()
    
    Returns:
        Valid torch device that is actually available
    """
    if device is None:
        device = get_device()
    
    # Check CUDA availability
    if device.type == 'cuda' and not torch.cuda.is_available():
        logger.warning(f"CUDA device {device} not available. Falling back to CPU.")
        device = torch.device('cpu')
    elif device.type == 'cuda' and torch.cuda.is_available():
        # Check if specific CUDA device index is available
        if device.index is not None and device.index >= torch.cuda.device_count():
            logger.warning(f"CUDA device {device} not available. Using cuda:0 instead.")
            device = torch.device('cuda:0')
    
    # Check MPS (Apple Silicon) availability if specified
    elif device.type == 'mps' and not (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()):
        logger.warning(f"MPS device not available. Falling back to CPU.")
        device = torch.device('cpu')
    
    return device

def _pairwise_sq_distances(
    X1: NDArray[np.floating], 
    X2: Optional[NDArray[np.floating]] = None, 
    device: Optional[torch.device] = None
) -> torch.Tensor:
    """
    Helper function to compute pairwise squared distances using torch.
    
    Args:
        X1: First set of points, shape (n_samples, n_features) or (n_samples,)
        X2: Second set of points, shape (m_samples, n_features) or (m_samples,).
            If None, uses X1 for both sets.
        device: Torch device to use. If None, uses get_device()
    
    Returns:
        Pairwise squared distance matrix of shape (n_samples, m_samples)
    """
    device = validate_device(device)
     
    X1 = X1.reshape([-1, 1]) if len(X1.shape) == 1 else X1
    if X2 is None:
        X2 = X1
    else:
        X2 = X2.reshape([-1, 1]) if len(X2.shape) == 1 else X2

    # Convert to torch tensors
    X1_torch = torch.tensor(X1, dtype=torch.float32).to(device)
    X2_torch = torch.tensor(X2, dtype=torch.float32).to(device)

    sq_dists = torch.cdist(X1_torch, X2_torch, p=2).pow(2)
    return sq_dists


def median_bandwidth(
    data: NDArray[np.floating], 
    device: Optional[torch.device] = None
) -> float:
    """
    Calculate median bandwidth for Gaussian kernel.
    
    Args:
        data: Input data array, shape (n_samples, n_features) or (n_samples,)
        device: Torch device for computation. If None, uses get_device()
    
    Returns:
        Median bandwidth as a float scalar
    """
    device = validate_device(device)
    
    try:
        sq_dists_matrix = _pairwise_sq_distances(data, device=device)
        upper_triangle_mask = torch.triu(torch.ones_like(sq_dists_matrix, dtype=bool), diagonal=1)
        pairwise_sq_dists = sq_dists_matrix[upper_triangle_mask]
        
        median_dist = torch.median(pairwise_sq_dists)
        return torch.sqrt(median_dist).item()
    except Exception as e:
        logger.warning(f"Error in median_bandwidth: {e}. Defaulting to bandwidth 1.0. Data shape: {data.shape}")
        return 1.0


def gaussian_kernel_gram(
    X1: NDArray[np.floating], 
    X2: Optional[NDArray[np.floating]] = None, 
    sigma: float = 1.0, 
    device: Optional[torch.device] = None
) -> torch.Tensor:
    """
    Gaussian kernel matrix (RBF kernel).
    
    Args:
        X1: First set of points, shape (n_samples, n_features) or (n_samples,)
        X2: Second set of points, shape (m_samples, n_features) or (m_samples,).
            If None, computes kernel matrix for X1 with itself.
        sigma: RBF kernel bandwidth parameter (standard deviation)
        device: Torch device for computation. If None, uses get_device()
    
    Returns:
        Gaussian kernel matrix of shape (n_samples, m_samples)
    """
    device = validate_device(device)
    
    sq_dists = _pairwise_sq_distances(X1, X2, device=device)
    return torch.exp(-sq_dists / (2 * (sigma**2)))


# --- Nuisance Estimation - Propensity Score Model ---
def fit_lightgbm_model(
    X_train: NDArray[np.floating],
    y_train: NDArray[np.floating],
    test_size: float = 0.2,
    objective: str = "binary",
    metric: str = "binary_logloss",
    custom_obj: Optional[callable] = None,
    custom_eval: Optional[callable] = None,
    weights_train: Optional[NDArray[np.floating]] = None,
    weights_val: Optional[NDArray[np.floating]] = None,
    random_state_lgb_split: Optional[int] = None,
    early_stopping_rounds_lgb: int = 10
) -> Optional[lgb.Booster]:
    """
    Fits a LightGBM model and returns the booster.
    
    Args:
        X_train: Training features, shape (n_samples, n_features)
        y_train: Training targets, shape (n_samples,)
        test_size: Fraction of data to use for validation
        objective: LightGBM objective function
        metric: LightGBM evaluation metric
        custom_obj: Custom objective function
        custom_eval: Custom evaluation function
        weights_train: Sample weights for training data
        weights_val: Sample weights for validation data
        random_state_lgb_split: Random state for train/validation split
        early_stopping_rounds_lgb: Early stopping rounds
    
    Returns:
        Trained LightGBM booster or None if training fails
    """
    X_train = X_train.reshape([-1, 1]) if len(X_train.shape) == 1 else X_train
    y_train = y_train.ravel()

    train_x, val_x, train_y, val_y = train_test_split(
        X_train, y_train, test_size=test_size, random_state=random_state_lgb_split)

    params = {
        "objective": objective,
        "metric": metric,
        "verbosity": -1,
        "boosting_type": "gbdt",
        # "min_child_samples": 20, # Prevent overfitting to small clusters
        # "reg_lambda": 1.0,       # L2 regularization to prevent hard 0/1 propensities
    }

    train_data = lgb.Dataset(train_x, train_y, weight=weights_train)
    val_data = lgb.Dataset(val_x, val_y, weight=weights_val)
    
    try:
        if HAS_OPTUNA_TUNER:
            tuner = lgb.LightGBMTuner(
                params,
                train_data,
                valid_sets=[val_data],
                callbacks=[lgb.early_stopping(early_stopping_rounds_lgb, verbose=False)],
                show_progress_bar=False,
                optuna_callbacks=None 
            )
            tuner.run()
            return tuner.get_best_booster()
        else:
            # Fallback: Standard training without Optuna
            params.update({"learning_rate": 0.05, "num_leaves": 31, "reg_lambda": 1.0})
            booster = lgb.train(
                params,
                train_data,
                num_boost_round=1000,
                valid_sets=[val_data],
                callbacks=[lgb.early_stopping(early_stopping_rounds_lgb, verbose=False)]
            )
            return booster
    except Exception as e: # Catch potential errors from LightGBM/Optuna with small data
        logger.error(f"Error during LightGBM tuning: {e}. Returning None for booster.")
        return None

# --- Nuisance Estimation - Component Matrices (Cross-fitted version) ---
def intermediate_C_and_E_matrices(
    X1: NDArray[np.floating], 
    A1: NDArray[np.floating], 
    X2: NDArray[np.floating], 
    A2: NDArray[np.floating], 
    pi1: NDArray[np.floating], 
    pi2: NDArray[np.floating], 
    K_for_betas: torch.Tensor, 
    reg_default: float = 0.001
) -> Tuple[NDArray[np.floating], NDArray[np.floating]]:
    """
    Computes C and E matrices for cross-fitted MMD estimation.
    
    Args:
        X1: Covariates for fold 1, shape (n1, n_features)
        A1: Treatment assignments for fold 1, shape (n1,)
        X2: Covariates for fold 2, shape (n2, n_features)
        A2: Treatment assignments for fold 2, shape (n2,)
        pi1: Propensity scores for fold 1, shape (n1,)
        pi2: Propensity scores for fold 2, shape (n2,)
        K_for_betas: Kernel matrix for regression, shape (n1+n2, n1+n2)
        reg_default: Default regularization parameter
    
    Returns:
        Tuple of (C, E) matrices, both shape (n1+n2, n1+n2)
    """
    n1, n2 = len(X1), len(X2)
    
    n1t, n1c = np.sum(A1), n1 - np.sum(A1)
    n2t, n2c = np.sum(A2), n2 - np.sum(A2)

    # reg = 1 / (n1 + n2) # Adaptive regularization # too aggressive
    reg = reg_default
    # reg = 1.0 / (50*np.sqrt(n1 + n2))  # Adaptive regularization

    pi1 = np.clip(pi1, 1e-9, 1.0 - 1e-9)
    pi2 = np.clip(pi2, 1e-9, 1.0 - 1e-9)

    ratio_1t = A1 / pi1
    ratio_1c = (1 - A1) / (1 - pi1)
    ratio_2t = A2 / pi2
    ratio_2c = (1 - A2) / (1 - pi2)

    C11 = np.diag((1 / (2 * n1)) * (ratio_1t - ratio_1c))
    C22 = np.diag((1 / (2 * n2)) * (ratio_2t - ratio_2c))

    coef_1t = (1 / (2 * n1)) * (1 - ratio_1t)
    coef_1c = -(1 / (2 * n1)) * (1 - ratio_1c)
    coef_2t = (1 / (2 * n2)) * (1 - ratio_2t)
    coef_2c = -(1 / (2 * n2)) * (1 - ratio_2c)
    
    # Solve for the beta values with regularization
    n1t, n1c, n2t, n2c = int(n1t), int(n1c), int(n2t), int(n2c)
    
    beta_2t = torch.linalg.solve(
        K_for_betas[n1:n1+n2t, n1:n1+n2t] + reg * torch.eye(n2t, device=K_for_betas.device),
        K_for_betas[:n1, n1:n1+n2t].T
    ).T.cpu().numpy()

    beta_2c = torch.linalg.solve(
        K_for_betas[n1+n2t:n1+n2, n1+n2t:n1+n2] + reg * torch.eye(n2c, device=K_for_betas.device),
        K_for_betas[:n1, n1+n2t:].T
    ).T.cpu().numpy()

    beta_1t = torch.linalg.solve(
        K_for_betas[:n1t, :n1t] + reg * torch.eye(n1t, device=K_for_betas.device),
        K_for_betas[n1:, :n1t].T
    ).T.cpu().numpy()

    beta_1c = torch.linalg.solve(
        K_for_betas[n1t:n1, n1t:n1] + reg * torch.eye(n1c, device=K_for_betas.device),
        K_for_betas[n1:, n1t:n1].T
    ).T.cpu().numpy()

    C12 = np.hstack((coef_1t[:, None] * beta_2t, coef_1c[:, None] * beta_2c))
    C21 = np.hstack((coef_2t[:, None] * beta_1t, coef_2c[:, None] * beta_1c))
    
    E11 = np.zeros((n1, n1))
    E22 = np.zeros((n2, n2))
    E12 = (1 / (2 * n1)) * np.hstack((beta_2t, beta_2c))
    E21 = -(1 / (2 * n2)) * np.hstack((beta_1t, beta_1c))

    C = np.block([[C11, C12], [C21, C22]])
    E = np.block([[E11, E12], [E21, E22]])
    return C, E


def wald_intermediate_matrices(
    C: torch.Tensor, 
    E: torch.Tensor, 
    n1: int, 
    n2: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute intermediate matrices for Wald-type test statistic.
    
    Args:
        C: C matrix from intermediate computation, shape (n1+n2, n1+n2)
        E: E matrix from intermediate computation, shape (n1+n2, n1+n2)
        n1: Number of samples in fold 1
        n2: Number of samples in fold 2
    
    Returns:
        Tuple of (D1, D2, V1, V2, W1, W2) matrices
    """
    D1 = torch.zeros_like(C)
    D1[:n1, :] = math.sqrt(2 * n1) * C[:n1, :]

    D2 = torch.zeros_like(C)
    D2[n1:, :] = math.sqrt(2 * n2) * C[n1:, :]

    V1 = torch.zeros_like(E)
    V1[:n1, :] = math.sqrt(2 / n1) * E[:n1, :]

    V2 = torch.zeros_like(E)
    V2[n1:, :] = math.sqrt(2 / n2) * E[n1:, :]

    W1 = D1 - n1 * V1
    W2 = D2 - n2 * V2
    return D1, D2, V1, V2, W1, W2


def proposed_mmd_statistic_components(
    X: NDArray[np.floating],
    A: NDArray[np.floating],
    Y: NDArray[np.floating],
    sigma_x: Optional[float] = None,
    sigma_y: Optional[float] = None,
    reg_default: float = 0.001,
    misspecify_propensity_model: bool = False,
    misspecify_outcome_model: bool = False,
    device: Optional[torch.device] = None,
    lgbm_random_state_seed: Optional[int] = None,
    propensity: Optional[NDArray[np.floating]] = None 
) -> Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int, NDArray, NDArray]:
    """
    Computes MMD and matrices K, L, C, E for cross-fitted MMD.
    
    Args:
        X: Covariates, shape (n_samples, n_features)
        A: Treatment assignments, shape (n_samples,)
        Y: Outcomes, shape (n_samples,)
        sigma_x: Bandwidth for covariate kernel
        sigma_y: Bandwidth for outcome kernel
        reg_default: Default regularization parameter
        misspecify_propensity_model: Whether to misspecify propensity model
        misspecify_outcome_model: Whether to misspecify outcome model
        device: Torch device for computation
        lgbm_random_state_seed: Random seed for LightGBM
        propensity: vector of true propensities if known
    Returns:
        Tuple of (mmd_value, K_torch, L_torch, KCL_torch, C_torch, E_torch, n1, n2, sort_idx1, sort_idx2)
    """
    device = validate_device(device)

    n = len(X)
    if n < 4: 
        raise ValueError(f"Not enough data for 2-fold split (n={n}). Need at least 4.")
    n_half = n // 2

    # Simple split for cross-fitting (not stratified, but common for this type of estimator)
    X1_orig, X2_orig = X[:n_half], X[n_half:]
    A1_orig, A2_orig = A[:n_half], A[n_half:]
    Y1_orig, Y2_orig = Y[:n_half], Y[n_half:]

    sort_idx1 = np.argsort(A1_orig)[::-1] # Get indices for descending sort
    X1_sorted, A1_sorted, Y1_sorted = X1_orig[sort_idx1], A1_orig[sort_idx1], Y1_orig[sort_idx1]

    sort_idx2 = np.argsort(A2_orig)[::-1] # Get indices for descending sort
    X2_sorted, A2_sorted, Y2_sorted = X2_orig[sort_idx2], A2_orig[sort_idx2], Y2_orig[sort_idx2]

    n1, n2 = len(X1_orig), len(X2_orig)

    X_full_sorted = np.concatenate((X1_sorted, X2_sorted))
    A_full_sorted = np.concatenate([A1_sorted, A2_sorted])
    Y_full_sorted = np.concatenate((Y1_sorted, Y2_sorted)) 
    
    if sigma_x is None:
        sigma_x = median_bandwidth(X_full_sorted, device=device)
        if sigma_x <= 1e-6 or np.isnan(sigma_x):
            logger.warning(f"Invalid sigma_x={sigma_x}, defaulting to 1.0")
            sigma_x = 1.0
            
    if sigma_y is None:
        sigma_y = median_bandwidth(Y_full_sorted, device=device)
        if sigma_y <= 1e-6 or np.isnan(sigma_y):
            logger.warning(f"Invalid sigma_y={sigma_y}, defaulting to 1.0")
            sigma_y = 1.0

    # Propensity scores
    if propensity is not None:
        # Use provided known propensities
        if len(propensity) != n:
            raise ValueError(f"Propensity length {len(propensity)} != data length {n}")
        
        logger.info("Using provided propensity scores")
        propensity_full = np.clip(propensity, 1e-9, 1.0 - 1e-9)
        
        # Split propensities to match the data splits
        propensity1_orig = propensity_full[:n_half]
        propensity2_orig = propensity_full[n_half:]
        
        # Apply same sorting indices as used for X, A, Y
        pi1_sorted = propensity1_orig[sort_idx1]
        pi2_sorted = propensity2_orig[sort_idx2]
        
    else:
        # Cross-fit propensity score models
        logger.debug("Estimating propensity scores via cross-fitting")
        
        # Train on fold 2, predict on fold 1 (sorted)
        X_train_for_pi1, A_train_for_pi1, X_pred_for_pi1 = X2_orig, A2_orig, X1_sorted
        # Train on fold 1, predict on fold 2 (sorted)
        X_train_for_pi2, A_train_for_pi2, X_pred_for_pi2 = X1_orig, A1_orig, X2_sorted
        
        def _get_ps_features(X_train, X_pred, misspec_flag):
            """Apply misspecification if requested."""
            X_train_mod, X_pred_mod = X_train.copy(), X_pred.copy()
            if misspec_flag:
                # Use only last feature (deliberate misspecification)
                if X_train_mod.shape[1] > 0: 
                    X_train_mod = X_train_mod[:, -1:]
                if X_pred_mod.shape[1] > 0: 
                    X_pred_mod = X_pred_mod[:, -1:]
            # Ensure 2D for LightGBM
            if X_train_mod.ndim == 1: 
                X_train_mod = X_train_mod.reshape(-1, 1)
            if X_pred_mod.ndim == 1: 
                X_pred_mod = X_pred_mod.reshape(-1, 1)
            return X_train_mod, X_pred_mod

        # Prepare features (possibly misspecified)
        X_train_for_pi1_mod, X_pred_for_pi1_mod = _get_ps_features(
            X_train_for_pi1, X_pred_for_pi1, misspecify_propensity_model
        )
        X_train_for_pi2_mod, X_pred_for_pi2_mod = _get_ps_features(
            X_train_for_pi2, X_pred_for_pi2, misspecify_propensity_model
        )

        # Fit model for fold 1
        model_pi1 = fit_lightgbm_model(
            X_train_for_pi1_mod, A_train_for_pi1, 
            random_state_lgb_split=lgbm_random_state_seed
        )
        if model_pi1: 
            pi1_sorted = model_pi1.predict(X_pred_for_pi1_mod)
        else: 
            logger.warning("pi1 model returned None. Defaulting to 0.5")
            pi1_sorted = np.full(A1_sorted.shape, 0.5)

        # Fit model for fold 2
        model_pi2 = fit_lightgbm_model(
            X_train_for_pi2_mod, A_train_for_pi2, 
            random_state_lgb_split=lgbm_random_state_seed
        )
        if model_pi2: 
            pi2_sorted = model_pi2.predict(X_pred_for_pi2_mod)
        else: 
            logger.warning("pi2 model returned None. Defaulting to 0.5")
            pi2_sorted = np.full(A2_sorted.shape, 0.5)

    K_torch = gaussian_kernel_gram(X_full_sorted, sigma=sigma_x, device=device)

    if misspecify_outcome_model:
        if X_full_sorted.shape[1] > 1:
            X_om_miss = X_full_sorted[:, -1:].copy()
            sigma_x_om_miss = median_bandwidth(X_om_miss) # Since KRR has access to X_om_miss only
            if sigma_x_om_miss <= 1e-6 or np.isnan(sigma_x_om_miss): sigma_x_om_miss = 1.0
            K_for_betas = gaussian_kernel_gram(X_om_miss, sigma=sigma_x_om_miss, device=device)
            # K_for_betas = gaussian_kernel_gram(X_om_miss, sigma=sigma_x, device=device)
        else:
            K_for_betas = K_torch
    else:
        K_for_betas = K_torch

    C, E = intermediate_C_and_E_matrices(
        X1_sorted, A1_sorted, X2_sorted, A2_sorted,
        pi1_sorted, pi2_sorted, K_for_betas, reg_default
    )

    C_torch = torch.tensor(C, dtype=torch.float32).to(device)
    E_torch = torch.tensor(E, dtype=torch.float32).to(device)
    L_torch = gaussian_kernel_gram(Y_full_sorted, sigma=sigma_y, device=device)

    KCL_torch = K_torch @ C_torch @ L_torch
    mmd_value = torch.sum(C_torch * KCL_torch).item() # This is the (unscaled) MMD^2_n

    return mmd_value, K_torch, L_torch, KCL_torch, C_torch, E_torch, n1, n2, sort_idx1, sort_idx2


def proposed_wald_type_statistic_components(
    K: torch.Tensor, 
    L: torch.Tensor, 
    KCL: torch.Tensor, 
    D1: torch.Tensor, 
    D2: torch.Tensor, 
    V1: torch.Tensor, 
    V2: torch.Tensor, 
    W1: torch.Tensor, 
    W2: torch.Tensor, 
    I: torch.Tensor, 
    mmd_value: float, 
    eps: Optional[float] = None
) -> Tuple[float, float]:
    """
    Compute Wald-type test statistic components.
    
    Args:
        K: Covariate kernel matrix
        L: Outcome kernel matrix
        KCL: Product K @ C @ L
        D1, D2, V1, V2, W1, W2: Intermediate matrices from wald_intermediate_matrices
        I: Identity matrix
        mmd_value: MMD statistic value
        eps: Regularization parameter. If None, computed adaptively.
    
    Returns:
        Tuple of (wald_statistic, eps_used)
    """
    n = K.shape[0]
    LD1T = L @ D1.T
    LD2T = L @ D2.T
    S1TGS1 = K * (D1 @ LD1T)
    S1TGS2 = K * (D1 @ LD2T)
    S2TGS1 = K * (D2 @ LD1T)
    S2TGS2 = K * (D2 @ LD2T)

    # -- Vector computations --
    KD1L = K @ D1 @ L
    KD2L = K @ D2 @ L
    S1T_G_d1 = torch.sum(D1 * KD1L, dim=1, keepdim=True)
    S1T_G_d2 = torch.sum(D1 * KD2L, dim=1, keepdim=True)
    S2T_G_d1 = torch.sum(D2 * KD1L, dim=1, keepdim=True)
    S2T_G_d2 = torch.sum(D2 * KD2L, dim=1, keepdim=True)

    KV1L = K @ V1 @ L
    KV2L = K @ V2 @ L
    S1T_G_v1 = torch.sum(D1 * KV1L, dim=1, keepdim=True)
    S1T_G_v2 = torch.sum(D1 * KV2L, dim=1, keepdim=True)
    S2T_G_v1 = torch.sum(D2 * KV1L, dim=1, keepdim=True)
    S2T_G_v2 = torch.sum(D2 * KV2L, dim=1, keepdim=True)

    KW1L = K @ W1 @ L
    KW2L = K @ W2 @ L
    S1T_G_w1 = torch.sum(D1 * KW1L, dim=1, keepdim=True)
    S1T_G_w2 = torch.sum(D1 * KW2L, dim=1, keepdim=True)
    S2T_G_w1 = torch.sum(D2 * KW1L, dim=1, keepdim=True)
    S2T_G_w2 = torch.sum(D2 * KW2L, dim=1, keepdim=True)

    # -- Scalar computations --
    D1_KV1L_frob = torch.sum(D1 * KV1L)
    D2_KV1L_frob = torch.sum(D2 * KV1L)
    V1_KV1L_frob = torch.sum(V1 * KV1L)
    V2_KV1L_frob = torch.sum(V2 * KV1L)

    D1_KV2L_frob = torch.sum(D1 * KV2L)
    D2_KV2L_frob = torch.sum(D2 * KV2L)
    V1_KV2L_frob = torch.sum(V1 * KV2L)
    V2_KV2L_frob = torch.sum(V2 * KV2L)

    D1_KW1L_frob = torch.sum(D1 * KW1L)
    D2_KW1L_frob = torch.sum(D2 * KW1L)
    V1_KW1L_frob = torch.sum(V1 * KW1L)
    V2_KW1L_frob = torch.sum(V2 * KW1L)

    D1_KW2L_frob = torch.sum(D1 * KW2L)
    D2_KW2L_frob = torch.sum(D2 * KW2L)
    V1_KW2L_frob = torch.sum(V1 * KW2L)
    V2_KW2L_frob = torch.sum(V2 * KW2L)

    # -- Assemble UTT = U^T T (size (2n+4) x (2n+4)) --
    UT_T_top_left = torch.cat([
        torch.cat([S1TGS1, S1TGS2], dim=1),
        torch.cat([S2TGS1, S2TGS2], dim=1)
    ], dim=0)  # 2n x 2n

    UT_T_top_right = torch.cat([
        torch.cat([S1T_G_v1, S2T_G_v1], dim=0),
        torch.cat([S1T_G_v2, S2T_G_v2], dim=0),
        torch.cat([S1T_G_w1, S2T_G_w1], dim=0),
        torch.cat([S1T_G_w2, S2T_G_w2], dim=0)
    ], dim=1) 

    UT_T_bottom_left = torch.cat([
        torch.cat([-S1T_G_d1, -S2T_G_d1], dim=0),
        torch.cat([-S1T_G_d2, -S2T_G_d2], dim=0),
        torch.cat([-S1T_G_v1, -S2T_G_v1], dim=0),
        torch.cat([-S1T_G_v2, -S2T_G_v2], dim=0)
    ], dim=1).T

    UT_T_bottom_right = torch.cat([
        torch.stack([-D1_KV1L_frob, -D1_KV2L_frob, -D1_KW1L_frob, -D1_KW2L_frob]).unsqueeze(0),
        torch.stack([-D2_KV1L_frob, -D2_KV2L_frob, -D2_KW1L_frob, -D2_KW2L_frob]).unsqueeze(0),
        torch.stack([-V1_KV1L_frob, -V1_KV2L_frob, -V1_KW1L_frob, -V1_KW2L_frob]).unsqueeze(0),
        torch.stack([-V2_KV1L_frob, -V2_KV2L_frob, -V2_KW1L_frob, -V2_KW2L_frob]).unsqueeze(0)
    ], dim=0) # 4 x 4

    UT_T = torch.cat([
        torch.cat([UT_T_top_left, UT_T_top_right], dim=1),
        torch.cat([UT_T_bottom_left, UT_T_bottom_right], dim=1)
    ], dim=0)  # (2n+4) x (2n+4)

    # -- Assemble cT_T = c^T T (size 1 x (2n + 4)) --
    S1T_G_c = torch.sum(D1 * KCL, dim=1)
    S2T_G_c = torch.sum(D2 * KCL, dim=1)

    V1_KCL_frob = torch.sum(V1 * KCL)
    V2_KCL_frob = torch.sum(V2 * KCL)
    W1_KCL_frob = torch.sum(W1 * KCL)
    W2_KCL_frob = torch.sum(W2 * KCL)

    cT_T = torch.cat((
        S1T_G_c,
        S2T_G_c,
        torch.stack([V1_KCL_frob, V2_KCL_frob, W1_KCL_frob, W2_KCL_frob])
    ))

    # -- Assemble UT_G_c = U^T G c (size (2n + 4) x 1) --
    D1_KCL_frob = torch.sum(S1T_G_c)
    D2_KCL_frob = torch.sum(S2T_G_c)

    UT_G_c = torch.cat((
        S1T_G_c,
        S2T_G_c,
        torch.stack([-D1_KCL_frob, -D2_KCL_frob, -V1_KCL_frob, -V2_KCL_frob])
    ))

    # -- Compute Wald-type value--
    if not eps:
        tr_val = torch.trace(UT_T).item()
        eps = tr_val/(3 + tr_val) # ~33% weightage on identity 

    z = torch.linalg.solve(
        (eps * I) + ((1 - eps) * UT_T),
        UT_G_c
    )

    term2 = ((1 - eps) / eps) * torch.dot(cT_T, z)
    term1 = (1 / eps) * mmd_value  # Fixed: was mmd

    return (term1 - term2).item(), eps


def proposed_fast_mmd_setup(
    K: torch.Tensor,
    L: torch.Tensor,
    C: torch.Tensor
) -> torch.Tensor:
    """
    Pre-compute the operator for fast O(n^2) MMD bootstrap.
    
    Returns Phi such that MMD_b = xi.T @ Phi @ xi
    """
    # MMD = Tr( diag(xi) K diag(xi) C L C.T )
    #     = xi.T @ (K * (C L C.T)) @ xi  [Element-wise product of K and CLC^T]
    
    # This is O(n^3) but performed only once.
    CLC_T = C @ L @ C.T
    
    # The operator is the element-wise product of K and CLC_T
    # Note: K is symmetric, CLC_T is symmetric (as L is symmetric).
    MMD_op = K * CLC_T
    
    return MMD_op

def proposed_fast_wald_setup(
    K: torch.Tensor,
    C: torch.Tensor, # Added C to signature
    L: torch.Tensor,
    D1: torch.Tensor,
    D2: torch.Tensor,
    V1: torch.Tensor,
    V2: torch.Tensor,
    W1: torch.Tensor,
    W2: torch.Tensor,
    I: torch.Tensor,
    eps: Optional[float] = None
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], float, dict]:
    """
    Pre-compute fixed covariance structure, LU factorization, and projection operators
    for fast Wald bootstrap.
    """
    # --- 1. Compute Fixed Covariance (Same as original slow setup) ---
    LD1T = L @ D1.T
    LD2T = L @ D2.T
    
    S1TGS1 = K * (D1 @ LD1T)
    S1TGS2 = K * (D1 @ LD2T)
    S2TGS1 = K * (D2 @ LD1T)
    S2TGS2 = K * (D2 @ LD2T)
    
    UT_T_top_left = torch.cat([
        torch.cat([S1TGS1, S1TGS2], dim=1),
        torch.cat([S2TGS1, S2TGS2], dim=1)
    ], dim=0)

    # Use existing basis matrices for covariance construction
    KD1L = K @ D1 @ L
    KD2L = K @ D2 @ L
    S1T_G_d1 = torch.sum(D1 * KD1L, dim=1, keepdim=True)
    S1T_G_d2 = torch.sum(D1 * KD2L, dim=1, keepdim=True)
    S2T_G_d1 = torch.sum(D2 * KD1L, dim=1, keepdim=True)
    S2T_G_d2 = torch.sum(D2 * KD2L, dim=1, keepdim=True)

    KV1L = K @ V1 @ L
    KV2L = K @ V2 @ L
    S1T_G_v1 = torch.sum(D1 * KV1L, dim=1, keepdim=True)
    S1T_G_v2 = torch.sum(D1 * KV2L, dim=1, keepdim=True)
    S2T_G_v1 = torch.sum(D2 * KV1L, dim=1, keepdim=True)
    S2T_G_v2 = torch.sum(D2 * KV2L, dim=1, keepdim=True)

    KW1L = K @ W1 @ L
    KW2L = K @ W2 @ L
    S1T_G_w1 = torch.sum(D1 * KW1L, dim=1, keepdim=True)
    S1T_G_w2 = torch.sum(D1 * KW2L, dim=1, keepdim=True)
    S2T_G_w1 = torch.sum(D2 * KW1L, dim=1, keepdim=True)
    S2T_G_w2 = torch.sum(D2 * KW2L, dim=1, keepdim=True)

    UT_T_top_right = torch.cat([
        torch.cat([S1T_G_v1, S2T_G_v1], dim=0),
        torch.cat([S1T_G_v2, S2T_G_v2], dim=0),
        torch.cat([S1T_G_w1, S2T_G_w1], dim=0),
        torch.cat([S1T_G_w2, S2T_G_w2], dim=0)
    ], dim=1)

    UT_T_bottom_left = torch.cat([
        torch.cat([-S1T_G_d1, -S2T_G_d1], dim=0),
        torch.cat([-S1T_G_d2, -S2T_G_d2], dim=0),
        torch.cat([-S1T_G_v1, -S2T_G_v1], dim=0),
        torch.cat([-S1T_G_v2, -S2T_G_v2], dim=0)
    ], dim=1).T

    # Scalar Frobenius products
    D1_KV1L_frob = torch.sum(D1 * KV1L)
    D2_KV1L_frob = torch.sum(D2 * KV1L)
    V1_KV1L_frob = torch.sum(V1 * KV1L)
    V2_KV1L_frob = torch.sum(V2 * KV1L)
    D1_KV2L_frob = torch.sum(D1 * KV2L)
    D2_KV2L_frob = torch.sum(D2 * KV2L)
    V1_KV2L_frob = torch.sum(V1 * KV2L)
    V2_KV2L_frob = torch.sum(V2 * KV2L)
    D1_KW1L_frob = torch.sum(D1 * KW1L)
    D2_KW1L_frob = torch.sum(D2 * KW1L)
    V1_KW1L_frob = torch.sum(V1 * KW1L)
    V2_KW1L_frob = torch.sum(V2 * KW1L)
    D1_KW2L_frob = torch.sum(D1 * KW2L)
    D2_KW2L_frob = torch.sum(D2 * KW2L)
    V1_KW2L_frob = torch.sum(V1 * KW2L)
    V2_KW2L_frob = torch.sum(V2 * KW2L)

    UT_T_bottom_right = torch.cat([
        torch.stack([-D1_KV1L_frob, -D1_KV2L_frob, -D1_KW1L_frob, -D1_KW2L_frob]).unsqueeze(0),
        torch.stack([-D2_KV1L_frob, -D2_KV2L_frob, -D2_KW1L_frob, -D2_KW2L_frob]).unsqueeze(0),
        torch.stack([-V1_KV1L_frob, -V1_KV2L_frob, -V1_KW1L_frob, -V1_KW2L_frob]).unsqueeze(0),
        torch.stack([-V2_KV1L_frob, -V2_KV2L_frob, -V2_KW1L_frob, -V2_KW2L_frob]).unsqueeze(0)
    ], dim=0)

    UT_T = torch.cat([
        torch.cat([UT_T_top_left, UT_T_top_right], dim=1),
        torch.cat([UT_T_bottom_left, UT_T_bottom_right], dim=1)
    ], dim=0)

    if eps is None:
        tr_val = torch.trace(UT_T).item()
        eps = tr_val / (3 + tr_val)

    operator_matrix = (eps * I) + ((1 - eps) * UT_T)
    lu_pivot = torch.linalg.lu_factor(operator_matrix)

    # --- 2. Compute Bootstrap Operators (The Fast Part) ---
    
    # MMD Operator
    MMD_op = proposed_fast_mmd_setup(K, L, C)

    # Common Matrix M = C @ L
    M = C @ L
    
    # Vector Projection Operators (Phi_vec @ xi -> vector)
    # Replaces: sum(D * KCL, dim=1)
    # Formula: K * (D @ M.T)
    H_d1 = K * (D1 @ M.T)
    H_d2 = K * (D2 @ M.T)

    # Scalar Projection Operators (phi_vec @ xi -> scalar)
    # Replaces: sum(V * KCL)
    # Formula: row_sums( (M @ V.T) * K )
    def get_scalar_proj_vec(Z):
        # returns vector h such that h.dot(xi) = sum(Z * K(diag(xi)C)L)
        term1 = M @ Z.T
        return torch.sum(term1 * K, dim=1)
    
    h_v1 = get_scalar_proj_vec(V1)
    h_v2 = get_scalar_proj_vec(V2)
    h_w1 = get_scalar_proj_vec(W1)
    h_w2 = get_scalar_proj_vec(W2)

    bootstrap_ops = {
        "MMD_op": MMD_op,
        "H_d1": H_d1, "H_d2": H_d2,
        "h_v1": h_v1, "h_v2": h_v2,
        "h_w1": h_w1, "h_w2": h_w2
    }

    return lu_pivot, eps, bootstrap_ops

def proposed_fast_wald_step(
    xi: torch.Tensor,
    lu_pivot: Tuple[torch.Tensor, torch.Tensor],
    ops: dict,
    eps: float
) -> float:
    """
    Compute Wald statistic in O(n^2) using precomputed operators.
    """
    # 1. Compute MMD efficiently O(n^2)
    # MMD = xi^T @ MMD_op @ xi
    # Note: If xi is just ones (original stat), this recovers original MMD^2
    temp_mmd = ops["MMD_op"] @ xi
    mmd_val = torch.dot(xi, temp_mmd).item()

    # 2. Compute Vector Projections O(n^2)
    S1T_G_c = ops["H_d1"] @ xi
    S2T_G_c = ops["H_d2"] @ xi

    # 3. Compute Scalar Projections O(n)
    V1_KCL_frob = torch.dot(ops["h_v1"], xi)
    V2_KCL_frob = torch.dot(ops["h_v2"], xi)
    W1_KCL_frob = torch.dot(ops["h_w1"], xi)
    W2_KCL_frob = torch.dot(ops["h_w2"], xi)

    # Assemble RHS vectors
    cT_T = torch.cat((
        S1T_G_c,
        S2T_G_c,
        torch.stack([V1_KCL_frob, V2_KCL_frob, W1_KCL_frob, W2_KCL_frob])
    ))

    UT_G_c = torch.cat((
        S1T_G_c,
        S2T_G_c,
        torch.stack([-torch.sum(S1T_G_c), -torch.sum(S2T_G_c), 
                     -V1_KCL_frob, -V2_KCL_frob])
    ))
    
    # Note on UT_G_c bottom part: 
    # In original code: D1_KCL_frob = torch.sum(S1T_G_c).
    # Since we computed S1T_G_c explicitly above, we can just sum it.
    
    # Solve linear system O(n^2)
    z = torch.linalg.lu_solve(lu_pivot[0], lu_pivot[1], UT_G_c.unsqueeze(1)).squeeze()

    # Final result
    term2 = ((1 - eps) / eps) * torch.dot(cT_T, z)
    term1 = (1 / eps) * mmd_val

    return (term1 - term2).item()


def codite_mmd_statistic(
    X: NDArray[np.floating], 
    A: NDArray[np.floating], 
    Y: NDArray[np.floating], 
    reg_default: float = 0.001, 
    misspecify_outcome_model: bool = False,
    device: Optional[torch.device] = None  # Add this parameter
) -> float:
    """
    Compute CODITE MMD statistic components.
    
    Args:
        X: Covariates, shape (n_samples, n_features)
        A: Treatment assignments, shape (n_samples,)
        Y: Outcomes, shape (n_samples,)
        reg_default: Default regularization parameter
        misspecify_outcome_model: Whether to misspecify outcome model
        device: Torch device for computation. If None, uses get_device()
    
    Returns:
        CODITE MMD statistic value
    """
    device = validate_device(device)

    n = len(X)
    nt = np.sum(A).astype(int)
    nc = n - nt

    sort_idx = np.argsort(A)[::-1] # Get indices for descending sort
    X_sorted, A_sorted, Y_sorted = X[sort_idx], A[sort_idx], Y[sort_idx]

    sigma_y = median_bandwidth(Y_sorted, device=device)
    L = gaussian_kernel_gram(Y_sorted, sigma=sigma_y, device=device)

    if misspecify_outcome_model:
        if X_sorted.shape[1] > 1: # Only makes sense if there's more than one feature
            X_sorted = X_sorted[:, -1:].copy() # Use only the last feature for K_beta

    Xt = X_sorted[:nt, :]
    Xc = X_sorted[nt:, :]

    sigma_xt = median_bandwidth(Xt, device=device) # Under missspec, KRR has access to last feature only
    sigma_xc = median_bandwidth(Xc, device=device) # Under missspec, KRR has access to last feature only

    K_X_Xt = gaussian_kernel_gram(X_sorted, Xt, sigma=sigma_xt, device=device) # This is K(X_eval, X_t)
    K_X_Xc = gaussian_kernel_gram(X_sorted, Xc, sigma=sigma_xc, device=device) # This is K(X_eval, X_c)

    # reg_t = 1/nt # Adaptive regularization # too aggressive
    # reg_c = 1/nc # Adaptive regularization # too aggressive
    reg_t = reg_default
    reg_c = reg_default

    # reg_t = 1.0 / (50*np.sqrt(nt)) # Adaptive regularization
    # reg_c = 1.0 / (50*np.sqrt(nc)) # Adaptive regularization

    Mt = torch.linalg.solve(K_X_Xt[:nt, :] + reg_t * torch.eye(nt, device=device),
                            K_X_Xt.T).T
    Mc = torch.linalg.solve(K_X_Xc[nt:, :] + reg_c * torch.eye(nc, device=device),
                            K_X_Xc.T).T # (N_eval_points, n_c)

    term1 = torch.trace(Mc @ L[nt:, nt:] @ Mc.T)
    term2 = torch.trace(Mc @ L[nt:, :nt] @ Mt.T)
    term3 = torch.trace(Mt @ L[:nt, :nt] @ Mt.T)

    return ((term1 - 2 * term2 + term3) / n).item()