"""
XGBoost comparison experiment: Full features vs HKAN retained features (5-fold cross-validation)
Testing the effectiveness of HKAN feature selection
"""

import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.metrics import accuracy_score, roc_auc_score
import warnings

from config import get_config
from data_loader import set_seed

warnings.filterwarnings('ignore')


def load_heart_data(data_path="heart_disease.csv"):
    """Load heart disease data from CSV file"""
    # Check if data file exists
    import os
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Data file not found: {data_path}")

    data = pd.read_csv(data_path)

    # Features and labels
    feature_names = data.columns[:-1].tolist()  # All except last column 'target'
    X = data[feature_names].values
    y = data['target'].values
    y = (y > 0).astype(int)  # Binary classification: disease(1) vs no disease(0)

    return X, y, feature_names


def train_xgboost_fold(X_train, X_test, y_train, y_test, verbose=False):
    """Train XGBoost for single fold"""
    # Simplified parameter grid for faster search
    param_grid = {
        'max_depth': [3, 4, 5],
        'learning_rate': [0.05, 0.1],
        'n_estimators': [100, 150],
        'subsample': [0.8],
        'colsample_bytree': [0.8],
        'gamma': [0, 0.1],
        'min_child_weight': [1]
    }

    # Base model
    xgb_model = xgb.XGBClassifier(
        objective='binary:logistic',
        use_label_encoder=False,
        eval_metric='auc',
        random_state=42
    )

    # Grid search with 3-fold internal cross-validation
    grid_search = GridSearchCV(
        xgb_model,
        param_grid,
        cv=3,
        scoring='roc_auc',
        n_jobs=-1,
        verbose=0
    )

    if verbose:
        print("      Starting grid search...")

    grid_search.fit(X_train, y_train)

    # Best model
    best_model = grid_search.best_estimator_

    # Predictions
    test_pred = best_model.predict(X_test)
    test_proba = best_model.predict_proba(X_test)[:, 1]

    # Evaluation
    test_acc = accuracy_score(y_test, test_pred)
    test_auc = roc_auc_score(y_test, test_proba)

    if verbose:
        print(f"      Best params: max_depth={grid_search.best_params_['max_depth']}, "
              f"lr={grid_search.best_params_['learning_rate']}")

    return test_auc, test_acc, best_model


def kfold_evaluation(X, y, feature_indices=None, feature_names=None, n_splits=5, verbose=False):
    """5-fold cross-validation evaluation"""
    if feature_indices is not None:
        X = X[:, feature_indices]
        if feature_names is not None:
            feature_names = [feature_names[i] for i in feature_indices]

    kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    auc_scores = []
    acc_scores = []
    feature_importances = []

    for fold, (train_idx, test_idx) in enumerate(kfold.split(X, y)):
        if verbose:
            print(f"    Fold {fold+1}/{n_splits}")

        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        # Note: XGBoost handles feature scaling internally, but we can still standardize for consistency
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)

        # Train XGBoost
        fold_auc, fold_acc, fold_model = train_xgboost_fold(
            X_train, X_test, y_train, y_test, verbose=verbose
        )

        auc_scores.append(fold_auc)
        acc_scores.append(fold_acc)

        # Store feature importance
        if hasattr(fold_model, 'feature_importances_'):
            feature_importances.append(fold_model.feature_importances_)

        if verbose:
            print(f"        Fold {fold+1} results: AUC={fold_auc:.4f}, ACC={fold_acc:.4f}")

    # Average feature importance across folds
    if feature_importances and feature_names:
        avg_importance = np.mean(feature_importances, axis=0)
        feature_importance_dict = dict(zip(feature_names, avg_importance))
        # Sort by importance
        sorted_importance = sorted(feature_importance_dict.items(), key=lambda x: x[1], reverse=True)
        if verbose:
            print(f"    Average feature importance:")
            for feat, imp in sorted_importance[:5]:  # Top 5 features
                print(f"      {feat}: {imp:.4f}")

    return np.array(auc_scores), np.array(acc_scores)


def main():
    print("=" * 80)
    print("XGBoost Comparison Experiment: Full Features vs HKAN Selected Features")
    print("5-Fold Cross-Validation on Heart Disease Dataset")
    print("=" * 80)

    # Get configuration
    config = get_config()
    set_seed(config.seed)

    # Load data
    X, y, feature_names = load_heart_data(config.data_path)
    print(f"Dataset loaded: {X.shape[0]} samples, {X.shape[1]} features")
    print(f"Class distribution: {np.bincount(y)}")

    # EA-discovered best feature groups (from EA.py results)
    ea_best_features = [
        'cp', 'restecg', 'thalach', 'exang', 'ca',    # group_0
        'fbs', 'thalach', 'oldpeak', 'ca',           # group_1
        'age', 'trestbps', 'chol', 'slope', 'thal',  # group_2
        'sex', 'slope', 'ca'                         # group_4
    ]

    # Convert to unique feature set (remove duplicates)
    ea_best_features = list(set(ea_best_features))
    ea_feature_indices = [i for i, name in enumerate(feature_names) if name in ea_best_features]

    print(f"\nFeature comparison:")
    print(f"  All features ({len(feature_names)}): {feature_names}")
    print(f"  HKAN selected features ({len(ea_best_features)}): {ea_best_features}")

    print(f"\nRunning 5-fold cross-validation...")

    # Experiment 1: All features
    print(f"\n--- Experiment 1: All Features XGBoost ---")
    all_auc_scores, all_acc_scores = kfold_evaluation(
        X, y, feature_indices=None, feature_names=feature_names,
        n_splits=5, verbose=True
    )

    # Experiment 2: HKAN selected features
    print(f"\n--- Experiment 2: HKAN Selected Features XGBoost ---")
    selected_auc_scores, selected_acc_scores = kfold_evaluation(
        X, y, feature_indices=ea_feature_indices, feature_names=feature_names,
        n_splits=5, verbose=True
    )

    # Results summary
    print(f"\n" + "=" * 80)
    print("RESULTS SUMMARY")
    print("=" * 80)

    print(f"\nAll Features XGBoost:")
    print(f"  AUC: {all_auc_scores.mean():.4f} +/- {all_auc_scores.std():.4f}")
    print(f"  ACC: {all_acc_scores.mean():.4f} +/- {all_acc_scores.std():.4f}")
    print(f"  Individual AUC scores: {[f'{score:.4f}' for score in all_auc_scores]}")

    print(f"\nHKAN Selected Features XGBoost:")
    print(f"  AUC: {selected_auc_scores.mean():.4f} +/- {selected_auc_scores.std():.4f}")
    print(f"  ACC: {selected_acc_scores.mean():.4f} +/- {selected_acc_scores.std():.4f}")
    print(f"  Individual AUC scores: {[f'{score:.4f}' for score in selected_auc_scores]}")

    # Statistical comparison
    auc_improvement = selected_auc_scores.mean() - all_auc_scores.mean()
    acc_improvement = selected_acc_scores.mean() - all_acc_scores.mean()

    print(f"\nComparison:")
    print(f"  AUC improvement: {auc_improvement:+.4f}")
    print(f"  ACC improvement: {acc_improvement:+.4f}")
    print(f"  Feature reduction: {len(feature_names)} -> {len(ea_best_features)} "
          f"({(1 - len(ea_best_features)/len(feature_names))*100:.1f}% reduction)")

    if auc_improvement > 0:
        print(f"  [SUCCESS] HKAN feature selection improves performance")
    else:
        print(f"  [INFO] HKAN feature selection does not improve performance")

    # Additional analysis
    print(f"\nAdditional Analysis:")

    # Perform statistical significance test (Wilcoxon signed-rank test)
    from scipy.stats import wilcoxon
    try:
        auc_stat, auc_p_value = wilcoxon(selected_auc_scores - all_auc_scores)
        print(f"  AUC difference statistical significance: p-value = {auc_p_value:.4f}")
        if auc_p_value < 0.05:
            print(f"  [SUCCESS] AUC improvement is statistically significant (p < 0.05)")
        else:
            print(f"  ⚠️  AUC improvement is not statistically significant (p >= 0.05)")
    except Exception:
        print(f"  Statistical significance test could not be performed")

    # Save results
    results = {
        'all_features': {
            'auc_mean': float(all_auc_scores.mean()),
            'auc_std': float(all_auc_scores.std()),
            'acc_mean': float(all_acc_scores.mean()),
            'acc_std': float(all_acc_scores.std()),
            'auc_scores': all_auc_scores.tolist(),
            'acc_scores': all_acc_scores.tolist(),
            'num_features': len(feature_names)
        },
        'hkan_selected': {
            'auc_mean': float(selected_auc_scores.mean()),
            'auc_std': float(selected_auc_scores.std()),
            'acc_mean': float(selected_acc_scores.mean()),
            'acc_std': float(selected_acc_scores.std()),
            'auc_scores': selected_auc_scores.tolist(),
            'acc_scores': selected_acc_scores.tolist(),
            'num_features': len(ea_best_features),
            'selected_features': ea_best_features
        },
        'comparison': {
            'auc_improvement': float(auc_improvement),
            'acc_improvement': float(acc_improvement),
            'feature_reduction_ratio': float((len(feature_names) - len(ea_best_features)) / len(feature_names)),
            'statistical_significance': {
                'auc_p_value': float(auc_p_value) if 'auc_p_value' in locals() else None
            }
        }
    }

    # Save to JSON file
    import json
    import os
    from datetime import datetime

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_file = os.path.join(config.results_dir, f"xgb_comparison_kfold_results_{timestamp}.json")

    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\nResults saved to: {results_file}")
    print("=" * 80)


if __name__ == "__main__":
    main()