"""
Bayesian optimization for EA Best Group HKAN on Heart Disease dataset
Uses evolutionary algorithm discovered feature groups for hierarchical KAN
"""

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import time
import warnings
import optuna
from optuna.samplers import TPESampler
import json
import os
from datetime import datetime

from config import get_config
from data_loader import set_seed, load_heart_disease_data, preprocess_features_by_groups
from models import HKANClassification, calculate_factor_quality_score
from training_utils import train_hkan_model

warnings.filterwarnings('ignore')


def get_ea_best_groups():
    """Return EA-discovered best feature groups"""
    # Feature names to index mapping
    feature_names = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg',
                     'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal']

    feature_to_idx = {name: idx for idx, name in enumerate(feature_names)}

    # EA-discovered best groups
    ea_best_groups = {
        'group_0': ['cp', 'restecg', 'thalach', 'exang', 'ca'],
        'group_1': ['fbs', 'thalach', 'oldpeak', 'ca'],
        'group_2': ['age', 'trestbps', 'chol', 'slope', 'thal'],
        'group_4': ['sex', 'slope', 'ca']
    }

    # Convert to indices
    group_indices = {}
    for group_name, feature_list in ea_best_groups.items():
        group_indices[group_name] = [feature_to_idx[fname] for fname in feature_list if fname in feature_to_idx]

    return group_indices


def train_and_evaluate_ea_hkan(params, X_train, X_val, X_test, y_train, y_val, y_test, feature_names, device):
    """Train and evaluate EA best group HKAN model"""

    # Set random seed
    set_seed(params.get('seed', 42))

    # Get EA best groups
    group_indices = get_ea_best_groups()

    # Preprocess features
    train_data, val_data, test_data, group_feature_info = preprocess_features_by_groups(
        X_train, X_val, X_test, group_indices, device
    )

    try:
        # Create config object from params
        config = get_config()
        for key, value in params.items():
            if hasattr(config, key):
                setattr(config, key, value)

        # Create model
        model = HKANClassification(group_feature_info, config, device).to(device)

        # Train model
        results = train_hkan_model(
            model, train_data, val_data, test_data,
            y_train, y_val, y_test, config, device
        )

        # Clean up memory
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return {
            'test_acc': results['test_acc'],
            'test_auc': results['test_auc'],
            'test_f1': results['test_f1'],
            'best_val_auc': results['best_val_auc'],
            'best_fqs': results.get('final_fqs', 0.5),
            'total_params': results['total_params']
        }

    except Exception as e:
        print(f"HKAN training failed: {str(e)}")
        return {
            'test_acc': 0.0,
            'test_auc': 0.0,
            'test_f1': 0.0,
            'best_val_auc': 0.0,
            'best_fqs': 0.0,
            'total_params': 0
        }


def objective(trial, X_train, X_val, X_test, y_train, y_val, y_test, feature_names, device):
    """Optuna optimization objective function - EA best group HKAN"""

    # Hyperparameter search space
    params = {
        # Learning parameters
        'learning_rate': trial.suggest_float('learning_rate', 1e-4, 1e-2, log=True),
        'weight_decay': trial.suggest_float('weight_decay', 1e-6, 1e-4, log=True),

        # KAN parameters
        'kan_grid': trial.suggest_int('kan_grid', 1, 8),
        'kan_hidden_multiplier': trial.suggest_int('kan_hidden_multiplier', 1, 2),

        # Fusion layer configuration
        'fusion_hidden_layers': [trial.suggest_int('fusion_hidden_dim', 1, 8)],

        # KAN regularization
        'kan_reg_lamb_l1': trial.suggest_float('kan_reg_lamb_l1', 1e-5, 1e-2, log=True),
        'kan_reg_lamb_entropy': trial.suggest_float('kan_reg_lamb_entropy', 1e-5, 1e-2, log=True),

        # Factor regularization
        'lambda_decorrelation': trial.suggest_float('lambda_decorrelation', 1e-4, 1e-1, log=True),
        'lambda_sparsity': trial.suggest_float('lambda_sparsity', 1e-5, 1e-2, log=True),
        'lambda_stability': trial.suggest_float('lambda_stability', 1e-4, 1e-1, log=True),

        # Training parameters
        'epochs': trial.suggest_int('epochs', 100, 300),
        'patience': trial.suggest_int('patience', 20, 50),

        'seed': 42
    }

    # Train and evaluate
    try:
        results = train_and_evaluate_ea_hkan(
            params, X_train, X_val, X_test, y_train, y_val, y_test, feature_names, device
        )

        # Record additional information
        trial.set_user_attr('test_acc', results['test_acc'])
        trial.set_user_attr('test_f1', results['test_f1'])
        trial.set_user_attr('best_val_auc', results['best_val_auc'])
        trial.set_user_attr('best_fqs', results['best_fqs'])
        trial.set_user_attr('total_params', results['total_params'])

        # Return test AUC as optimization target
        return results['test_auc']

    except Exception as e:
        print(f"Trial failed: {e}")
        return 0.0


def main():
    print("=" * 80)
    print("EA Best Group HKAN - Bayesian Optimization")
    print("Heart Disease Dataset (Using EA-discovered optimal feature groups)")
    print("=" * 80)

    # Get configuration
    config = get_config()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")

    # Set random seed
    set_seed(config.seed)

    # Load data
    X_train, X_val, X_test, y_train, y_val, y_test, feature_names = load_heart_disease_data(
        config.data_path, config.test_size, config.val_size_from_remaining, config.seed
    )

    # Display EA best groups
    ea_groups = get_ea_best_groups()
    print(f"\nEA-discovered best group configuration:")
    for group_name, indices in ea_groups.items():
        group_features = [feature_names[i] for i in indices]
        print(f"  {group_name}: {group_features}")

    # Create optuna study
    study = optuna.create_study(
        direction='maximize',
        sampler=TPESampler(seed=config.seed),
        study_name="ea_best_group_hkan_heart_disease"
    )

    # Set objective function with data
    objective_with_data = lambda trial: objective(
        trial, X_train, X_val, X_test, y_train, y_val, y_test, feature_names, device
    )

    # Bayesian optimization
    n_trials = 100
    print(f"\nStarting Bayesian optimization with {n_trials} trials...")

    study.optimize(objective_with_data, n_trials=n_trials, show_progress_bar=True)

    # Output results
    print("\n" + "=" * 80)
    print("Optimization Results")
    print("n" + "=" * 80)

    best_trial = study.best_trial
    print(f"\nBest AUC: {best_trial.value:.4f}")

    print(f"\nBest parameters:")
    for key, value in best_trial.params.items():
        print(f"  {key}: {value}")

    print(f"\nBest trial detailed metrics:")
    print(f"  Test AUC: {best_trial.value:.4f}")
    test_acc = best_trial.user_attrs.get('test_acc', 'N/A')
    print(f"  Test Accuracy: {test_acc:.4f}" if test_acc != 'N/A' else "  Test Accuracy: N/A")
    test_f1 = best_trial.user_attrs.get('test_f1', 'N/A')
    print(f"  Test F1 Score: {test_f1:.4f}" if test_f1 != 'N/A' else "  Test F1 Score: N/A")
    best_val_auc = best_trial.user_attrs.get('best_val_auc', 'N/A')
    print(f"  Best Validation AUC: {best_val_auc:.4f}" if best_val_auc != 'N/A' else "  Best Validation AUC: N/A")
    best_fqs = best_trial.user_attrs.get('best_fqs', 'N/A')
    print(f"  Best FQS: {best_fqs:.4f}" if best_fqs != 'N/A' else "  Best FQS: N/A")
    total_params = best_trial.user_attrs.get('total_params', 'N/A')
    print(f"  Model Parameters: {total_params:,}" if total_params != 'N/A' else "  Model Parameters: N/A")

    # Save results
    results = {
        'model_type': 'EA Best Group HKAN',
        'dataset': 'Heart Disease',
        'ea_best_groups': {k: [feature_names[i] for i in v] for k, v in ea_groups.items()},
        'best_auc': best_trial.value,
        'best_params': best_trial.params,
        'best_trial_attrs': best_trial.user_attrs,
        'study_stats': {
            'n_trials': len(study.trials),
            'n_complete_trials': len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]),
            'n_failed_trials': len([t for t in study.trials if t.state == optuna.trial.TrialState.FAIL])
        }
    }

    # Save to file
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = os.path.join(config.results_dir, f"ea_best_group_hkan_optuna_results_{timestamp}.json")

    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    print(f"\nResults saved to: {filename}")
    print(f"Total trials: {len(study.trials)}")
    print(f"Successful trials: {len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE])}")
    print(f"Failed trials: {len([t for t in study.trials if t.state == optuna.trial.TrialState.FAIL])}")

    print(f"\nFeature group configuration:")
    for group_name, indices in ea_groups.items():
        group_features = [feature_names[i] for i in indices]
        print(f"  {group_name}: {group_features}")


if __name__ == "__main__":
    main()