"""
CODITE MMD Test Implementation

This module provides the main test function for CODITE MMD statistical testing.
"""

import numpy as np
import torch
from typing import Optional, Tuple
from numpy.typing import NDArray
from utils import codite_mmd_statistic, fit_lightgbm_model, validate_device
import logging

logger = logging.getLogger(__name__)

def codite_mmd_test(
    X: NDArray[np.floating], 
    A: NDArray[np.floating], 
    Y: NDArray[np.floating], 
    reg: float = 0.001, 
    misspecify_propensity_model: bool = False, 
    misspecify_outcome_model: bool = False, 
    num_perm: int = 100, 
    propensity: Optional[NDArray[np.floating]] = None,
    random_state: Optional[int] = None,
    device: Optional[torch.device] = None  # Add this parameter
) -> Tuple[float, float]:
    """
    Perform CODITE MMD test for treatment effect heterogeneity.
    
    This function tests the null hypothesis of no treatment effect heterogeneity
    using the CODITE MMD statistic with permutation testing.
    
    Args:
        X: Covariates, shape (n_samples, n_features)
        A: Treatment assignments, shape (n_samples,) with values in {0, 1}
        Y: Outcomes, shape (n_samples,)
        reg: Regularization parameter for kernel ridge regression
        misspecify_propensity_model: If True, uses only last feature for propensity model
        misspecify_outcome_model: If True, uses only last feature for outcome model
        num_perm: Number of permutations for p-value computation
        propensity: Pre-computed propensity scores, shape (n_samples,). 
                   If None, estimates using LightGBM
        random_state: Random seed for reproducible permutations
    
    Returns:
        Tuple of (test_statistic, p_value)
        - test_statistic: CODITE MMD statistic value
        - p_value: Permutation-based p-value for the test
    
    Raises:
        ValueError: If inputs have incompatible shapes or invalid values
        
    Notes:
        - The null hypothesis is no treatment effect heterogeneity
        - Small p-values (< 0.05) suggest evidence of treatment effect heterogeneity
        - Propensity scores are clipped to [1e-9, 1-1e-9] for numerical stability
    """
    device = validate_device(device)

    # Input validation
    X = np.asarray(X)
    A = np.asarray(A)
    Y = np.asarray(Y)
    
    if not (len(X) == len(A) == len(Y)):
        raise ValueError(f"Input arrays must have same length: X={len(X)}, A={len(A)}, Y={len(Y)}")
    
    if not np.all(np.isin(A, [0, 1])):
        raise ValueError("Treatment assignments A must be binary (0 or 1)")
    
    if num_perm < 1:
        raise ValueError(f"num_perm must be positive, got {num_perm}")
    
    if not (0 < reg):
        raise ValueError(f"reg must be in > 0, got {reg}")
    
    if random_state is not None:
        np.random.seed(random_state)
    
    t_hat = codite_mmd_statistic(
        X, A, Y, 
        reg_default=reg, 
        misspecify_outcome_model=misspecify_outcome_model,
        device=device
    )

    if propensity is not None:
        if len(propensity) != len(A):
            raise ValueError(f"Propensity length {len(propensity)} != A length {len(A)}")
        e_hat = np.clip(propensity, 1e-9, 1.0 - 1e-9)
    else:
        if misspecify_propensity_model and X.shape[1] > 1:
            X_ps = X[:, -1:].copy()  # Use only last feature
        else:
            X_ps = X.copy()

        if X_ps.ndim == 1:
            X_ps = X_ps.reshape(-1, 1)
        
        # Fit propensity score model
        model_e_hat = fit_lightgbm_model(X_ps, A)
        if model_e_hat:
            e_hat = model_e_hat.predict(X_ps)
            e_hat = np.clip(e_hat, 1e-9, 1.0 - 1e-9)
        else:
            logger.warning("Warning: Propensity score model returned None. Defaulting to 0.5")
            e_hat = np.full(A.shape, 0.5)
    
    # Permutation test
    t_k = np.zeros(num_perm)
    for k in range(num_perm):
        A_tilde = np.random.binomial(1, e_hat)
        t_k[k] = codite_mmd_statistic(
            X, A_tilde, Y, 
            reg_default=reg, 
            misspecify_outcome_model=misspecify_outcome_model,
            device=device
        )
    # Compute p-value (including observed statistic in both numerator and denominator)
    p_value = (1 + np.sum(t_k >= t_hat)) / (1 + num_perm)
    
    return t_hat, p_value


def codite_mmd_test_batch(
    datasets: list,
    reg: float = 0.001,
    misspecify_propensity_model: bool = False,
    misspecify_outcome_model: bool = False,
    num_perm: int = 100,
    random_state: Optional[int] = None
) -> list:
    """
    Run CODITE MMD test on multiple datasets.
    
    Args:
        datasets: List of tuples (X, A, Y) or (X, A, Y, propensity)
        reg: Regularization parameter
        misspecify_propensity_model: Whether to misspecify propensity model
        misspecify_outcome_model: Whether to misspecify outcome model  
        num_perm: Number of permutations
        random_state: Random seed
    
    Returns:
        List of tuples (test_statistic, p_value) for each dataset
    """
    results = []
    
    for i, data in enumerate(datasets):
        if len(data) == 3:
            X, A, Y = data
            propensity = None
        elif len(data) == 4:
            X, A, Y, propensity = data
        else:
            raise ValueError(f"Dataset {i} must have 3 or 4 elements (X, A, Y) or (X, A, Y, propensity)")
        
        try:
            t_hat, p_value = codite_mmd_test(
                X, A, Y,
                reg=reg,
                misspecify_propensity_model=misspecify_propensity_model,
                misspecify_outcome_model=misspecify_outcome_model,
                num_perm=num_perm,
                propensity=propensity,
                random_state=random_state
            )
            results.append((t_hat, p_value))
            
        except Exception as e:
            logger.error(f"Error processing dataset {i}: {e}")
            results.append((np.nan, np.nan))
    
    return results


# Example usage
if __name__ == "__main__":
    # Generate synthetic data for testing
    np.random.seed(42)
    n = 200
    
    # Covariates
    X = np.random.normal(0, 1, (n, 3))
    
    # Treatment assignment (with some dependence on X)
    propensity_true = 1 / (1 + np.exp(-(0.5 * X[:, 0] + 0.3 * X[:, 1])))
    A = np.random.binomial(1, propensity_true)
    
    # Outcomes with heterogeneous treatment effects
    treatment_effect = 1.0 + 0.5 * X[:, 0]  # Effect varies with first covariate
    Y = (2 * X[:, 0] + X[:, 1] + 
         A * treatment_effect + 
         np.random.normal(0, 0.5, n))
    
    # Run test
    print("Running CODITE MMD test...")
    test_stat, p_val = codite_mmd_test(
        X, A, Y, 
        num_perm=50,  # Use fewer permutations for demo
        random_state=42
    )
    
    print(f"Test statistic: {test_stat:.6f}")
    print(f"P-value: {p_val:.4f}")
    
    if p_val < 0.05:
        print("Reject null hypothesis: Evidence of treatment effect heterogeneity")
    else:
        print("Fail to reject null hypothesis: No strong evidence of heterogeneity")