import numpy as np
import pandas as pd
import statsmodels.api as sm
from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler
import shap
import matplotlib.pyplot as plt

# A GBT model is assumed to be pre-trained and available as `model`


def get_shap_values(model, X):
    """
    Initialize the SHAP TreeExplainer and compute SHAP interaction values.
    TODO: May need to adjust for different model types.

    Parameters:
        model: Trained tree-based model (e.g., XGBoost, LightGBM, etc.)
        X: Input data (numpy array or pandas DataFrame)

    Returns:
        shap_interaction_values: SHAP interaction values for X
    """
    explainer = shap.TreeExplainer(model)
    shap_interaction_values = explainer.shap_interaction_values(X)
    return shap_interaction_values

def _create_target_from_shap_interactions(shap_interaction_values, i, j):
    """
    Create target variable from SHAP interaction values for a feature pair.

    Parameters:
        shap_interaction_values: SHAP interaction values array
        i, j: Feature indices

    Returns:
        y_target: Combined SHAP interaction values
    """
    return (
        shap_interaction_values[:, i, j]
        + shap_interaction_values[:, j, i]
        + shap_interaction_values[:, i, i]
        + shap_interaction_values[:, j, j]
    )


def _create_feature_matrix(X, i, j, linear=False):
    """
    Create feature matrix for OLS regression including interaction term.
    Variables are normalized to ensure beta3 is a unified value.

    Parameters:
        X: Input feature matrix
        i, j: Feature indices
        linear: If True, create linear model without interaction term

    Returns:
        X_with_const: Feature matrix with constant term for OLS
    """
    if isinstance(X, pd.DataFrame):
        X = X.values
    
    # Extract features and normalize them
    feature_i_values = X[:, i]
    feature_j_values = X[:, j]
    
    # Normalize features to have mean=0 and std=1
    scaler_i = StandardScaler()
    scaler_j = StandardScaler()
    
    feature_i_normalized = scaler_i.fit_transform(feature_i_values.reshape(-1, 1)).flatten()
    feature_j_normalized = scaler_j.fit_transform(feature_j_values.reshape(-1, 1)).flatten()
    
    if linear:
        X_features = np.column_stack([feature_i_normalized, feature_j_normalized])
    else:
        # Create interaction term with normalized features
        interaction_values = feature_i_normalized * feature_j_normalized
        X_features = np.column_stack([feature_i_normalized, feature_j_normalized, interaction_values])
    
    # Add constant term (bias/intercept)
    return sm.add_constant(X_features)

def _extract_model_statistics(model_res: sm.OLS, i: int, j: int, y_target: np.ndarray, linear: bool = False) -> dict:
    """
    Extract comprehensive statistics from fitted OLS model.

    Parameters:
        model_res: Fitted statsmodels OLS result
        i, j: Feature indices
        y_target: Target variable
        linear: Whether this is a linear model (no interaction term)

    Returns:
        dict: Dictionary containing model statistics (coefficients, p-values, etc.)
    """
    stats = {
        "feature_i_idx": i,
        "feature_j_idx": j,
        "r2_score": model_res.rsquared,
        "adjusted_r2": model_res.rsquared_adj,
        "beta0_intercept": model_res.params[0],
        "beta1_coef_i": model_res.params[1],
        "beta2_coef_j": model_res.params[2],
        "mean_shapint_value": np.mean(y_target),
        "std_shapint_value": np.std(y_target),
        "se_beta0": model_res.bse[0],
        "se_beta1": model_res.bse[1],
        "se_beta2": model_res.bse[2],
        "t_beta0": model_res.tvalues[0],
        "t_beta1": model_res.tvalues[1],
        "t_beta2": model_res.tvalues[2],
        "p_beta0": model_res.pvalues[0],
        "p_beta1": model_res.pvalues[1],
        "p_beta2": model_res.pvalues[2],
        "beta0_significant": model_res.pvalues[0] < 0.05,
        "beta1_significant": model_res.pvalues[1] < 0.05,
        "beta2_significant": model_res.pvalues[2] < 0.05,
        "f_statistic": model_res.fvalue,
        "f_pvalue": model_res.f_pvalue,
    }
    
    # Add interaction term statistics only if not linear model
    if not linear and len(model_res.params) > 3:
        stats.update({
            "beta3_coef_interaction": model_res.params[3],
            "se_beta3": model_res.bse[3],
            "t_beta3": model_res.tvalues[3],
            "p_beta3": model_res.pvalues[3],
            "beta3_significant": model_res.pvalues[3] < 0.05,
        })
    
    return stats


def _fit_single_interaction_model(shap_interaction_values, X, i, j):
    """
    Fit OLS model for a single feature pair interaction.

    Parameters:
        shap_interaction_values: SHAP interaction values array
        X: Input feature matrix
        i, j: Feature indices

    Returns:
        tuple: (model_result, y_target) if successful, None if failed
    """
    try:
        # Create target and feature matrix
        y_target = _create_target_from_shap_interactions(shap_interaction_values, i, j)
        X_features = _create_feature_matrix(X, i, j)

        # Fit OLS model
        model_res = sm.OLS(y_target, X_features).fit()
        return model_res, y_target
    except Exception:
        return None

def _fit_single_linear_model(shap_interaction_values, X, i, j):
    """
    Fit OLS model for a single feature pair without interaction term.
    
    Parameters:
        shap_interaction_values: SHAP interaction values array
        X: Input feature matrix
        i, j: Feature indices

    Returns:
        tuple: (model_result, y_target) if successful, None if failed
    """
    try:
        y_target = _create_target_from_shap_interactions(shap_interaction_values, i, j)
        X_features = _create_feature_matrix(X, i, j, linear=True)
        model_res = sm.OLS(y_target, X_features).fit()
        return model_res, y_target
    except Exception:
        return None


def get_feature_names(X):
    """
    Get feature names from X if available (e.g., pandas DataFrame), otherwise generate default names.
    
    Parameters:
        X: Input data (numpy array or pandas DataFrame)
    
    Returns:
        feature_names: List of feature names
        n_features: Number of features
    """
    if hasattr(X, 'columns'):
        feature_names = list(X.columns)
    else:
        feature_names = [f"x_{i}" for i in range(X.shape[1])]
    n_features = len(feature_names)
    return feature_names, n_features

def get_moderator_from_int_values(shap_interaction_values, feature_names, X, model_type='both'):
    """
    Fit moderators for SHAP interaction values for all feature pairs.

    Parameters:
        shap_interaction_values: SHAP interaction values array (n_samples, n_features, n_features)
                                 If a list (multiclass), pass the appropriate element (or let caller handle).
        feature_names: List of feature names
        X: Input feature matrix (numpy array or pandas DataFrame)
        model_type: 'interaction', 'linear', or 'both' - which models to fit

    Returns:
        If model_type='both': tuple of (interaction_results_df, linear_results_df)
        Otherwise: single pandas DataFrame with regression results
    """
    # Ensure the number of features matches between SHAP values and feature names
    assert shap_interaction_values.shape[1] == len(feature_names), "Mismatch in number of features"
    n_features = len(feature_names)

    interaction_results = []
    linear_results = []

    for i in range(n_features):
        for j in range(i + 1, n_features):
            
            # Fit interaction model if requested
            if model_type in ['interaction', 'both']:
                fit_result = _fit_single_interaction_model(shap_interaction_values, X, i, j)
                if fit_result is not None:
                    model_res, y_target = fit_result
                    stats_dict = _extract_model_statistics(model_res, i, j, y_target, linear=False)
                    stats_dict.update({
                        "feature_i": feature_names[i],
                        "feature_j": feature_names[j]
                    })
                    interaction_results.append(stats_dict)
            
            # Fit linear model if requested
            if model_type in ['linear', 'both']:
                fit_result = _fit_single_linear_model(shap_interaction_values, X, i, j)
                if fit_result is not None:
                    model_res, y_target = fit_result
                    stats_dict = _extract_model_statistics(model_res, i, j, y_target, linear=True)
                    stats_dict.update({
                        "feature_i": feature_names[i],
                        "feature_j": feature_names[j]
                    })
                    linear_results.append(stats_dict)

    if model_type == 'both':
        return pd.DataFrame(interaction_results), pd.DataFrame(linear_results)
    elif model_type == 'interaction':
        return pd.DataFrame(interaction_results)
    else:  # linear
        return pd.DataFrame(linear_results)


def get_shap_moderator(model, X, model_type='both'):
    """
    Compute SHAP interaction values and fit moderator regressions for all feature pairs.

    Parameters:
        model: trained tree-based model (used by get_shap_values)
        X: numpy array or pandas DataFrame of input features
        model_type: 'interaction', 'linear', or 'both' - which models to fit

    Returns:
        If model_type='both': tuple of (interaction_results_df, linear_results_df)
        Otherwise: single pandas DataFrame with regression results
    """
    # Compute SHAP interaction values
    shap_values = get_shap_values(model, X)
    if isinstance(shap_values, list):
        shap_values = shap_values[0]  # Handle multiclass case

    # Get feature names and fit moderators
    feature_names, _ = get_feature_names(X)
    return get_moderator_from_int_values(shap_values, feature_names, X, model_type)


def compare_interaction_vs_linear(interaction_df, linear_df):
    """
    Compare interaction and linear models to find pairs where interaction model has higher R².
    Ranking is now based on absolute value of beta3 instead of R² improvement.
    
    Parameters:
        interaction_df: DataFrame with interaction model results
        linear_df: DataFrame with linear model results
        
    Returns:
        comparison_df: DataFrame with pairs where interaction R² > linear R², ranked by |beta3|
    """
    # Merge dataframes on feature indices
    merged = pd.merge(
        interaction_df[['feature_i_idx', 'feature_j_idx', 'feature_i', 'feature_j', 
                       'r2_score', 'beta3_coef_interaction', 'p_beta3']],
        linear_df[['feature_i_idx', 'feature_j_idx', 'r2_score']],
        on=['feature_i_idx', 'feature_j_idx'],
        suffixes=('_interaction', '_linear')
    )
    
    # Filter for pairs where interaction R² > linear R²
    better_interactions = merged[merged['r2_score_interaction'] > merged['r2_score_linear']].copy()
    
    # Calculate R² improvement (keep for reference)
    better_interactions['r2_improvement'] = (
        better_interactions['r2_score_interaction'] - better_interactions['r2_score_linear']
    )
    
    # Calculate absolute value of beta3 for ranking
    better_interactions['abs_beta3'] = np.abs(better_interactions['beta3_coef_interaction'])
    
    # Rename columns for clarity
    result_df = better_interactions[[
        'feature_i_idx', 'feature_j_idx', 'feature_i', 'feature_j',
        'r2_score_linear', 'r2_score_interaction', 'r2_improvement',
        'beta3_coef_interaction', 'abs_beta3', 'p_beta3'
    ]].rename(columns={
        'feature_i_idx': 'i',
        'feature_j_idx': 'j',
        'r2_score_linear': 'r2_linear',
        'r2_score_interaction': 'r2_interaction',
        'p_beta3': 'beta3_pvalue'
    })
    
    # Sort by absolute value of beta3 (descending) instead of R² improvement
    result_df = result_df.sort_values('abs_beta3', ascending=False).reset_index(drop=True)
    
    return result_df


def get_improved_interactions(model, X):
    """
    Convenience function to get interaction pairs that benefit from interaction terms.
    
    Parameters:
        model: trained tree-based model
        X: numpy array or pandas DataFrame of input features
        
    Returns:
        comparison_df: DataFrame with pairs where interaction R² > linear R²
    """
    # Get both model results
    interaction_df, linear_df = get_shap_moderator(model, X, model_type='both')
    
    # Compare and return improved interactions
    return compare_interaction_vs_linear(interaction_df, linear_df)


