"""
Bayesian optimization for Pure KAN on Heart Disease dataset
Official KAN implementation for binary classification
"""

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import time
import warnings
import optuna
from optuna.samplers import TPESampler
import json
import os
from datetime import datetime

from config import get_config
from data_loader import set_seed, load_heart_disease_data
from models import PureKANClassifier
from training_utils import train_pure_kan_model

warnings.filterwarnings('ignore')


def train_and_evaluate_pure_kan(params, X_train, X_val, X_test, y_train, y_val, y_test, device):
    """Train and evaluate pure KAN model"""

    # Set random seed
    set_seed(params.get('seed', 42))

    # Create config object from params
    config = get_config()
    for key, value in params.items():
        if hasattr(config, key):
            setattr(config, key, value)

    # Create model
    input_dim = X_train.shape[1]
    model = PureKANClassifier(input_dim, config, device).to(device)

    try:
        # Train model
        results = train_pure_kan_model(
            model, X_train, X_val, X_test, y_train, y_val, y_test, config, device
        )

        # Clean up memory
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return {
            'test_acc': results['test_acc'],
            'test_auc': results['test_auc'],
            'test_f1': results['test_f1'],
            'best_val_auc': results['best_val_auc'],
            'total_params': results['total_params']
        }

    except Exception as e:
        print(f"KAN training failed: {str(e)}")
        return {
            'test_acc': 0.0,
            'test_auc': 0.0,
            'test_f1': 0.0,
            'best_val_auc': 0.0,
            'total_params': 0
        }


def objective(trial, X_train, X_val, X_test, y_train, y_val, y_test, device):
    """Optuna optimization objective function - Pure KAN"""

    # Hyperparameter search space
    params = {
        # Network structure
        'hidden_size': trial.suggest_int('hidden_size', 16, 64),

        # KAN core parameters
        'kan_grid': trial.suggest_int('kan_grid', 3, 8),

        # Learning rate
        'learning_rate': trial.suggest_float('learning_rate', 1e-4, 1e-2, log=True),

        # Regularization
        'kan_reg_lamb_l1': trial.suggest_float('kan_reg_lamb_l1', 1e-4, 1e-2),
        'kan_reg_lamb_entropy': trial.suggest_float('kan_reg_lamb_entropy', 1e-3, 5.0),

        # Training steps
        'epochs': trial.suggest_int('epochs', 100, 300),

        'seed': 42
    }

    # Train and evaluate
    try:
        results = train_and_evaluate_pure_kan(
            params, X_train, X_val, X_test, y_train, y_val, y_test, device
        )

        # Record additional information
        trial.set_user_attr('test_acc', results['test_acc'])
        trial.set_user_attr('test_f1', results['test_f1'])
        trial.set_user_attr('best_val_auc', results['best_val_auc'])
        trial.set_user_attr('total_params', results['total_params'])

        # Return test AUC as optimization target
        return results['test_auc']

    except Exception as e:
        print(f"Trial failed: {e}")
        return 0.0


def main():
    print("=" * 80)
    print("Pure KAN (Official Implementation) - Bayesian Optimization")
    print("Heart Disease Dataset")
    print("=" * 80)

    # Get configuration
    config = get_config()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")

    # Set random seed
    set_seed(config.seed)

    # Load data
    X_train, X_val, X_test, y_train, y_val, y_test, feature_names = load_heart_disease_data(
        config.data_path, config.test_size, config.val_size_from_remaining, config.seed
    )

    # Create optuna study
    study = optuna.create_study(
        direction='maximize',
        sampler=TPESampler(seed=config.seed),
        study_name="pure_kan_heart_disease"
    )

    # Set objective function with data
    objective_with_data = lambda trial: objective(
        trial, X_train, X_val, X_test, y_train, y_val, y_test, device
    )

    # Bayesian optimization
    n_trials = 100
    print(f"\nStarting Bayesian optimization with {n_trials} trials...")

    study.optimize(objective_with_data, n_trials=n_trials, show_progress_bar=True)

    # Output results
    print("\n" + "=" * 80)
    print("Optimization Results")
    print("=" * 80)

    best_trial = study.best_trial
    print(f"\nBest AUC: {best_trial.value:.4f}")

    print(f"\nBest parameters:")
    print(f"  k: 3 (fixed)")
    for key, value in best_trial.params.items():
        print(f"  {key}: {value}")

    print(f"\nBest trial detailed metrics:")
    print(f"  Test AUC: {best_trial.value:.4f}")
    test_acc = best_trial.user_attrs.get('test_acc', 'N/A')
    print(f"  Test Accuracy: {test_acc:.4f}" if test_acc != 'N/A' else "  Test Accuracy: N/A")
    test_f1 = best_trial.user_attrs.get('test_f1', 'N/A')
    print(f"  Test F1 Score: {test_f1:.4f}" if test_f1 != 'N/A' else "  Test F1 Score: N/A")
    best_val_auc = best_trial.user_attrs.get('best_val_auc', 'N/A')
    print(f"  Best Validation AUC: {best_val_auc:.4f}" if best_val_auc != 'N/A' else "  Best Validation AUC: N/A")
    total_params = best_trial.user_attrs.get('total_params', 'N/A')
    print(f"  Model Parameters: {total_params:,}" if total_params != 'N/A' else "  Model Parameters: N/A")

    # Save results
    results = {
        'model_type': 'Pure KAN (Official)',
        'dataset': 'Heart Disease',
        'best_auc': best_trial.value,
        'best_params': best_trial.params,
        'best_trial_attrs': best_trial.user_attrs,
        'study_stats': {
            'n_trials': len(study.trials),
            'n_complete_trials': len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]),
            'n_failed_trials': len([t for t in study.trials if t.state == optuna.trial.TrialState.FAIL])
        }
    }

    # Save to file
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = os.path.join(config.results_dir, f"pure_kan_heart_optuna_results_{timestamp}.json")

    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    print(f"\nResults saved to: {filename}")
    print(f"Total trials: {len(study.trials)}")
    print(f"Successful trials: {len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE])}")
    print(f"Failed trials: {len([t for t in study.trials if t.state == optuna.trial.TrialState.FAIL])}")


if __name__ == "__main__":
    main()