"""
Bayesian optimization for Mutual Information Group HKAN on Heart Disease dataset
Uses mutual information-based feature grouping for hierarchical KAN
"""

import pandas as pd
import numpy as np
from sklearn.feature_selection import mutual_info_classif
import torch
import torch.nn as nn
import torch.optim as optim
import time
import warnings
import optuna
from optuna.samplers import TPESampler
import json
import os
from datetime import datetime

from config import get_config
from data_loader import set_seed, load_heart_disease_data, preprocess_features_by_groups
from models import HKANClassification, calculate_factor_quality_score
from training_utils import train_hkan_model

warnings.filterwarnings('ignore')


def create_mi_groups(X_train, y_train, feature_names, num_groups=4):
    """Create feature groups based on mutual information"""
    print("\nCalculating feature mutual information...")

    # Calculate mutual information between each feature and target
    mi_scores = mutual_info_classif(X_train, y_train, random_state=42)

    # Create feature index and score pairs
    feature_mi_pairs = [(i, score) for i, score in enumerate(mi_scores)]

    # Sort by mutual information scores
    feature_mi_pairs.sort(key=lambda x: x[1], reverse=True)

    # Print mutual information scores
    print("\nFeature mutual information scores (sorted):")
    for idx, score in feature_mi_pairs:
        print(f"  {feature_names[idx]}: {score:.4f}")

    # Assign features to groups using round-robin to ensure balance
    groups = [[] for _ in range(num_groups)]

    # Round-robin assignment: highest MI to group 0, second highest to group 1, etc.
    for i, (feature_idx, mi_score) in enumerate(feature_mi_pairs):
        group_idx = i % num_groups
        groups[group_idx].append(feature_idx)

    # Convert to dictionary format
    group_dict = {}
    for i, group in enumerate(groups):
        group_dict[f"group_{i}"] = sorted(group)

    # Print grouping results
    print("\nMutual information grouping results:")
    for group_name, indices in group_dict.items():
        group_features = [feature_names[idx] for idx in indices]
        avg_mi = np.mean([mi_scores[idx] for idx in indices])
        print(f"  {group_name}: {group_features}")
        print(f"    Average MI: {avg_mi:.4f}")

    return group_dict


def train_and_evaluate_mi_group_hkan(params, X_train, X_val, X_test, y_train, y_val, y_test, feature_names, device):
    """Train and evaluate mutual information group HKAN model"""

    # Set random seed
    set_seed(params.get('seed', 42))

    # Create mutual information groups - fixed to 4 groups
    num_groups = 4
    group_dict = create_mi_groups(X_train, y_train, feature_names, num_groups)

    # Preprocess features
    train_data, val_data, test_data, group_feature_info = preprocess_features_by_groups(
        X_train, X_val, X_test, group_dict, device
    )

    # Create config object from params
    config = get_config()
    for key, value in params.items():
        if hasattr(config, key):
            setattr(config, key, value)

    try:
        # Create model
        model = HKANClassification(group_feature_info, config, device).to(device)

        # Train model
        results = train_hkan_model(
            model, train_data, val_data, test_data,
            y_train, y_val, y_test, config, device
        )

        # Clean up memory
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return {
            'test_acc': results['test_acc'],
            'test_auc': results['test_auc'],
            'test_f1': results['test_f1'],
            'best_val_auc': results['best_val_auc'],
            'final_fqs': results.get('final_fqs', 0.5),
            'total_params': results['total_params'],
            'epochs_trained': results['epochs_trained'],
            'group_dict': group_dict
        }

    except Exception as e:
        print(f"MI Group HKAN training failed: {str(e)}")
        return {
            'test_acc': 0.0,
            'test_auc': 0.0,
            'test_f1': 0.0,
            'best_val_auc': 0.0,
            'final_fqs': 0.0,
            'total_params': 0,
            'epochs_trained': 0,
            'group_dict': {}
        }


def objective(trial, X_train, X_val, X_test, y_train, y_val, y_test, feature_names, device):
    """Optuna optimization objective function - Mutual Information Group HKAN"""

    # Hyperparameter search space
    params = {
        # KAN core parameters
        'kan_grid': trial.suggest_int('kan_grid', 3, 20),

        # Sub-KAN hidden layer size (for each group's sub-network)
        # Note: actual size will be limited to 3 to 2*input_dim+1 based on input dimension
        'kan_hidden_multiplier': trial.suggest_int('kan_hidden_multiplier', 1, 3),

        # Fusion layer parameters (single layer structure)
        'fusion_hidden_layers': [trial.suggest_int('fusion_hidden_dim', 4, 16)],

        # Optimizer parameters - using AdamW
        'learning_rate': trial.suggest_float('learning_rate', 1e-4, 1e-1, log=True),
        'weight_decay': trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True),

        # Learning rate scheduler
        'scheduler_type': trial.suggest_categorical('scheduler_type', ['None', 'ReduceLROnPlateau', 'CosineAnnealingLR']),

        # KAN regularization parameters
        'kan_reg_lamb_l1': trial.suggest_float('kan_reg_lamb_l1', 1e-5, 1e-2, log=True),
        'kan_reg_lamb_entropy': trial.suggest_float('kan_reg_lamb_entropy', 1e-5, 1e-2, log=True),

        # Factor regularization parameters
        'lambda_decorrelation': trial.suggest_float('lambda_decorrelation', 1e-4, 1e-1, log=True),
        'lambda_sparsity': trial.suggest_float('lambda_sparsity', 1e-5, 1e-2, log=True),
        'lambda_stability': trial.suggest_float('lambda_stability', 1e-4, 1e-1, log=True),

        # Gradient clipping
        'grad_clip': trial.suggest_categorical('grad_clip', [0, 0.5, 1.0, 2.0, 5.0]),

        # Training parameters
        'epochs': 1000,  # Fixed max epochs
        'patience': 60,  # Fixed early stopping patience

        'seed': 42
    }

    # Train and evaluate
    try:
        results = train_and_evaluate_mi_group_hkan(
            params, X_train, X_val, X_test, y_train, y_val, y_test, feature_names, device
        )

        # Record additional information
        trial.set_user_attr('test_acc', results['test_acc'])
        trial.set_user_attr('test_f1', results['test_f1'])
        trial.set_user_attr('best_val_auc', results['best_val_auc'])
        trial.set_user_attr('final_fqs', results['final_fqs'])
        trial.set_user_attr('epochs_trained', results['epochs_trained'])
        trial.set_user_attr('total_params', results['total_params'])

        # Return test AUC as optimization target
        return results['test_auc']

    except Exception as e:
        print(f"Trial failed: {e}")
        return 0.0


def main():
    print("=" * 80)
    print("MI (Mutual Information) Group HKAN - Bayesian Optimization")
    print("Mutual Information Group HKAN Bayesian Optimization")
    print("=" * 80)

    # Get configuration
    config = get_config()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")

    # Set random seed
    set_seed(config.seed)

    # Load data
    X_train, X_val, X_test, y_train, y_val, y_test, feature_names = load_heart_disease_data(
        config.data_path, config.test_size, config.val_size_from_remaining, config.seed
    )

    # Create results directory
    os.makedirs(os.path.join(config.results_dir, "mi_group_hkan_optuna_results"), exist_ok=True)

    # Set up Optuna
    study = optuna.create_study(
        direction='maximize',
        sampler=TPESampler(seed=config.seed),
        study_name='mi_group_hkan_heart_disease'
    )

    # Run optimization
    n_trials = 100
    print(f"\nStarting Bayesian optimization (MI Group HKAN) with {n_trials} trials...")

    def objective_wrapper(trial):
        return objective(trial, X_train, X_val, X_test, y_train, y_val, y_test, feature_names, device)

    study.optimize(objective_wrapper, n_trials=n_trials, show_progress_bar=True)

    # Get best results
    print("\n" + "=" * 80)
    print("Optimization Results (MI Group HKAN)")
    print("=" * 80)

    best_trial = study.best_trial

    print(f"\nBest AUC: {best_trial.value:.4f}")
    print(f"\nBest parameters:")
    print(f"  Grouping strategy: Mutual Information Grouping")
    print(f"  Number of groups: 4 (fixed)")
    for key, value in best_trial.params.items():
        print(f"  {key}: {value}")

    print(f"\nBest trial detailed metrics:")
    print(f"  Test AUC: {best_trial.value:.4f}")
    print(f"  Test Accuracy: {best_trial.user_attrs.get('test_acc', 'N/A'):.4f}")
    print(f"  Test F1 Score: {best_trial.user_attrs.get('test_f1', 'N/A'):.4f}")
    print(f"  Best Validation AUC: {best_trial.user_attrs.get('best_val_auc', 'N/A'):.4f}")
    print(f"  Final FQS: {best_trial.user_attrs.get('final_fqs', 'N/A'):.4f}")
    print(f"  Epochs Trained: {best_trial.user_attrs.get('epochs_trained', 'N/A')}")
    print(f"  Model Parameters: {best_trial.user_attrs.get('total_params', 'N/A'):,}")

    # Save results
    results = {
        'model_type': 'MI Group HKAN',
        'best_auc': best_trial.value,
        'best_params': best_trial.params,
        'best_metrics': {
            'test_acc': best_trial.user_attrs.get('test_acc'),
            'test_auc': best_trial.value,
            'test_f1': best_trial.user_attrs.get('test_f1'),
            'best_val_auc': best_trial.user_attrs.get('best_val_auc'),
            'final_fqs': best_trial.user_attrs.get('final_fqs'),
            'epochs_trained': best_trial.user_attrs.get('epochs_trained'),
            'total_params': best_trial.user_attrs.get('total_params')
        },
        'grouping_strategy': 'mutual_information',
        'num_groups': 4,
        'study_stats': {
            'n_trials': len(study.trials),
            'n_complete': len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]),
            'n_failed': len([t for t in study.trials if t.state == optuna.trial.TrialState.FAIL])
        },
        'timestamp': datetime.now().isoformat()
    }

    results_file = os.path.join(config.results_dir, 'mi_group_hkan_optuna_results', 'best_results.json')
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)

    # Save all trials history
    trials_df = study.trials_dataframe()
    trials_file = os.path.join(config.results_dir, 'mi_group_hkan_optuna_results', 'all_trials.csv')
    trials_df.to_csv(trials_file, index=False)

    print(f"\nResults saved to: {results_file}")
    print(f"All trials saved to: {trials_file}")

    # Train final model with best parameters
    print("\n" + "=" * 80)
    print("Training final model with best parameters (MI Group HKAN)")
    print("=" * 80)

    final_params = best_trial.params.copy()
    final_results = train_and_evaluate_mi_group_hkan(
        final_params, X_train, X_val, X_test, y_train, y_val, y_test, feature_names, device
    )

    print(f"\nFinal model performance (MI Group HKAN):")
    print(f"  Grouping strategy: Mutual Information Grouping")
    print(f"  Number of groups: 4 (fixed)")
    print(f"  Group configuration: {final_results['group_dict']}")
    print(f"  Validation - Best AUC: {final_results['best_val_auc']:.4f}")
    print(f"  Test Set - ACC: {final_results['test_acc']:.4f}, AUC: {final_results['test_auc']:.4f}, F1: {final_results['test_f1']:.4f}")
    print(f"  Model Parameters: {final_results['total_params']:,}")

    # Save final results
    final_results['best_params'] = final_params
    final_results_file = os.path.join(config.results_dir, 'mi_group_hkan_optuna_results', 'final_model_results.json')
    with open(final_results_file, 'w') as f:
        json.dump(final_results, f, indent=2)

    print(f"\nAll results saved to {os.path.join(config.results_dir, 'mi_group_hkan_optuna_results')} directory")
    print("=" * 80)
    print("MI Group HKAN optimization completed!")
    print("=" * 80)


if __name__ == "__main__":
    main()