"""
Demo HKAN - Using EA algorithm discovered optimal feature grouping configuration
Optimal grouping configuration from EA.py:
  Group 0: ['cp', 'restecg', 'thalach', 'exang', 'ca']
  Group 1: ['fbs', 'thalach', 'oldpeak', 'ca']
  Group 2: ['age', 'trestbps', 'chol', 'slope', 'thal']
  Group 3: ['sex', 'slope', 'ca']
Expected performance: AUC ~0.9777, Parameters ~1652
"""

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score
import torch
import torch.nn as nn
import torch.optim as optim
import time
import warnings
import random
import os
from datetime import datetime

from config import get_config
from data_loader import set_seed

warnings.filterwarnings('ignore')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Import official KAN library
try:
    from kan import KAN
    print("Official KAN library imported successfully")
except ImportError:
    print("KAN library not found. Please install: pip install pykan")
    exit(1)


# ================================
# Heart Disease Data Loading
# ================================
def load_heart_disease_data_demo(data_path="heart_disease.csv"):
    """Load Heart Disease dataset for demo (80:20 train:test split)"""
    print("Loading Heart Disease dataset...")

    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Data file not found: {data_path}")

    data = pd.read_csv(data_path)
    print(f"Data shape: {data.shape}")

    # Separate features and target variable
    feature_names = [col for col in data.columns if col != 'target']
    X = data[feature_names].values
    y = data['target'].values

    print(f"Number of features: {len(feature_names)}")
    print(f"Number of samples: {len(y)}")
    print(f"Class distribution: {np.bincount(y)}")
    print(f"Feature list: {feature_names}")

    # Split train and test sets (80:20 for demo)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )

    # Standardize features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    print(f"Train set: {X_train_scaled.shape}, Test set: {X_test_scaled.shape}")
    print(f"Train class distribution: {np.bincount(y_train)}")
    print(f"Test class distribution: {np.bincount(y_test)}")

    return (X_train_scaled, X_test_scaled, y_train, y_test, feature_names)


# ================================
# HKAN Model Definition
# ================================
class HKAN(nn.Module):
    def __init__(self, feature_groups, kan_grid=5, kan_k=3):
        super(HKAN, self).__init__()
        self.feature_groups = feature_groups
        self.num_groups = len(feature_groups)
        self.kan_grid = kan_grid
        self.kan_k = kan_k

        # EA.py regularization parameters
        self.kan_reg_lamb_l1 = 1e-3
        self.kan_reg_lamb_entropy = 1e-3
        self.kan_reg_lamb_coef = 0.0
        self.kan_reg_lamb_coefdiff = 0.0

        # Factor regularization parameters
        self.lambda_decorrelation = 0.01
        self.lambda_sparsity = 0.001
        self.lambda_stability = 0.01

        # Create KAN sub-networks for each feature group
        self.group_kans = nn.ModuleList()
        self.group_outputs = []

        for i, group in enumerate(feature_groups):
            input_dim = len(group)
            # According to EA.py configuration: hidden_dim = input_dim * 1 + 1
            hidden_dim = input_dim * 1 + 1

            # Create KAN network: [input_dim, hidden_dim, 1]
            kan = KAN(width=[input_dim, hidden_dim, 1],
                     grid=kan_grid, k=kan_k, device=device, seed=42)
            self.group_kans.append(kan)
            self.group_outputs.append(1)

        # Final fusion layer - According to EA.py configuration: [num_groups, 1, 1]
        total_group_outputs = sum(self.group_outputs)
        self.final_kan = KAN(width=[total_group_outputs, 1, 1],
                            grid=kan_grid, k=kan_k, device=device, seed=42)

        # Cache for factor regularization
        self.last_combined_factors = None
        self.last_factors_for_fqs = None

    def forward(self, x):
        group_results = []

        # Compute outputs for each feature group
        start_idx = 0
        for i, group in enumerate(self.feature_groups):
            end_idx = start_idx + len(group)
            group_input = x[:, start_idx:end_idx]

            # Through corresponding KAN sub-network
            group_output = self.group_kans[i](group_input)
            group_results.append(group_output)

            start_idx = end_idx

        # Concatenate all group outputs
        combined_output = torch.cat(group_results, dim=1)
        self.last_combined_factors = combined_output  # Save for regularization
        self.last_factors_for_fqs = [f.detach().cpu().numpy().flatten() for f in group_results]

        # Final fusion
        final_output = self.final_kan(combined_output)
        return final_output.squeeze()  # Don't use sigmoid, let BCEWithLogitsLoss handle it

    def get_kan_regularization_loss(self):
        """Get KAN built-in regularization loss"""
        total_reg = 0.0

        # Regularization for each sub-KAN
        for kan in self.group_kans:
            reg = kan.reg(
                'edge_forward_spline_n',
                self.kan_reg_lamb_l1,
                self.kan_reg_lamb_entropy,
                self.kan_reg_lamb_coef,
                self.kan_reg_lamb_coefdiff
            )
            total_reg += reg

        # Regularization for fusion KAN
        reg = self.final_kan.reg(
            'edge_forward_spline_n',
            self.kan_reg_lamb_l1,
            self.kan_reg_lamb_entropy,
            self.kan_reg_lamb_coef,
            self.kan_reg_lamb_coefdiff
        )
        total_reg += reg

        return total_reg

    def get_factor_regularization_loss(self):
        """Get factor regularization loss"""
        if self.last_combined_factors is None:
            return torch.tensor(0.0, device=device)

        factors = self.last_combined_factors
        factor_reg = torch.tensor(0.0, device=device)

        # Decorrelation regularization
        if factors.shape[1] > 1:
            corr_matrix = torch.corrcoef(factors.T)
            if not torch.isnan(corr_matrix).any():
                factor_reg += self.lambda_decorrelation * torch.sum(torch.triu(corr_matrix, diagonal=1)**2)

        # Sparsity regularization
        factor_reg += self.lambda_sparsity * torch.mean(torch.abs(factors))

        # Stability regularization
        factor_reg += self.lambda_stability * torch.sum(torch.var(factors, dim=0))

        return factor_reg

    def get_total_regularization_loss(self):
        """Get total regularization loss"""
        return self.get_kan_regularization_loss() + self.get_factor_regularization_loss()


# ================================
# Pruning and Symbolic Functions
# ================================
def clean_formula(formula_str, threshold=0.05):
    """Clean formula, remove terms with coefficients smaller than threshold"""
    import re

    # Handle special cases
    if not formula_str or formula_str.strip() in ['', '0', '-0.0', '0.0', '-0.005']:
        return '0'

    # Handle scientific notation (like 0.e-1 becomes 0)
    formula_str = re.sub(r'0\.e[+-]\d+', '0', formula_str)

    # Remove very small coefficient terms (including those starting with 0.0)
    # Match terms like +/- 0.0xxx * term
    formula_str = re.sub(r'[+-]?\s*0\.0\d*\s*\*[^+-]+', '', formula_str)

    # Remove standalone small decimal terms
    terms = formula_str.split()
    cleaned_terms = []
    i = 0
    while i < len(terms):
        term = terms[i]
        # Check if it's a small coefficient
        try:
            val = float(term)
            if abs(val) < threshold:
                # Skip this term and its following multiplication term
                if i + 1 < len(terms) and terms[i + 1] == '*':
                    i += 3  # Skip coefficient, *, and variable
                else:
                    i += 1
                continue
        except:
            pass
        cleaned_terms.append(term)
        i += 1

    formula_str = ' '.join(cleaned_terms)

    # Clean up extra plus/minus signs
    formula_str = re.sub(r'\s*\+\s*-', ' - ', formula_str)
    formula_str = re.sub(r'\s+', ' ', formula_str)
    formula_str = formula_str.strip()

    # If empty after cleaning, return 0
    if not formula_str or formula_str in ['+', '-', '']:
        return '0'

    # Remove leading plus sign
    if formula_str.startswith('+ '):
        formula_str = formula_str[2:]

    return formula_str


def prune_and_symbolize_model(model, X_train, y_train, X_test, y_test, feature_groups, feature_names):
    """Prune and symbolize HKAN model"""
    print("\n" + "="*60)
    print("Starting pruning and symbolization process")
    print("="*60)

    # Step 1: Pruning
    print("\nStep 1: Pruning each KAN sub-network...")

    # Prepare data tensors
    X_train_tensor = torch.FloatTensor(X_train).to(device)
    y_train_tensor = torch.FloatTensor(y_train).to(device)
    X_test_tensor = torch.FloatTensor(X_test).to(device)
    y_test_tensor = torch.FloatTensor(y_test).to(device)

    # Ensure model is on correct device
    model = model.to(device)

    # Perform complete forward pass to update all activation values
    print("  Computing activation values...")
    model.eval()
    with torch.no_grad():
        _ = model(X_train_tensor)  # Use all training data to update activation values

    # Prune each sub-KAN
    for i, kan in enumerate(model.group_kans):
        print(f"  Pruning group {i}...")
        # Prune directly on GPU
        model.group_kans[i] = kan.prune(edge_th=0.001)  # Only edge pruning

    print("  Pruning final fusion KAN...")
    model.final_kan = model.final_kan.prune(edge_th=0.1)  # Only edge pruning

    # Step 2: Symbolic regression (symbolize all at once)
    print("\nStep 2: Symbolic regression...")

    # PyKAN symbol library - conservative selection but includes necessary functions
    full_lib = ['x', 'x^2', 'x^3', 'exp', 'log', 'sqrt',
                'sin', 'cos', 'abs', 'tanh']
    fusion_lib = ['x', 'x^2', 'x^3', 'exp', 'log', 'abs']

    # Symbolize each group KAN
    for i, kan in enumerate(model.group_kans):
        print(f"  Symbolizing group {i} KAN...")
        kan.auto_symbolic(lib=full_lib, a_range=(-5, 5), b_range=(-1, 1), weight_simple=0.8)
        try:
            from kan.utils import ex_round
            formula = kan.symbolic_formula()[0][0]
            # Keep 1 decimal place
            formula_simple = str(ex_round(formula, 2))
            # Map feature names
            for j, feature_name in enumerate(feature_groups[i]):
                kan_symbol = f'x_{j+1}'
                formula_simple = formula_simple.replace(kan_symbol, feature_name)
            # Clean formula
            formula_simple = clean_formula(formula_simple)
            print(f"    Group {i} formula: {formula_simple}")
        except Exception as e:
            print(f"    Group {i} formula extraction failed: {e}")

    # Symbolize final fusion KAN
    print("  Symbolizing final fusion layer...")
    model.final_kan.auto_symbolic(lib=fusion_lib, a_range=(-10, 10), b_range=(-2, 2), weight_simple=0.8)
    try:
        from kan.utils import ex_round
        fusion_formula = model.final_kan.symbolic_formula()[0][0]
        # Keep 1 decimal place
        fusion_simple = str(ex_round(fusion_formula, 1))
        # Map to factor names
        for i in range(len(model.group_kans)):
            kan_symbol = f'x_{i+1}'
            fusion_simple = fusion_simple.replace(kan_symbol, f'factor_{i}')
        # Clean formula
        fusion_simple = clean_formula(fusion_simple)
        print(f"    Fusion layer formula: {fusion_simple}")
    except Exception as e:
        print(f"    Fusion layer formula extraction failed: {e}")

    print("\nModel symbolization completed")

    return model


def extract_formulas_with_feature_names(model, feature_groups, feature_names):
    """Extract symbolic formulas from all KANs and map feature names"""
    print("\n" + "="*60)
    print("Extracting symbolic formulas")
    print("="*60)

    formulas = {}

    try:
        from kan.utils import ex_round

        # Extract formulas for each factor group
        for i, kan in enumerate(model.group_kans):
            try:
                formula = kan.symbolic_formula()[0][0]
                # Keep 1 decimal place
                formula_simple = str(ex_round(formula, 1))

                # Map KAN symbols to actual feature names
                for j, feature_name in enumerate(feature_groups[i]):
                    kan_symbol = f'x_{j+1}'
                    formula_simple = formula_simple.replace(kan_symbol, feature_name)

                # Clean formula
                formula_simple = clean_formula(formula_simple)

                formulas[f'group_{i}'] = {
                    'formula': formula_simple,
                    'features': feature_groups[i],
                    'description': f'Factor Group {i}: {feature_groups[i]}'
                }

                print(f"\nFactor Group {i} Formula ({feature_groups[i]}):")
                print(f"  {formula_simple}")

            except Exception as e:
                print(f"  Group {i} symbol extraction failed: {e}")
                formulas[f'group_{i}'] = {
                    'formula': 'N/A',
                    'features': feature_groups[i],
                    'description': f'Factor Group {i}: {feature_groups[i]} (symbol extraction failed)'
                }

        # Extract final fusion KAN formula
        try:
            fusion_formula = model.final_kan.symbolic_formula()[0][0]
            # Keep 1 decimal place
            fusion_simple = str(ex_round(fusion_formula, 1))

            # Map fusion layer symbols to factor group names
            for i in range(len(model.group_kans)):
                kan_symbol = f'x_{i+1}'
                fusion_simple = fusion_simple.replace(kan_symbol, f'factor_{i}')

            # Clean formula
            fusion_simple = clean_formula(fusion_simple)

            formulas['fusion'] = {
                'formula': fusion_simple,
                'features': [f'factor_{i}' for i in range(len(model.group_kans))],
                'description': 'Final Fusion Layer'
            }

            print(f"\nFinal Fusion Layer Formula:")
            print(f"  {fusion_simple}")

        except Exception as e:
            print(f"  Fusion layer symbol extraction failed: {e}")
            formulas['fusion'] = {
                'formula': 'N/A',
                'features': [f'factor_{i}' for i in range(len(model.group_kans))],
                'description': 'Final Fusion Layer (symbol extraction failed)'
            }

    except ImportError:
        print("Cannot import kan.utils.ex_round, skipping symbolic formula extraction")

    return formulas


def save_kan_plots(model, feature_groups, feature_names, save_dir='kan_plots'):
    """Save KAN network structure plots"""
    print(f"\n" + "="*60)
    print(f"Saving KAN network structure plots")
    print("="*60)

    # Create save directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    full_save_dir = f"{save_dir}_{timestamp}"
    os.makedirs(full_save_dir, exist_ok=True)

    try:
        import matplotlib.pyplot as plt

        # Save each sub-KAN network structure plot
        for i, kan in enumerate(model.group_kans):
            try:
                print(f"  Saving factor group {i} network plot...")
                plt.figure(figsize=(10, 6))

                # Call kan.plot() to directly show network structure, no labels passed
                # beta controls activation function display intensity, higher values show more clearly
                kan.plot(beta=5)

                # Save plot (no title added)
                plt.savefig(os.path.join(full_save_dir, f'group_{i}_network.png'), dpi=150, bbox_inches='tight')
                plt.close()

                print(f"    Saved: group_{i}_network.png")
            except Exception as e:
                print(f"    Group {i} network plot save failed: {e}")

        # Save fusion layer network structure plot
        try:
            print(f"  Saving fusion layer network plot...")
            plt.figure(figsize=(10, 6))

            # Call plot method, no labels passed
            # beta controls activation function display intensity
            model.final_kan.plot(beta=5)

            # Save plot (no title added)
            plt.savefig(os.path.join(full_save_dir, 'fusion_network.png'), dpi=150, bbox_inches='tight')
            plt.close()

            print(f"    Saved: fusion_network.png")
        except Exception as e:
            print(f"    Fusion layer network plot save failed: {e}")

        print(f"\nAll network plots saved to: {full_save_dir}/")
        return full_save_dir

    except Exception as e:
        print(f"Network plot saving process error: {e}")
        return None


def save_results_to_file(formulas, metrics, param_count, feature_groups, save_dir):
    """Save all results to file"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"hkan_heart_disease_demo_results_{timestamp}.txt"

    with open(filename, 'w', encoding='utf-8') as f:
        f.write("HKAN Heart Disease Prediction Demo - Pruning and Symbolization Results\n")
        f.write("="*60 + "\n\n")

        # Performance metrics
        f.write("Performance Metrics:\n")
        f.write(f"  Test AUC: {metrics['auc']:.4f}\n")
        f.write(f"  Test Accuracy: {metrics['accuracy']:.4f}\n")
        f.write(f"  Test F1 Score: {metrics['f1']:.4f}\n")
        f.write(f"  Model Parameters: {param_count:,}\n\n")

        # Feature grouping configuration
        f.write("Feature Grouping Configuration:\n")
        for i, group in enumerate(feature_groups):
            f.write(f"  Group {i}: {group}\n")
        f.write("\n")

        # Symbolic formulas
        f.write("Extracted Symbolic Formulas:\n")
        for key, formula_info in formulas.items():
            f.write(f"\n{formula_info['description']}:\n")
            f.write(f"  Formula: {formula_info['formula']}\n")
            f.write(f"  Features: {formula_info['features']}\n")

        f.write(f"\nSpline plots saved to: {save_dir}/\n")

    print(f"\nDetailed results saved to: {filename}")
    return filename


# ================================
# Training Function
# ================================
def train_model(model, X_train, y_train, X_test, y_test, epochs=200, lr=0.01, weight_decay=1e-5):
    """Train HKAN model - fully reproduce EA.py training process"""
    print("Starting HKAN model training (complete EA.py configuration with regularization)...")

    # Convert to tensors
    X_train_tensor = torch.FloatTensor(X_train).to(device)
    y_train_tensor = torch.FloatTensor(y_train).to(device)
    X_test_tensor = torch.FloatTensor(X_test).to(device)
    y_test_tensor = torch.FloatTensor(y_test).to(device)

    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()  # Use BCEWithLogitsLoss (consistent with EA.py)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)  # Use AdamW
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)  # Learning rate scheduler

    best_auc = 0
    best_metrics = {}
    patience = 0
    max_patience = 20

    for epoch in range(epochs):
        # Training mode
        model.train()
        optimizer.zero_grad()

        outputs = model(X_train_tensor)
        # Main loss + regularization loss (completely consistent with EA.py)
        loss = criterion(outputs.squeeze(), y_train_tensor) + model.get_total_regularization_loss()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()

        # Validation
        model.eval()
        with torch.no_grad():
            test_outputs = model(X_test_tensor)
            val_loss = criterion(test_outputs.squeeze(), y_test_tensor)

            # Use sigmoid to get probabilities
            test_probs = torch.sigmoid(test_outputs).cpu().numpy().squeeze()
            test_preds = (test_probs > 0.5).astype(int)

            auc = roc_auc_score(y_test, test_probs)
            acc = accuracy_score(y_test, test_preds)
            f1 = f1_score(y_test, test_preds)

            # Calculate FQS (Factor Quality Score)
            if model.last_factors_for_fqs is not None and len(model.last_factors_for_fqs) >= 2:
                try:
                    fqs = calculate_factor_quality_score(model.last_factors_for_fqs)
                except:
                    fqs = 0.5
            else:
                fqs = 0.5

            # Comprehensive metric (consistent with EA.py)
            metric = 0.9 * auc + 0.1 * fqs

        # Learning rate scheduling
        scheduler.step(val_loss)

        # Early stopping
        if metric > best_auc:
            best_auc = metric
            best_metrics = {
                'auc': auc,
                'accuracy': acc,
                'f1': f1,
                'fqs': fqs,
                'epoch': epoch + 1
            }
            patience = 0
        else:
            patience += 1
            if patience >= max_patience:
                print(f"\nEarly stopping at epoch {epoch+1}")
                break

        # Print progress
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1:3d} | Loss: {loss.item():.4f} | "
                  f"AUC: {auc:.4f} | Acc: {acc:.4f} | F1: {f1:.4f} | FQS: {fqs:.4f}")

    return best_metrics


def calculate_factor_quality_score(factors, independence_weight=0.4, stability_weight=0.3, sparsity_weight=0.3):
    """Calculate Factor Quality Score - reproduce EA.py FQS calculation"""
    import numpy as np

    factors_array = np.array(factors).T
    if factors_array.shape[1] < 2:
        return 0.5

    # Independence score
    correlation_matrix = np.corrcoef(factors_array.T)
    off_diagonal = correlation_matrix[np.triu_indices_from(correlation_matrix, k=1)]
    independence_score = 1 - np.mean(np.abs(off_diagonal))

    # Stability score
    factor_vars = np.var(factors_array, axis=0)
    stability_score = np.exp(-np.mean(np.log(1 + factor_vars)))

    # Sparsity score
    tau = np.mean(np.abs(factors_array)) * 0.1
    sparsity_activations = np.mean(np.abs(factors_array) > tau, axis=0)
    sparsity_score = 1.0 - np.mean(sparsity_activations)

    fqs = (independence_weight * independence_score +
           stability_weight * stability_score +
           sparsity_weight * sparsity_score)

    return fqs


# ================================
# Model Parameter Counting
# ================================
def count_parameters(model):
    """Count model parameters"""
    total_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            total_params += param.numel()
    return total_params


# ================================
# Main Program
# ================================
def main():
    print("=" * 60)
    print("HKAN Heart Disease Prediction Demo - Using EA Algorithm Optimal Grouping")
    print("=" * 60)

    # Get configuration
    config = get_config()
    set_seed(config.seed)

    # Load data
    X_train, X_test, y_train, y_test, feature_names = load_heart_disease_data_demo(config.data_path)

    # EA algorithm discovered optimal feature grouping configuration
    feature_groups = [
        ['cp', 'restecg', 'thalach', 'exang', 'ca'],        # Group 0: 5 features
        ['fbs', 'thalach', 'oldpeak', 'ca'],                # Group 1: 4 features
        ['age', 'trestbps', 'chol', 'slope', 'thal'],       # Group 2: 5 features
        ['sex', 'slope', 'ca']                              # Group 3: 3 features
    ]

    # Map feature names to indices
    feature_to_idx = {name: idx for idx, name in enumerate(feature_names)}
    feature_group_indices = []

    for group in feature_groups:
        group_indices = [feature_to_idx[feat] for feat in group if feat in feature_to_idx]
        feature_group_indices.append(group_indices)

    print(f"\nFeature grouping configuration:")
    for i, (group, indices) in enumerate(zip(feature_groups, feature_group_indices)):
        print(f"  Group {i}: {group} -> indices {indices}")

    # Rearrange data to match grouping
    all_indices = []
    for group_indices in feature_group_indices:
        all_indices.extend(group_indices)

    X_train_grouped = X_train[:, all_indices]
    X_test_grouped = X_test[:, all_indices]

    # Create HKAN model - using EA.py configuration
    model = HKAN(
        feature_groups=[[i for i in range(len(group))] for group in feature_group_indices],
        kan_grid=5,
        kan_k=3
    )

    print(f"\nKAN configuration: grid=5, k=3")
    print(f"Hidden layer rule: hidden_dim = input_dim * 1 + 1")
    print(f"Fusion layer structure: [num_groups={len(feature_groups)}, 1, 1]")

    # Calculate parameter count
    param_count = count_parameters(model)
    print(f"\nModel parameters: {param_count:,}")

    # Stage 1: Initial training
    print("\n" + "="*60)
    print("Stage 1: Initial Training")
    print("="*60)
    start_time = time.time()

    best_metrics = train_model(
        model, X_train_grouped, y_train, X_test_grouped, y_test,
        epochs=200, lr=0.01, weight_decay=1e-5  # Use EA.py parameters
    )

    initial_training_time = time.time() - start_time
    print(f"Initial training time: {initial_training_time:.2f} seconds")

    # Stage 2: Pruning and symbolization
    print("\n" + "="*60)
    print("Stage 2: Pruning, Symbolization and Visualization")
    print("="*60)

    # Pruning and symbolization
    model = prune_and_symbolize_model(
        model, X_train_grouped, y_train, X_test_grouped, y_test,
        feature_groups, feature_names
    )

    # Re-evaluate performance after pruning
    print("\nEvaluating performance after pruning...")
    model.eval()
    with torch.no_grad():
        X_test_tensor = torch.FloatTensor(X_test_grouped).to(device)
        test_outputs = model(X_test_tensor)
        test_probs = torch.sigmoid(test_outputs).cpu().numpy().squeeze()
        test_preds = (test_probs > 0.5).astype(int)

        final_auc = roc_auc_score(y_test, test_probs)
        final_acc = accuracy_score(y_test, test_preds)
        final_f1 = f1_score(y_test, test_preds)

    # Update final metrics
    final_metrics = {
        'auc': final_auc,
        'accuracy': final_acc,
        'f1': final_f1,
        'initial_auc': best_metrics['auc'],
        'pruning_auc_change': final_auc - best_metrics['auc']
    }

    # Extract symbolic formulas
    formulas = extract_formulas_with_feature_names(model, feature_groups, feature_names)

    # Save spline plots
    plot_save_dir = save_kan_plots(model, feature_groups, feature_names)

    # Recalculate parameter count after pruning
    final_param_count = count_parameters(model)
    param_reduction = param_count - final_param_count

    total_time = time.time() - start_time

    # Output final results
    print("\n" + "="*60)
    print("HKAN Complete Process Report")
    print("="*60)
    print(f"Original model AUC: {best_metrics['auc']:.4f}")
    print(f"After pruning AUC: {final_auc:.4f} (change: {final_auc - best_metrics['auc']:+.4f})")
    print(f"Final accuracy: {final_acc:.4f}")
    print(f"Final F1 score: {final_f1:.4f}")
    print(f"Original parameters: {param_count:,}")
    print(f"After pruning parameters: {final_param_count:,} (reduced: {param_reduction:,})")
    print(f"Total time: {total_time:.2f} seconds")

    print("\nFeature grouping configuration:")
    for i, group in enumerate(feature_groups):
        print(f"  Group {i}: {group}")

    print("\nSymbolic formulas summary:")
    for key, formula_info in formulas.items():
        print(f"  {formula_info['description']}: {formula_info['formula']}")

    # Save detailed results to file
    save_results_to_file(formulas, final_metrics, final_param_count, feature_groups, plot_save_dir)

    return final_metrics, formulas, plot_save_dir


if __name__ == "__main__":
    main()