"""
Evolutionary Algorithm for HKAN Feature Grouping Discovery
Factor Guided Evolutionary Architecture Search (FG-EAS) + Dual Regularization

Global Best Results from Original EA.py:
  Fitness: 0.9549
  Discovery Generation: 14
  AUC Score: 0.9777
  Accuracy: 0.8667
  FQS Score: 0.7495
  Parameters: 1652

Global Best Feature Grouping Configuration (4 groups):
  Group 0: ['cp', 'restecg', 'thalach', 'exang', 'ca']
  Group 1: ['fbs', 'thalach', 'oldpeak', 'ca']
  Group 2: ['age', 'trestbps', 'chol', 'slope', 'thal']
  Group 4: ['sex', 'slope', 'ca']
"""

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.feature_selection import mutual_info_classif
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time
import warnings
import random
from tqdm import tqdm
from deap import base, creator, tools
import gc
import os
from collections import Counter

from config import get_config
from data_loader import set_seed
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from models import HKANClassification, calculate_factor_quality_score

warnings.filterwarnings('ignore')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Import official KAN library
try:
    from kan import KAN
    print("Official KAN library imported successfully")
except ImportError:
    print("KAN library not found. Please install: pip install pykan")
    exit(1)


def cleanup_gpu_memory():
    """Clean up GPU memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

def load_heart_disease_data_for_ea(data_path="heart_disease.csv"):
    """Load Heart Disease dataset - EA algorithm uses training set only"""
    print("Loading Heart Disease dataset (EA specific: training set only)...")

    # Check if data file exists
    import os
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Data file not found: {data_path}")

    data = pd.read_csv(data_path)
    print(f"Data shape: {data.shape}")

    # Separate features and target variable
    feature_names = [col for col in data.columns if col != 'target']
    X = data[feature_names].values
    y = data['target'].values

    print(f"Number of features: {len(feature_names)}")
    print(f"Number of samples: {len(y)}")
    print(f"Class distribution: {np.bincount(y)}")
    print(f"Feature list: {feature_names}")

    # EA algorithm uses all data as training set to avoid data waste
    # Subsequent Bayesian optimization will perform proper train/val/test split
    X_train = X
    y_train = y

    # Standardize features
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)

    # Convert to appropriate data types
    X_train = np.asarray(X_train, dtype=np.float32)
    y_train = np.asarray(y_train, dtype=np.int64)

    print(f"EA training set: {X_train.shape}")
    print(f"EA training set class distribution: {np.bincount(y_train)}")
    print(f"Note: EA algorithm uses all data, subsequent Bayesian optimization will perform proper data split")

    return X_train, y_train, feature_names


# ================================
# Parameter Configuration Center
# ================================
class HeartDiseaseHKANConfig:
    def __init__(self):
        # Training parameters
        self.epochs = 200
        self.learning_rate = 0.01
        self.weight_decay = 1e-5
        self.patience = 20
        self.grad_clip = 1.0
        self.num_classes = 1  # Binary classification uses sigmoid

        # Official KAN parameters
        self.kan_grid = 5
        self.kan_k = 3
        self.kan_hidden_multiplier = 1  # SubKAN hidden layer multiplier
        self.fusion_hidden_layers = [1]  # Fusion layer configuration

        # Official KAN regularization parameters
        self.kan_reg_lamb_l1 = 1e-3
        self.kan_reg_lamb_entropy = 1e-3
        self.kan_reg_lamb_coef = 0.0
        self.kan_reg_lamb_coefdiff = 0.0

        # Factor regularization parameters
        self.lambda_decorrelation = 0.01
        self.lambda_sparsity = 0.001
        self.lambda_stability = 0.01

        # Factor Quality Score weights
        self.fqs_independence_weight = 0.4
        self.fqs_stability_weight = 0.3
        self.fqs_sparsity_weight = 0.3
        self.fqs_semantic_weight = 0

        # Unified fitness weights
        self.performance_weight = 0.9
        self.quality_weight = 0.1

        # GA parameters (adapted for 13 features, convenient for formulation)
        self.ga_pop_size = 30
        self.ga_max_generations = 20
        self.ga_min_groups = 3        # Minimum 3 groups
        self.ga_max_groups = 6        # Maximum 6 groups (13 features / 6 groups ≈ 2.2 features/group)
        self.ga_min_features_per_group = 1    # Minimum 1 feature per group
        self.ga_max_features_per_group = 5    # Maximum 5 features per group
        self.ga_crossover_rate = 0.7
        self.ga_mutation_rate = 0.8

        self.description = "Heart Disease FG-EAS HKAN Configuration (Binary Classification Task)"


# ================================
# Feature preprocessing
# ================================
def preprocess_features_heart_disease(X_train, X_val, X_test, group_matrix, active_groups, feature_names, device='cpu'):
    """Preprocess Heart Disease features and assign by groups"""
    group_train_data, group_val_data, group_test_data, group_feature_info = {}, {}, {}, {}

    for group_id in sorted(active_groups):
        feature_indices = np.where(group_matrix[:, group_id] == 1)[0]
        if len(feature_indices) == 0:
            continue

        group_name = f"group_{group_id}"

        X_train_group = X_train[:, feature_indices].copy()
        X_val_group = X_val[:, feature_indices].copy()
        X_test_group = X_test[:, feature_indices].copy()

        group_train_data[group_name] = torch.FloatTensor(X_train_group).to(device)
        group_val_data[group_name] = torch.FloatTensor(X_val_group).to(device)
        group_test_data[group_name] = torch.FloatTensor(X_test_group).to(device)
        group_feature_info[group_name] = {'input_dim': len(feature_indices)}

    return group_train_data, group_val_data, group_test_data, group_feature_info


# ================================
# Unified fitness evaluator (classification version)
# ================================
class HKANUnifiedClassificationEvaluator:
    def __init__(self, X_train, y_train, feature_names, config):
        self.X_train, self.y_train = X_train, y_train
        self.feature_names = feature_names
        self.config = config

        print(f"[TARGET] Binary classification task unified fitness weights: Performance(AUC)={config.performance_weight:.2f}, Quality(FQS)={config.quality_weight:.2f}")
        print(f"[TARGET] EA evaluation performed only on training set, optimal feature grouping will be passed to Bayesian optimization for proper data splitting and hyperparameter tuning")

    def evaluate_single(self, individual_matrix, active_groups, generation=None):
        train_data, _, _, info = preprocess_features_heart_disease(
            self.X_train, self.X_train, self.X_train, individual_matrix, active_groups, self.feature_names, device
        )
        assert info, "Feature preprocessing failed, check grouping matrix and active groups"

        model = HKANClassification(info, self.config, device).to(device)

        auc, acc, fqs = self._train_model(model, train_data)
        params = model.count_parameters()

        assert auc >= 0, f"AUC score abnormal: {auc}"
        performance_score = auc
        quality_score = fqs
        fitness = (self.config.performance_weight * performance_score +
                  self.config.quality_weight * quality_score)

        del model
        cleanup_gpu_memory()

        metrics = {
            'auc': f"{auc:.4f}",
            'acc': f"{acc:.4f}",
            'fqs': f"{fqs:.4f}",
            'params': params,
            'fitness': f"{fitness:.4f}",
            'generation': generation or 0,
            'status': 'Normal'
        }

        return fitness, metrics

    def _train_model(self, model, train_data):
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )
        criterion = torch.nn.BCEWithLogitsLoss()  # Binary classification loss

        y_train = torch.FloatTensor(self.y_train.astype(np.float32)).to(device)

        best_metric, patience, best_auc, best_acc, best_fqs = -float('inf'), 0, 0.0, 0.0, 0.5

        for epoch in range(self.config.epochs):
            model.train()
            optimizer.zero_grad()

            predictions = model(train_data).squeeze()
            loss = criterion(predictions, y_train) + model.get_total_regularization_loss()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), self.config.grad_clip)
            optimizer.step()

            # Evaluate on training set (every 10 epochs to save time)
            if epoch % 10 == 0 or epoch == self.config.epochs - 1:
                model.eval()
                with torch.no_grad():
                    train_predictions = model(train_data).squeeze()

                    # Calculate AUC and Accuracy on training set
                    y_train_np = y_train.cpu().numpy()
                    train_prob = torch.sigmoid(train_predictions).cpu().numpy()
                    train_pred = (train_prob > 0.5).astype(int)

                    auc = roc_auc_score(y_train_np, train_prob)
                    acc = accuracy_score(y_train_np, train_pred)

                    fqs = calculate_factor_quality_score(model.last_factors_for_fqs, self.config)
                    metric = 0.9 * auc + 0.1 * fqs  # Use AUC as main metric

                if metric > best_metric:
                    best_metric, best_auc, best_acc, best_fqs, patience = metric, auc, acc, fqs, 0
                else:
                    patience += 1

                # Early stopping mechanism (check every 10 epochs)
                if patience >= self.config.patience // 10:
                    break

        return best_auc, best_acc, best_fqs

    def evaluate_serial(self, individuals, generation, current_best_auc_str):
        full_results = []

        with tqdm(total=len(individuals), desc=f"Gen {generation:2d} Eval", unit="ind", ncols=150) as pbar:
            for ind in individuals:
                matrix, active_groups = ind[0], ind[1]
                fitness, metrics = self.evaluate_single(matrix, active_groups, generation)
                full_results.append({'fitness': fitness, 'metrics': metrics})

                group_sizes = [np.sum(matrix[:, g]) for g in active_groups]
                min_feat, max_feat = (min(group_sizes), max(group_sizes)) if group_sizes else (0, 0)
                group_info = f"Grp:{len(active_groups)}({min_feat}-{max_feat})"

                pbar.set_postfix(
                    auc=metrics.get('auc', 'N/A'),
                    acc=metrics.get('acc', 'N/A'),
                    fqs=metrics.get('fqs', 'N/A'),
                    best_AUC=current_best_auc_str,
                    groups=group_info,
                    status=metrics.get('status', 'N/A')[:4]
                )
                pbar.update(1)

        return full_results


# ================================
# Genetic Algorithm (maintain complete EA.py structure)
# ================================
class FactorGuidedGeneticAlgorithm:
    def __init__(self, evaluator, num_features, config):
        self.evaluator = evaluator
        self.num_features = num_features
        self.config = config

        self.global_best_fitness = -float('inf')
        self.global_best_individual = None
        self.global_best_metrics = {}
        self.global_best_generation = 0

        self.mutation_rate = config.ga_mutation_rate

        self.mutation_stats = {
            'add_feature': 0, 'remove_feature': 0, 'swap_feature': 0,
            'merge_groups': 0, 'split_group': 0, 'delete_group': 0
        }

    def create_individual(self):
        cfg = self.config
        matrix = np.zeros((self.num_features, cfg.ga_max_groups), dtype=int)
        num_groups = random.randint(cfg.ga_min_groups, cfg.ga_max_groups)
        active_groups = set(range(num_groups))

        available_features = list(range(self.num_features))
        for group_id in active_groups:
            num_features_in_group = random.randint(
                min(cfg.ga_min_features_per_group, len(available_features)),
                min(cfg.ga_max_features_per_group, len(available_features))
            )
            selected_features = random.sample(available_features, num_features_in_group)
            for feat_idx in selected_features:
                matrix[feat_idx, group_id] = 1

        unassigned_features = [i for i in range(self.num_features) if np.sum(matrix[i, :]) == 0]
        if unassigned_features and active_groups:
             for feat_idx in unassigned_features:
                 matrix[feat_idx, random.choice(list(active_groups))] = 1

        return [matrix, active_groups]

    def structure_aware_crossover(self, parent1, parent2):
        # [Previous crossover implementation remains the same]
        matrix1, groups1 = parent1[0].copy(), parent1[1].copy()
        matrix2, groups2 = parent2[0].copy(), parent2[1].copy()

        all_active_groups = list(groups1.union(groups2))
        if len(all_active_groups) <= 1:
            return [matrix1, groups1], [matrix2, groups2]

        crossover_point = random.randint(1, len(all_active_groups))
        groups_to_swap = set(all_active_groups[:crossover_point])

        child1_matrix, child2_matrix = matrix1.copy(), matrix2.copy()

        for group_id in groups_to_swap:
            if group_id < self.config.ga_max_groups:
                temp_col = child1_matrix[:, group_id].copy()
                child1_matrix[:, group_id] = child2_matrix[:, group_id]
                child2_matrix[:, group_id] = temp_col

        child1_matrix, child1_groups = self._repair_constraints(child1_matrix)
        child2_matrix, child2_groups = self._repair_constraints(child2_matrix)

        return [child1_matrix, child1_groups], [child2_matrix, child2_groups]

    def _repair_constraints(self, matrix):
        # [Previous repair implementation remains the same]
        active_groups = set()
        for g in range(self.config.ga_max_groups):
            if np.sum(matrix[:, g]) > 0:
                active_groups.add(g)

        # Ensure minimum number of groups
        while len(active_groups) < self.config.ga_min_groups:
            available_groups = set(range(self.config.ga_max_groups)) - active_groups
            if not available_groups:
                break
            new_group = min(available_groups)
            unassigned_features = [i for i in range(self.num_features)
                                   if np.sum(matrix[i, :]) == 0]
            if not unassigned_features:
                multi_assigned = [i for i in range(self.num_features)
                                  if np.sum(matrix[i, :]) > 1]
                if multi_assigned:
                    feature_to_move = random.choice(multi_assigned)
                    current_groups = np.where(matrix[feature_to_move, :] == 1)[0]
                    group_to_remove_from = random.choice(current_groups)
                    matrix[feature_to_move, group_to_remove_from] = 0
                    matrix[feature_to_move, new_group] = 1
                    active_groups.add(new_group)
            else:
                features_to_assign = min(len(unassigned_features),
                                           self.config.ga_max_features_per_group)
                selected_features = random.sample(unassigned_features, features_to_assign)
                for feat_idx in selected_features:
                    matrix[feat_idx, new_group] = 1
                active_groups.add(new_group)

        # Assign unassigned features
        for i in range(self.num_features):
            if np.sum(matrix[i, :]) == 0:
                target_group = random.choice(list(active_groups))
                matrix[i, target_group] = 1

        # Handle groups that are too large or too small
        for group_id in list(active_groups):
            group_size = np.sum(matrix[:, group_id])
            if group_size > self.config.ga_max_features_per_group:
                features_in_group = np.where(matrix[:, group_id] == 1)[0]
                excess = group_size - self.config.ga_max_features_per_group
                to_redistribute = random.sample(features_in_group.tolist(), excess)

                for feat_idx in to_redistribute:
                    matrix[feat_idx, group_id] = 0
                    available_groups = [g for g in active_groups
                                        if g != group_id and
                                        np.sum(matrix[:, g]) < self.config.ga_max_features_per_group]
                    if available_groups:
                        matrix[feat_idx, random.choice(available_groups)] = 1
            elif group_size < self.config.ga_min_features_per_group:
                matrix[:, group_id] = 0
                active_groups.discard(group_id)

        final_active_groups = set(g for g in range(self.config.ga_max_groups)
                                  if np.sum(matrix[:, g]) > 0)

        return matrix, final_active_groups

    def mutate_individual(self, individual):
        # [Previous mutation implementation with translated comments]
        matrix, groups = individual[0], individual[1].copy()

        # Select mutation operation based on current state
        weights = self.get_mutation_weights(len(groups))
        ops, probs = list(weights.keys()), list(weights.values())
        op = random.choices(ops, weights=probs, k=1)[0]
        self.mutation_stats[op] += 1

        mutators = {
            'add_feature': self._mutate_add_feature,
            'remove_feature': self._mutate_remove_feature,
            'swap_feature': self._mutate_swap_feature,
            'merge_groups': self._mutate_merge_groups,
            'split_group': self._mutate_split_group,
            'delete_group': self._mutate_delete_group,
        }

        mutators[op](matrix, groups)
        individual[1] = set(np.where(np.sum(matrix, axis=0) > 0)[0])

        # Handle unassigned features
        unassigned = [i for i in range(self.num_features) if np.sum(matrix[i, :]) == 0]
        if unassigned and individual[1]:
            for feat_idx in unassigned:
                matrix[feat_idx, random.choice(list(individual[1]))] = 1

        return individual,

    def get_mutation_weights(self, num_groups):
        cfg = self.config
        if num_groups >= cfg.ga_max_groups - 1:
            return {
                'add_feature': 0.15, 'remove_feature': 0.20, 'swap_feature': 0.25,
                'merge_groups': 0.25, 'split_group': 0.05, 'delete_group': 0.10
            }
        elif num_groups <= cfg.ga_min_groups + 1:
            return {
                'add_feature': 0.25, 'remove_feature': 0.10, 'swap_feature': 0.25,
                'merge_groups': 0.05, 'split_group': 0.25, 'delete_group': 0.10
            }
        else:
            return {
                'add_feature': 0.20, 'remove_feature': 0.15, 'swap_feature': 0.30,
                'merge_groups': 0.15, 'split_group': 0.15, 'delete_group': 0.05
            }

    def _mutate_add_feature(self, matrix, groups):
        if not groups: return False
        group_id = random.choice(list(groups))
        if np.sum(matrix[:, group_id]) < self.config.ga_max_features_per_group:
            candidates = np.where(matrix[:, group_id] == 0)[0]
            if len(candidates) > 0:
                feature_idx = random.choice(candidates)
                matrix[feature_idx, group_id] = 1
                return True
        return False

    def _mutate_remove_feature(self, matrix, groups):
        if not groups: return False
        group_id = random.choice(list(groups))
        if np.sum(matrix[:, group_id]) > self.config.ga_min_features_per_group:
            candidates = np.where(matrix[:, group_id] == 1)[0]
            if len(candidates) > 0:
                feature_idx = random.choice(candidates)
                if np.sum(matrix[feature_idx, :]) > 1:
                    matrix[feature_idx, group_id] = 0
                    return True
        return False

    def _mutate_swap_feature(self, matrix, groups):
        if len(groups) < 2: return False
        feature_idx = random.randint(0, self.num_features - 1)
        current_groups = np.where(matrix[feature_idx, :] == 1)[0]
        if len(current_groups) == 0: return False

        from_group = random.choice(current_groups)
        possible_to_groups = list(groups - {from_group})
        if not possible_to_groups: return False

        to_group = random.choice(possible_to_groups)
        if (np.sum(matrix[:, from_group]) > self.config.ga_min_features_per_group and
            np.sum(matrix[:, to_group]) < self.config.ga_max_features_per_group):
            matrix[feature_idx, from_group] = 0
            matrix[feature_idx, to_group] = 1
            return True
        return False

    def _mutate_merge_groups(self, matrix, groups):
        if len(groups) < 2: return False
        g1, g2 = random.sample(list(groups), 2)
        merged_features = np.logical_or(matrix[:, g1], matrix[:, g2])
        if np.sum(merged_features) <= self.config.ga_max_features_per_group:
            matrix[:, g1] = merged_features
            matrix[:, g2] = 0
            return True
        return False

    def _mutate_split_group(self, matrix, groups):
        if len(groups) >= self.config.ga_max_groups: return False
        splittable_groups = [g for g in groups if np.sum(matrix[:, g]) >= 2 * self.config.ga_min_features_per_group]
        if not splittable_groups: return False

        g_to_split = random.choice(splittable_groups)
        features_in_g = np.where(matrix[:, g_to_split] == 1)[0]
        random.shuffle(features_in_g)
        split_point = len(features_in_g) // 2

        if (split_point < self.config.ga_min_features_per_group or
            (len(features_in_g) - split_point) < self.config.ga_min_features_per_group):
            return False

        new_g = next((i for i in range(self.config.ga_max_groups) if i not in groups), None)
        if new_g is None: return False

        matrix[features_in_g[split_point:], g_to_split] = 0
        matrix[features_in_g[split_point:], new_g] = 1
        return True

    def _mutate_delete_group(self, matrix, groups):
        if len(groups) <= self.config.ga_min_groups: return False
        g_to_del = random.choice(list(groups))
        features_to_reassign = np.where(matrix[:, g_to_del] == 1)[0]
        matrix[:, g_to_del] = 0

        remaining_groups = list(groups - {g_to_del})
        if not remaining_groups: return False

        for feat_idx in features_to_reassign:
            if np.sum(matrix[feat_idx, :]) == 0:
                target_g = random.choice(remaining_groups)
                matrix[feat_idx, target_g] = 1
        return True

    def mixed_selection(self, population, k):
        tournament_k = int(k * 0.8)
        random_k = k - tournament_k

        selected = tools.selTournament(population, tournament_k, tournsize=2)
        if random_k > 0:
            selected_ids = set(id(ind) for ind in selected)
            remaining = [ind for ind in population if id(ind) not in selected_ids]
            if remaining:
                selected.extend(tools.selRandom(remaining, min(random_k, len(remaining))))

        return selected

    def run_evolution(self):
        creator.create("FitnessMax", base.Fitness, weights=(1.0,))
        creator.create("Individual", list, fitness=creator.FitnessMax, metrics=dict)

        toolbox = base.Toolbox()
        toolbox.register("individual", lambda: creator.Individual(self.create_individual()))
        toolbox.register("population", tools.initRepeat, list, toolbox.individual)
        toolbox.register("mate", self.structure_aware_crossover)
        toolbox.register("mutate", self.mutate_individual)
        toolbox.register("select", self.mixed_selection)

        population = toolbox.population(n=self.config.ga_pop_size)

        print(f"\nStarting Heart Disease Factor Guided Evolution Algorithm (FG-EAS)")
        print(f"[TREE] Population size: {self.config.ga_pop_size}, Total generations: {self.config.ga_max_generations}")
        print(f"[TREE] Initialization strategy: Pure random initialization")

        for gen in range(self.config.ga_max_generations):
            print(f"\n--- Generation {gen+1:2d} / {self.config.ga_max_generations} ---")

            best_auc_so_far = self.global_best_metrics.get('auc', 'N/A')
            eval_results = self.evaluator.evaluate_serial(population, gen + 1, best_auc_so_far)

            for ind, result in zip(population, eval_results):
                ind.fitness.values = (result['fitness'],)
                ind.metrics = result['metrics']

            valid_pop = [ind for ind in population if ind.fitness.valid]
            if valid_pop:
                best_ind = tools.selBest(valid_pop, 1)[0]
                if best_ind.fitness.values[0] > self.global_best_fitness:
                    self.global_best_fitness = best_ind.fitness.values[0]
                    self.global_best_individual = best_ind
                    self.global_best_metrics = best_ind.metrics
                    self.global_best_generation = gen + 1
                    print(f"  New global best discovered! Fitness: {self.global_best_fitness:.4f}, AUC: {self.global_best_metrics.get('auc', 'N/A')}, ACC: {self.global_best_metrics.get('acc', 'N/A')}")

            if eval_results:
                current_best_fitness = max(r['fitness'] for r in eval_results)

                print(f"  Current generation best fitness: {current_best_fitness:.4f}")
                print(f"  Global best fitness: {self.global_best_fitness:.4f} (Generation {self.global_best_generation})")
                print(f"  Global best AUC: {self.global_best_metrics.get('auc', 'N/A')}, ACC: {self.global_best_metrics.get('acc', 'N/A')}")

            offspring = toolbox.select(population, len(population))
            offspring = list(map(toolbox.clone, offspring))

            crossover_count = 0
            for child1, child2 in zip(offspring[::2], offspring[1::2]):
                if random.random() < self.config.ga_crossover_rate:
                    new_child1, new_child2 = toolbox.mate(child1, child2)
                    child1[0], child1[1] = new_child1[0], new_child1[1]
                    child2[0], child2[1] = new_child2[0], new_child2[1]
                    del child1.fitness.values, child2.fitness.values
                    crossover_count += 2

            mutation_count = 0
            for mutant in offspring:
                if random.random() < self.mutation_rate:
                    toolbox.mutate(mutant)
                    del mutant.fitness.values
                    mutation_count += 1

            population[:] = offspring

            print(f"  Genetic operations: {crossover_count} crossovers, {mutation_count} mutations")

        return {
            'best_matrix': self.global_best_individual[0],
            'best_groups': self.global_best_individual[1],
            'best_fitness': self.global_best_fitness,
            'best_generation': self.global_best_generation,
            'best_metrics': self.global_best_metrics,
            'mutation_stats': self.mutation_stats
        }


# ================================
# Result saving function
# ================================
def save_results_to_txt(results, total_time, config, feature_names):
    filename = "heart_disease_fgeas_hkan_results_english.txt"

    with open(filename, "w", encoding="utf-8") as f:
        f.write("="*80 + "\n")
        f.write("[TROPHY] Heart Disease FG-EAS HKAN Binary Classification Experiment Results Report\n")
        f.write("Factor Guided Evolutionary Architecture Search + Dual Regularization (Heart Disease Classification)\n")
        f.write("="*80 + "\n\n")

        f.write(f"Experiment time: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Total time: {total_time:.1f} seconds\n\n")

        f.write("--- Unified Fitness Configuration ---\n")
        f.write(f"Performance weight: {config.performance_weight:.2f}\n")
        f.write(f"Quality weight: {config.quality_weight:.2f}\n")
        f.write(f"Fitness function: {config.performance_weight:.2f} * AUC + {config.quality_weight:.2f} * FQS\n\n")

        f.write("--- Global Best Performance Metrics ---\n")
        best_metrics = results.get('best_metrics', {})
        f.write(f"Global best fitness: {results.get('best_fitness', -1):.4f}\n")
        f.write(f"Discovery generation: Generation {results.get('best_generation', 0)}\n")
        f.write(f"Best AUC score: {best_metrics.get('auc', 'N/A')}\n")
        f.write(f"Best accuracy: {best_metrics.get('acc', 'N/A')}\n")
        f.write(f"Best FQS score: {best_metrics.get('fqs', 'N/A')}\n")
        f.write(f"Model parameters: {best_metrics.get('params', 'N/A')}\n\n")

        f.write("--- Global Best Feature Grouping Strategy ---\n")
        matrix, groups = results['best_matrix'], results['best_groups']
        f.write(f"Total {len(groups)} groups:\n")
        for group_id in sorted(groups):
            group_features = [feature_names[i] for i in range(len(feature_names)) if matrix[i][group_id] == 1]
            f.write(f"  Group {group_id}: {group_features}\n")
        f.write("\n")

        f.write("--- Structured Mutation Operator Usage Statistics ---\n")
        mutation_stats = results.get('mutation_stats', {})
        total_mutations = sum(mutation_stats.values())
        if total_mutations > 0:
            for op, count in mutation_stats.items():
                percentage = (count / total_mutations) * 100
                f.write(f"  {op:<15}: {count:4d} times ({percentage:5.1f}%)\n")
        f.write("\n")

        f.write("="*80 + "\n")
        f.write("Experiment Configuration Parameters\n")
        f.write("="*80 + "\n")
        for key, value in config.__dict__.items():
            f.write(f"{key}: {value}\n")

    print(f"\nExperiment results saved to file: {filename}")


# ================================
# Main experiment function
# ================================
def run_heart_disease_fgeas_hkan_experiment():
    print("="*80)
    print("[TROPHY] Heart Disease FG-EAS HKAN Binary Classification Experiment")
    print("Factor Guided Evolutionary Architecture Search + Dual Regularization (Heart Disease Classification)")
    print("[ROCKET] Using official KAN library implementation")
    print("="*80)

    config = HeartDiseaseHKANConfig()
    print(f"Configuration: {config.description}")

    # Get configuration and load data using new modules
    global_config = get_config()
    set_seed(global_config.seed)

    X_train, y_train, feature_names = load_heart_disease_data_for_ea(global_config.data_path)
    print(f"Binary classification task: Predict heart disease risk")

    evaluator = HKANUnifiedClassificationEvaluator(X_train, y_train, feature_names, global_config)
    ga = FactorGuidedGeneticAlgorithm(evaluator, len(feature_names), global_config)

    start_time = time.time()
    print(f"\nStarting Heart Disease FG-EAS...")
    results = ga.run_evolution()
    total_time = time.time() - start_time

    print(f"\n{'='*80}")
    print("[TROPHY] Heart Disease FG-EAS HKAN Binary Classification Experiment Results")
    print(f"{'='*80}")

    print(f"Total experiment time: {total_time:.1f} seconds")

    best_metrics = results.get('best_metrics', {})
    print(f"\nGlobal best results:")
    print(f"  Fitness: {results.get('best_fitness', -1):.4f}")
    print(f"  Discovery generation: Generation {results.get('best_generation', 0)}")
    print(f"  AUC score: {best_metrics.get('auc', 'N/A')}")
    print(f"  Accuracy: {best_metrics.get('acc', 'N/A')}")
    print(f"  FQS score: {best_metrics.get('fqs', 'N/A')}")
    print(f"  Parameters: {best_metrics.get('params', 'N/A')}")

    matrix, groups = results['best_matrix'], results['best_groups']
    print(f"\nGlobal best feature grouping configuration ({len(groups)} groups):")
    for group_id in sorted(groups):
        group_features = [feature_names[i] for i in range(len(feature_names)) if matrix[i][group_id] == 1]
        print(f"  Group {group_id}: {group_features}")

    save_results_to_txt(results, total_time, config, feature_names)

    print(f"\nHeart Disease FG-EAS HKAN classification experiment completed!")
    print(f"Core features:")
    print(f"  Task type: Binary classification (Heart Disease)")
    print(f"  Loss function: BCEWithLogitsLoss")
    print(f"  Evaluation metrics: AUC + Accuracy")
    print(f"  Fitness: {config.performance_weight:.2f} * AUC + {config.quality_weight:.2f} * FQS")
    print(f"  Architecture search: Factor guided evolutionary algorithm")
    print(f"  Model architecture: H-KAN (Official KAN library)")


# ================================
# Main Function
# ================================
if __name__ == "__main__":
    np.random.seed(42)
    random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

    run_heart_disease_fgeas_hkan_experiment()