"""
MLP comparison experiment: Full features vs HKAN retained features (5-fold cross-validation)
Testing the effectiveness of HKAN feature selection
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score
import warnings

from config import get_config
from data_loader import set_seed

warnings.filterwarnings('ignore')


def load_heart_data(data_path="heart_disease.csv"):
    """Load heart disease data from CSV file"""
    # Check if data file exists
    import os
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Data file not found: {data_path}")

    data = pd.read_csv(data_path)

    # Features and labels
    feature_names = data.columns[:-1].tolist()  # All except last column 'target'
    X = data[feature_names].values
    y = data['target'].values
    y = (y > 0).astype(int)  # Binary classification: disease(1) vs no disease(0)

    return X, y, feature_names


class MLP(nn.Module):
    """Multi-Layer Perceptron model"""

    def __init__(self, input_dim, hidden_dims=[64, 32, 16], dropout=0.2):
        super(MLP, self).__init__()
        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim

        layers.append(nn.Linear(prev_dim, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


def train_mlp_fold(X_train, X_test, y_train, y_test, input_dim, epochs=200, lr=1e-3, verbose=False):
    """Train MLP for single fold"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Convert to tensors
    X_train_tensor = torch.FloatTensor(X_train).to(device)
    X_test_tensor = torch.FloatTensor(X_test).to(device)
    y_train_tensor = torch.FloatTensor(y_train).to(device)
    y_test_tensor = torch.FloatTensor(y_test).to(device)

    # Create data loader
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    # Initialize model
    model = MLP(input_dim=input_dim).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, factor=0.5)

    best_auc = 0
    best_acc = 0

    # Training loop
    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for batch_x, batch_y in train_loader:
            optimizer.zero_grad()
            outputs = model(batch_x).squeeze()
            loss = criterion(outputs, batch_y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item()

        # Evaluation
        if (epoch + 1) % 50 == 0 or epoch == epochs - 1:
            model.eval()
            with torch.no_grad():
                test_outputs = model(X_test_tensor).squeeze()
                test_probs = torch.sigmoid(test_outputs).cpu().numpy()
                test_preds = (test_probs > 0.5).astype(int)
                test_auc = roc_auc_score(y_test, test_probs)
                test_acc = accuracy_score(y_test, test_preds)

                scheduler.step(train_loss / len(train_loader))

                if test_auc > best_auc:
                    best_auc = test_auc
                    best_acc = test_acc

                if verbose and (epoch + 1) % 50 == 0:
                    print(f"        Epoch {epoch+1:3d} | Test AUC: {test_auc:.4f}")

    # Clean up
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return best_auc, best_acc


def kfold_evaluation(X, y, feature_indices=None, feature_names=None, n_splits=5, verbose=False):
    """5-fold cross-validation evaluation"""
    if feature_indices is not None:
        X = X[:, feature_indices]
        if feature_names is not None:
            feature_names = [feature_names[i] for i in feature_indices]

    kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    auc_scores = []
    acc_scores = []

    for fold, (train_idx, test_idx) in enumerate(kfold.split(X, y)):
        if verbose:
            print(f"    Fold {fold+1}/{n_splits}")

        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        # Standardize features
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)

        # Train MLP
        fold_auc, fold_acc = train_mlp_fold(
            X_train, X_test, y_train, y_test,
            X_train.shape[1], epochs=150, lr=1e-3, verbose=verbose
        )

        auc_scores.append(fold_auc)
        acc_scores.append(fold_acc)

        if verbose:
            print(f"        Fold {fold+1} results: AUC={fold_auc:.4f}, ACC={fold_acc:.4f}")

    return np.array(auc_scores), np.array(acc_scores)


def main():
    print("=" * 80)
    print("MLP Comparison Experiment: Full Features vs HKAN Selected Features")
    print("5-Fold Cross-Validation on Heart Disease Dataset")
    print("=" * 80)

    # Get configuration
    config = get_config()
    set_seed(config.seed)

    # Load data
    X, y, feature_names = load_heart_data(config.data_path)
    print(f"Dataset loaded: {X.shape[0]} samples, {X.shape[1]} features")
    print(f"Class distribution: {np.bincount(y)}")

    # EA-discovered best feature groups (from EA.py results)
    ea_best_features = [
        'cp', 'restecg', 'thalach', 'exang', 'ca',    # group_0
        'fbs', 'thalach', 'oldpeak', 'ca',           # group_1
        'age', 'trestbps', 'chol', 'slope', 'thal',  # group_2
        'sex', 'slope', 'ca'                         # group_4
    ]

    # Convert to unique feature set (remove duplicates)
    ea_best_features = list(set(ea_best_features))
    ea_feature_indices = [i for i, name in enumerate(feature_names) if name in ea_best_features]

    print(f"\nFeature comparison:")
    print(f"  All features ({len(feature_names)}): {feature_names}")
    print(f"  HKAN selected features ({len(ea_best_features)}): {ea_best_features}")

    print(f"\nRunning 5-fold cross-validation...")

    # Experiment 1: All features
    print(f"\n--- Experiment 1: All Features MLP ---")
    all_auc_scores, all_acc_scores = kfold_evaluation(
        X, y, feature_indices=None, feature_names=feature_names,
        n_splits=5, verbose=True
    )

    # Experiment 2: HKAN selected features
    print(f"\n--- Experiment 2: HKAN Selected Features MLP ---")
    selected_auc_scores, selected_acc_scores = kfold_evaluation(
        X, y, feature_indices=ea_feature_indices, feature_names=feature_names,
        n_splits=5, verbose=True
    )

    # Results summary
    print(f"\n" + "=" * 80)
    print("RESULTS SUMMARY")
    print("=" * 80)

    print(f"\nAll Features MLP:")
    print(f"  AUC: {all_auc_scores.mean():.4f} +/- {all_auc_scores.std():.4f}")
    print(f"  ACC: {all_acc_scores.mean():.4f} +/- {all_acc_scores.std():.4f}")
    print(f"  Individual AUC scores: {[f'{score:.4f}' for score in all_auc_scores]}")

    print(f"\nHKAN Selected Features MLP:")
    print(f"  AUC: {selected_auc_scores.mean():.4f} +/- {selected_auc_scores.std():.4f}")
    print(f"  ACC: {selected_acc_scores.mean():.4f} +/- {selected_acc_scores.std():.4f}")
    print(f"  Individual AUC scores: {[f'{score:.4f}' for score in selected_auc_scores]}")

    # Statistical comparison
    auc_improvement = selected_auc_scores.mean() - all_auc_scores.mean()
    acc_improvement = selected_acc_scores.mean() - all_acc_scores.mean()

    print(f"\nComparison:")
    print(f"  AUC improvement: {auc_improvement:+.4f}")
    print(f"  ACC improvement: {acc_improvement:+.4f}")
    print(f"  Feature reduction: {len(feature_names)} -> {len(ea_best_features)} "
          f"({(1 - len(ea_best_features)/len(feature_names))*100:.1f}% reduction)")

    if auc_improvement > 0:
        print(f"  [SUCCESS] HKAN feature selection improves performance")
    else:
        print(f"  [INFO] HKAN feature selection does not improve performance")

    # Save results
    results = {
        'all_features': {
            'auc_mean': float(all_auc_scores.mean()),
            'auc_std': float(all_auc_scores.std()),
            'acc_mean': float(all_acc_scores.mean()),
            'acc_std': float(all_acc_scores.std()),
            'auc_scores': all_auc_scores.tolist(),
            'acc_scores': all_acc_scores.tolist(),
            'num_features': len(feature_names)
        },
        'hkan_selected': {
            'auc_mean': float(selected_auc_scores.mean()),
            'auc_std': float(selected_auc_scores.std()),
            'acc_mean': float(selected_acc_scores.mean()),
            'acc_std': float(selected_acc_scores.std()),
            'auc_scores': selected_auc_scores.tolist(),
            'acc_scores': selected_acc_scores.tolist(),
            'num_features': len(ea_best_features),
            'selected_features': ea_best_features
        },
        'comparison': {
            'auc_improvement': float(auc_improvement),
            'acc_improvement': float(acc_improvement),
            'feature_reduction_ratio': float((len(feature_names) - len(ea_best_features)) / len(feature_names))
        }
    }

    # Save to JSON file
    import json
    import os
    from datetime import datetime

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_file = os.path.join(config.results_dir, f"mlp_comparison_kfold_results_{timestamp}.json")

    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\nResults saved to: {results_file}")
    print("=" * 80)


if __name__ == "__main__":
    main()