"""
KAN comparison experiment: Full features vs HKAN retained features (5-fold cross-validation)
Testing the effectiveness of HKAN feature selection
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score
import warnings

from config import get_config
from data_loader import set_seed

# Import official KAN library
try:
    from kan import KAN
    print("Official KAN library imported successfully")
except ImportError:
    print("KAN library not found. Please install: pip install pykan")
    exit(1)

warnings.filterwarnings('ignore')


def load_heart_data(data_path="heart_disease.csv"):
    """Load heart disease data from CSV file"""
    # Check if data file exists
    import os
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Data file not found: {data_path}")

    data = pd.read_csv(data_path)

    # Features and labels
    feature_names = data.columns[:-1].tolist()  # All except last column 'target'
    X = data[feature_names].values
    y = data['target'].values
    y = (y > 0).astype(int)  # Binary classification: disease(1) vs no disease(0)

    return X, y, feature_names


def train_kan_fold(X_train, X_test, y_train, y_test, input_dim, verbose=False):
    """Train KAN for single fold"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Convert to tensors
    X_train_tensor = torch.FloatTensor(X_train).to(device)
    X_test_tensor = torch.FloatTensor(X_test).to(device)
    y_train_tensor = torch.FloatTensor(y_train).to(device)
    y_test_tensor = torch.FloatTensor(y_test).to(device)

    # Initialize KAN - using similar configuration to EA.py
    model = KAN(width=[input_dim, 8, 4, 1], k=3, device=device, seed=42)

    # Step 1: Initial training
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

    best_auc = 0
    best_acc = 0

    # Initial training
    for epoch in range(30):  # Reduced training epochs for speed
        model.train()
        optimizer.zero_grad()

        outputs = model(X_train_tensor)
        # Use KAN reg method
        reg_loss = model.reg('edge_forward_spline_n', 1e-3, 1e-3, 0.0, 0.0)
        loss = criterion(outputs.squeeze(), y_train_tensor) + reg_loss

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        if (epoch + 1) % 15 == 0 or epoch == 29:
            model.eval()
            with torch.no_grad():
                test_outputs = model(X_test_tensor)
                test_probs = torch.sigmoid(test_outputs).cpu().numpy().squeeze()
                test_preds = (test_probs > 0.5).astype(int)
                test_auc = roc_auc_score(y_test, test_probs)
                test_acc = accuracy_score(y_test, test_preds)

                if test_auc > best_auc:
                    best_auc = test_auc
                    best_acc = test_acc

                if verbose and (epoch + 1) % 15 == 0:
                    print(f"        Epoch {epoch+1:3d} | Test AUC: {test_auc:.4f}")

    # Step 2: Pruning
    model.eval()
    with torch.no_grad():
        _ = model(X_train_tensor)  # Update activations

    model = model.prune(edge_th=0.01)

    # Step 3: Fine-tuning
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)

    for epoch in range(20):  # Fine-tuning epochs
        model.train()
        optimizer.zero_grad()

        outputs = model(X_train_tensor)
        reg_loss = model.reg('edge_forward_spline_n', 1e-3, 1e-3, 0.0, 0.0)
        loss = criterion(outputs.squeeze(), y_train_tensor) + reg_loss

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        if epoch == 19:  # Final evaluation
            model.eval()
            with torch.no_grad():
                test_outputs = model(X_test_tensor)
                test_probs = torch.sigmoid(test_outputs).cpu().numpy().squeeze()
                test_preds = (test_probs > 0.5).astype(int)
                final_auc = roc_auc_score(y_test, test_probs)
                final_acc = accuracy_score(y_test, test_preds)

                if final_auc > best_auc:
                    best_auc = final_auc
                    best_acc = final_acc

    # Clean up
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return best_auc, best_acc


def kfold_evaluation(X, y, feature_indices=None, feature_names=None, n_splits=5, verbose=False):
    """5-fold cross-validation evaluation"""
    if feature_indices is not None:
        X = X[:, feature_indices]
        if feature_names is not None:
            feature_names = [feature_names[i] for i in feature_indices]

    kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    auc_scores = []
    acc_scores = []

    for fold, (train_idx, test_idx) in enumerate(kfold.split(X, y)):
        if verbose:
            print(f"    Fold {fold+1}/{n_splits}")

        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        # Standardize features
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)

        # Train KAN
        fold_auc, fold_acc = train_kan_fold(
            X_train, X_test, y_train, y_test,
            X_train.shape[1], verbose=verbose
        )

        auc_scores.append(fold_auc)
        acc_scores.append(fold_acc)

        if verbose:
            print(f"        Fold {fold+1} results: AUC={fold_auc:.4f}, ACC={fold_acc:.4f}")

    return np.array(auc_scores), np.array(acc_scores)


def main():
    print("=" * 80)
    print("KAN Comparison Experiment: Full Features vs HKAN Selected Features")
    print("5-Fold Cross-Validation on Heart Disease Dataset")
    print("=" * 80)

    # Get configuration
    config = get_config()
    set_seed(config.seed)

    # Load data
    X, y, feature_names = load_heart_data(config.data_path)
    print(f"Dataset loaded: {X.shape[0]} samples, {X.shape[1]} features")
    print(f"Class distribution: {np.bincount(y)}")

    # EA-discovered best feature groups (from EA.py results)
    ea_best_features = [
        'cp', 'restecg', 'thalach', 'exang', 'ca',    # group_0
        'fbs', 'thalach', 'oldpeak', 'ca',           # group_1
        'age', 'trestbps', 'chol', 'slope', 'thal',  # group_2
        'sex', 'slope', 'ca'                         # group_4
    ]

    # Convert to unique feature set (remove duplicates)
    ea_best_features = list(set(ea_best_features))
    ea_feature_indices = [i for i, name in enumerate(feature_names) if name in ea_best_features]

    print(f"\nFeature comparison:")
    print(f"  All features ({len(feature_names)}): {feature_names}")
    print(f"  HKAN selected features ({len(ea_best_features)}): {ea_best_features}")

    print(f"\nRunning 5-fold cross-validation...")

    # Experiment 1: All features
    print(f"\n--- Experiment 1: All Features KAN ---")
    all_auc_scores, all_acc_scores = kfold_evaluation(
        X, y, feature_indices=None, feature_names=feature_names,
        n_splits=5, verbose=True
    )

    # Experiment 2: HKAN selected features
    print(f"\n--- Experiment 2: HKAN Selected Features KAN ---")
    selected_auc_scores, selected_acc_scores = kfold_evaluation(
        X, y, feature_indices=ea_feature_indices, feature_names=feature_names,
        n_splits=5, verbose=True
    )

    # Results summary
    print(f"\n" + "=" * 80)
    print("RESULTS SUMMARY")
    print("=" * 80)

    print(f"\nAll Features KAN:")
    print(f"  AUC: {all_auc_scores.mean():.4f} +/- {all_auc_scores.std():.4f}")
    print(f"  ACC: {all_acc_scores.mean():.4f} +/- {all_acc_scores.std():.4f}")
    print(f"  Individual AUC scores: {[f'{score:.4f}' for score in all_auc_scores]}")

    print(f"\nHKAN Selected Features KAN:")
    print(f"  AUC: {selected_auc_scores.mean():.4f} +/- {selected_auc_scores.std():.4f}")
    print(f"  ACC: {selected_acc_scores.mean():.4f} +/- {selected_acc_scores.std():.4f}")
    print(f"  Individual AUC scores: {[f'{score:.4f}' for score in selected_auc_scores]}")

    # Statistical comparison
    auc_improvement = selected_auc_scores.mean() - all_auc_scores.mean()
    acc_improvement = selected_acc_scores.mean() - all_acc_scores.mean()

    print(f"\nComparison:")
    print(f"  AUC improvement: {auc_improvement:+.4f}")
    print(f"  ACC improvement: {acc_improvement:+.4f}")
    print(f"  Feature reduction: {len(feature_names)} -> {len(ea_best_features)} "
          f"({(1 - len(ea_best_features)/len(feature_names))*100:.1f}% reduction)")

    if auc_improvement > 0:
        print(f"  [SUCCESS] HKAN feature selection improves performance")
    else:
        print(f"  [INFO] HKAN feature selection does not improve performance")

    # Save results
    results = {
        'all_features': {
            'auc_mean': float(all_auc_scores.mean()),
            'auc_std': float(all_auc_scores.std()),
            'acc_mean': float(all_acc_scores.mean()),
            'acc_std': float(all_acc_scores.std()),
            'auc_scores': all_auc_scores.tolist(),
            'acc_scores': all_acc_scores.tolist(),
            'num_features': len(feature_names)
        },
        'hkan_selected': {
            'auc_mean': float(selected_auc_scores.mean()),
            'auc_std': float(selected_auc_scores.std()),
            'acc_mean': float(selected_acc_scores.mean()),
            'acc_std': float(selected_acc_scores.std()),
            'auc_scores': selected_auc_scores.tolist(),
            'acc_scores': selected_acc_scores.tolist(),
            'num_features': len(ea_best_features),
            'selected_features': ea_best_features
        },
        'comparison': {
            'auc_improvement': float(auc_improvement),
            'acc_improvement': float(acc_improvement),
            'feature_reduction_ratio': float((len(feature_names) - len(ea_best_features)) / len(feature_names))
        }
    }

    # Save to JSON file
    import json
    import os
    from datetime import datetime

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_file = os.path.join(config.results_dir, f"kan_comparison_kfold_results_{timestamp}.json")

    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\nResults saved to: {results_file}")
    print("=" * 80)


if __name__ == "__main__":
    main()