"""
Proposed MMD and Wald Test Implementation

This module provides bootstrap-based testing for treatment effect heterogeneity
using the proposed cross-fitted MMD and Wald-type statistics.
"""

import numpy as np
import torch
from typing import Optional, Tuple, Literal
from numpy.typing import NDArray
import logging

logger = logging.getLogger(__name__)

# Import utility functions
from utils import (
    proposed_mmd_statistic_components,
    wald_intermediate_matrices,
    proposed_wald_type_statistic_components,
    proposed_fast_mmd_setup,
    proposed_fast_wald_setup,
    proposed_fast_wald_step,
    validate_device,
    get_device
)

def proposed_mmd_wald_test(
    X: NDArray[np.floating],
    A: NDArray[np.floating],
    Y: NDArray[np.floating],
    test_type: Literal["mmd", "wald"] = "mmd",
    sigma_x: Optional[float] = None,
    sigma_y: Optional[float] = None,
    reg: float = 0.001,
    misspecify_propensity_model: bool = False,
    misspecify_outcome_model: bool = False,
    n_bootstrap: int = 1000,
    alpha: float = 0.05,
    device: Optional[torch.device] = None,
    random_state: Optional[int] = None,
    lgbm_random_state_seed: Optional[int] = None,
    propensity: Optional[NDArray[np.floating]] = None,
    fast_wald: bool = True,
) -> Tuple[float, float, bool, dict]:
    """
    Perform proposed MMD or Wald test for treatment effect heterogeneity with bootstrap.
    
    This function tests the null hypothesis of no treatment effect heterogeneity
    using cross-fitted MMD or Wald-type statistics with bootstrap p-values.
    
    Args:
        X: Covariates, shape (n_samples, n_features)
        A: Treatment assignments, shape (n_samples,) with values in {0, 1}
        Y: Outcomes, shape (n_samples,)
        test_type: Type of test statistic - "mmd" or "wald"
        sigma_x: Bandwidth parameter for covariate kernel
        sigma_y: Bandwidth parameter for outcome kernel
        default_reg: Regularization parameter for kernel ridge regression
        misspecify_propensity_model: If True, uses only last feature for propensity model
        misspecify_outcome_model: If True, uses only last feature for outcome model
        n_bootstrap: Number of bootstrap samples for p-value computation
        alpha: Significance level for hypothesis test
        device: Torch device for computation. If None, uses get_device()
        random_state: Random seed for bootstrap sampling
        lgbm_random_state_seed: Random seed for LightGBM propensity score models
        propensity: vector of true propensities if known
        fast_wald: If True (default), uses fast Wald test with fixed basis and pre-computed
                   LU factorization. If False, uses standard Wald test that recomputes basis
                   for each bootstrap sample. Only affects test_type="wald". The fast version
                   is much faster but uses a fixed basis computed from the observed data.
    
    Returns:
        Tuple of (test_statistic, p_value, rejected, info_dict)
        - test_statistic: Scaled test statistic value (n * mmd or n * wald)
        - p_value: Bootstrap-based p-value
        - rejected: Boolean indicating if null hypothesis is rejected at alpha level
        - info_dict: Dictionary with additional information (raw_mmd, eps, etc.)
    
    Raises:
        ValueError: If inputs have incompatible shapes or invalid values
        
    Notes:
        - Uses cross-fitted estimation to reduce overfitting bias
        - Bootstrap weights are generated using multinomial resampling
        - Wald test provides better finite-sample properties but is computationally heavier
    """
    device = validate_device(device)

    # Input validation
    X = np.asarray(X)
    A = np.asarray(A) 
    Y = np.asarray(Y)
    
    if not (len(X) == len(A) == len(Y)):
        raise ValueError(f"Input arrays must have same length: X={len(X)}, A={len(A)}, Y={len(Y)}")
    
    if not np.all(np.isin(A, [0, 1])):
        raise ValueError("Treatment assignments A must be binary (0 or 1)")
    
    if test_type not in ["mmd", "wald"]:
        raise ValueError(f"test_type must be 'mmd' or 'wald', got {test_type}")
    
    if n_bootstrap < 1:
        raise ValueError(f"n_bootstrap must be positive, got {n_bootstrap}")
    
    if not (0 < alpha < 1):
        raise ValueError(f"alpha must be in (0, 1), got {alpha}")
    
    n_samples = len(X)
    if n_samples < 4:
        raise ValueError(f"Need at least 4 samples for cross-fitting, got {n_samples}")
    
    if random_state is not None:
        np.random.seed(random_state)
        torch.manual_seed(random_state)
    
    mmd, K, L, KCL, C, E, n1, n2, sort_idx1, sort_idx2 = proposed_mmd_statistic_components(
        X, A, Y, 
        sigma_x=sigma_x, 
        sigma_y=sigma_y, 
        reg_default=reg,
        misspecify_propensity_model=misspecify_propensity_model, 
        misspecify_outcome_model=misspecify_outcome_model,
        device=device,
        lgbm_random_state_seed=lgbm_random_state_seed,
        propensity=propensity
    )
    
    # Compute test statistic
    test_statistic = n_samples * mmd

    eps_used = None
    lu_pivot = None  # Will be used for fast Wald
    bootstrap_ops = None
    D1, D2, V1, V2, W1, W2 = None, None, None, None, None, None 

    if test_type == "mmd":
        MMD_op = proposed_fast_mmd_setup(K, L, C)
        bootstrap_ops = {"MMD_op": MMD_op}

    elif test_type == "wald":
        D1, D2, V1, V2, W1, W2 = wald_intermediate_matrices(C, E, n1, n2)
        I = torch.eye(2 * n_samples + 4, device=device)

        if fast_wald:
            # Fast Wald: Pre-compute LU factorization (O(n^3) done once) #Changed, added C; return bootstrap_ops
            lu_pivot, eps_used, bootstrap_ops = proposed_fast_wald_setup(
                K, C, L, D1, D2, V1, V2, W1, W2, I, eps=None
            )
            # Compute observed statistic using fast step (function call changed)
            # observed statistic corresponds to weights = 1.0 for all samples
            ref_tensor = bootstrap_ops["MMD_op"]
            ones_vec = torch.ones(n_samples, dtype=ref_tensor.dtype, device=ref_tensor.device)
            wald = proposed_fast_wald_step(ones_vec, lu_pivot, bootstrap_ops, eps_used)
        else:
            # Standard Wald: Compute normally
            wald, eps_used = proposed_wald_type_statistic_components(
                K, L, KCL, D1, D2, V1, V2, W1, W2, I, mmd
            )
        test_statistic = n_samples * wald
    
    # Bootstrap to get p-value
    bootstrap_samples = np.zeros(n_bootstrap)
    
    # logger.info(f"Running {test_type.upper()} test with {n_bootstrap} bootstrap samples...")
    
    for b in range(n_bootstrap):
        if b > 0 and b % 500 == 0:
            logger.debug(f"Bootstrap progress: {b}/{n_bootstrap}")
            
        # Generate bootstrap weights using multinomial resampling
        mult_1 = np.random.multinomial(n1, [1/n1] * n1) - 1
        mult_2 = np.random.multinomial(n2, [1/n2] * n2) - 1
        
        # Apply sort indices to match the sorted data
        mult_1 = torch.tensor(mult_1[sort_idx1], dtype=torch.float32, device=device)
        mult_2 = torch.tensor(mult_2[sort_idx2], dtype=torch.float32, device=device)
        
        # Combine weights
        mult = torch.cat([mult_1, mult_2], dim=0)
        
        if test_type == "mmd":
            temp = bootstrap_ops["MMD_op"] @ mult
            mmd_b = torch.dot(mult, temp).item()
            bootstrap_samples[b] = n_samples * mmd_b
        else:
            if fast_wald:
                wald_b = proposed_fast_wald_step(mult, lu_pivot, bootstrap_ops, eps_used)
                bootstrap_samples[b] = n_samples * wald_b
            else:
                C_ = C.clone().detach() * mult[:, None]
                KCL_ = K @ C_ @ L
                mmd_b = torch.sum(C_ * KCL_).item()

                E_ = E.clone().detach() * mult[:, None]
                D1_b, D2_b, V1_b, V2_b, W1_b, W2_b = wald_intermediate_matrices(C_, E_, n1, n2)
                wald_b, _ = proposed_wald_type_statistic_components(
                    K, L, KCL_, D1_b, D2_b, V1_b, V2_b, W1_b, W2_b, I, mmd_b, eps=eps_used
                )
                bootstrap_samples[b] = n_samples * wald_b
    
    # Calculate p-value
    p_value = np.mean(bootstrap_samples >= test_statistic)
    
    # Test decision
    rejected = p_value <= alpha
    
    # Additional information
    info_dict = {
        'raw_mmd': mmd,
        'n1': n1,
        'n2': n2,
        'eps_used': eps_used if test_type == "wald" else None,
        'bootstrap_samples': bootstrap_samples,
        'test_type': test_type,
        'sigma_x': sigma_x,
        'sigma_y': sigma_y,
        'n_bootstrap': n_bootstrap,
        'fast_wald': fast_wald if test_type == "wald" else None
    }
    
    return test_statistic, p_value, rejected, info_dict


def proposed_mmd_test(
    X: NDArray[np.floating],
    A: NDArray[np.floating], 
    Y: NDArray[np.floating],
    **kwargs
) -> Tuple[float, float, bool, dict]:
    """
    Convenience function for proposed MMD test only.
    
    Args:
        X: Covariates, shape (n_samples, n_features)
        A: Treatment assignments, shape (n_samples,)
        Y: Outcomes, shape (n_samples,)
        **kwargs: Additional arguments passed to proposed_mmd_wald_test
    
    Returns:
        Same as proposed_mmd_wald_test with test_type="mmd"
    """
    return proposed_mmd_wald_test(X, A, Y, test_type="mmd", **kwargs)


def proposed_wald_test(
    X: NDArray[np.floating],
    A: NDArray[np.floating],
    Y: NDArray[np.floating], 
    **kwargs
) -> Tuple[float, float, bool, dict]:
    """
    Convenience function for proposed Wald test only.
    
    Args:
        X: Covariates, shape (n_samples, n_features)
        A: Treatment assignments, shape (n_samples,)
        Y: Outcomes, shape (n_samples,)
        **kwargs: Additional arguments passed to proposed_mmd_wald_test
    
    Returns:
        Same as proposed_mmd_wald_test with test_type="wald"
    """
    return proposed_mmd_wald_test(X, A, Y, test_type="wald", **kwargs)


def compare_mmd_wald_tests(
    X: NDArray[np.floating],
    A: NDArray[np.floating],
    Y: NDArray[np.floating],
    n_bootstrap: int = 1000,
    **kwargs
) -> dict:
    """
    Compare MMD and Wald test results side by side.
    
    Args:
        X: Covariates, shape (n_samples, n_features)
        A: Treatment assignments, shape (n_samples,)
        Y: Outcomes, shape (n_samples,)
        n_bootstrap: Number of bootstrap samples
        **kwargs: Additional arguments passed to both tests
    
    Returns:
        Dictionary with results from both tests
    """
    logger.info("Running comparison of MMD and Wald tests...")
    
    # Run MMD test
    mmd_stat, mmd_pval, mmd_rejected, mmd_info = proposed_mmd_test(
        X, A, Y, n_bootstrap=n_bootstrap, **kwargs
    )
    
    # Run Wald test  
    wald_stat, wald_pval, wald_rejected, wald_info = proposed_wald_test(
        X, A, Y, n_bootstrap=n_bootstrap, **kwargs
    )
    
    return {
        'mmd': {
            'statistic': mmd_stat,
            'p_value': mmd_pval,
            'rejected': mmd_rejected,
            'info': mmd_info
        },
        'wald': {
            'statistic': wald_stat,
            'p_value': wald_pval, 
            'rejected': wald_rejected,
            'info': wald_info
        },
        'raw_mmd': mmd_info['raw_mmd']  # Same for both tests
    }


# Example usage
if __name__ == "__main__":
    # Generate synthetic data
    np.random.seed(42)
    n = 150
    
    # Covariates
    X = np.random.normal(0, 1, (n, 2))
    
    # Treatment assignment
    propensity_true = 1 / (1 + np.exp(-(0.5 * X[:, 0])))
    A = np.random.binomial(1, propensity_true)
    
    # Outcomes with heterogeneous treatment effects  
    treatment_effect = 1.0 + 0.8 * X[:, 0]  # Strong heterogeneity
    Y = (X[:, 0] + 0.5 * X[:, 1] + 
         A * treatment_effect + 
         np.random.normal(0, 0.5, n))
    
    print("Example: Testing for treatment effect heterogeneity")
    print(f"Sample size: {n}")
    print(f"Treatment rate: {np.mean(A):.3f}")
    
    # Run comparison
    results = compare_mmd_wald_tests(
        X, A, Y, 
        n_bootstrap=200,  # Fewer for demo
        random_state=42
    )
    
    print(f"\nResults:")
    print(f"MMD test - Statistic: {results['mmd']['statistic']:.6f}, "
          f"P-value: {results['mmd']['p_value']:.4f}, "
          f"Rejected: {results['mmd']['rejected']}")
    print(f"Wald test - Statistic: {results['wald']['statistic']:.6f}, "
          f"P-value: {results['wald']['p_value']:.4f}, "
          f"Rejected: {results['wald']['rejected']}")
    
    if results['mmd']['rejected'] or results['wald']['rejected']:
        print("\nConclusion: Evidence of treatment effect heterogeneity detected")
    else:
        print("\nConclusion: No strong evidence of treatment effect heterogeneity")