import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras import layers, Model
from keras.applications import InceptionV3, VGG16, ResNet50
import pandas as pd
from tqdm import tqdm
import gc


# GPU Memory Configuration
def configure_gpu_memory():
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            print(f"Configured {len(gpus)} GPU(s) with memory growth enabled")
        except RuntimeError as e:
            print(f"GPU configuration error: {e}")


configure_gpu_memory()

# ------------------ Hyperparameters ------------------
BATCH_SIZE = 16
IMAGE_SIZE = 160
NUM_CLASSES = 10
MIXUP_ALPHA = 0.2

# Training Phase Hyperparameters
PHASE1_EPOCHS = 5
TOTAL_EPOCHS_STANDARD = 20
CUSTOM_TRAIN_EPOCHS_FROM_PHASE1 = 20
CUSTOM_REFINE_EPOCHS_ON_BEST = 20

# Learning Rates
BASE_LR = 0.001
FINETUNE_LR = 1e-4
CUSTOM_TRAIN_LR = 0.0001
CUSTOM_REFINE_LR = 1e-5
WEIGHT_DECAY = 0.01

# Enhanced Custom Loss Hyperparameters
INITIAL_LOSS_ALPHA = 0.5
PROTOTYPE_MOMENTUM = 0.9
PROTOTYPE_TEMPERATURE = 10.0
PROTOTYPE_REG_WEIGHT = 0.05
MIN_SAMPLES_FOR_PROTOTYPE = 50
ALPHA_DECAY_RATE = 0.98
PROTOTYPE_UPDATE_FREQUENCY = 1

NUM_IMAGES_FOR_PROTO = 100
PROTO_SUBSET_SIZE = 10000

# ------------------ Model Configurations ------------------
EXPERIMENT_CONFIGS = {
    "vgg16": {"feature_dim": 512},
    "resnet50": {"feature_dim": 2048},
    "inceptionv3": {"feature_dim": 2048},
}


# ------------------ Enhanced Prototype Management ------------------
class PrototypeManager:
    def __init__(self, num_classes, feature_dim, momentum=PROTOTYPE_MOMENTUM, min_samples=MIN_SAMPLES_FOR_PROTOTYPE):
        self.num_classes = num_classes
        self.feature_dim = feature_dim
        self.momentum = momentum
        self.min_samples = min_samples

        with tf.device('/CPU:0'):
            self.prototypes = tf.Variable(
                tf.random.normal((num_classes, feature_dim), stddev=0.1),
                trainable=False,
                name="prototypes"
            )
            self.initialized = tf.Variable(
                tf.zeros(num_classes, dtype=tf.bool),
                trainable=False,
                name="prototype_initialized"
            )
            self.update_count = tf.Variable(
                tf.zeros(num_classes, dtype=tf.int32),
                trainable=False,
                name="update_count"
            )

            # Class balancing components
            self.class_sample_counts = tf.Variable(
                tf.zeros(num_classes, dtype=tf.int32),
                trainable=False,
                name="class_sample_counts"
            )
            self.fallback_prototypes = tf.Variable(
                tf.random.normal((num_classes, feature_dim), stddev=0.1),
                trainable=False,
                name="fallback_prototypes"
            )


    @tf.autograph.experimental.do_not_convert
    def update_prototypes_batch(self, features, labels, predictions):
        features_np = features.numpy()
        labels_np = labels.numpy()
        predictions_np = predictions.numpy()

        features_norm = features_np / (np.linalg.norm(features_np, axis=1, keepdims=True) + 1e-8)
        pred_classes = np.argmax(predictions_np, axis=1)
        true_classes = np.argmax(labels_np, axis=1)
        correct_mask = (pred_classes == true_classes)

        for c in range(self.num_classes):
            class_mask = (true_classes == c)
            valid_mask = class_mask & correct_mask
            valid_count = np.sum(valid_mask)

            current_sample_count = self.class_sample_counts[c].numpy()
            self.class_sample_counts[c].assign(current_sample_count + valid_count)

            if valid_count >= self.min_samples:
                class_features = features_norm[valid_mask]

                class_confidences = predictions_np[valid_mask, c]
                weights = class_confidences.reshape(-1, 1)
                weighted_features = class_features * weights

                new_prototype = np.sum(weighted_features, axis=0) / np.sum(weights)

                if self.initialized[c].numpy():
                    current_proto = self.prototypes[c].numpy()
                    updated_prototype = self.momentum * current_proto + (1 - self.momentum) * new_prototype
                    self.prototypes[c].assign(updated_prototype)
                else:
                    self.prototypes[c].assign(new_prototype)
                    self.initialized[c].assign(True)

                current_count = self.update_count[c].numpy()
                self.update_count[c].assign(current_count + 1)

            elif valid_count > 0:
                class_features = features_norm[valid_mask]
                fallback_prototype = np.mean(class_features, axis=0)

                if self.initialized[c].numpy():
                    current_fallback = self.fallback_prototypes[c].numpy()
                    updated_fallback = 0.7 * current_fallback + 0.3 * fallback_prototype
                    self.fallback_prototypes[c].assign(updated_fallback)
                else:
                    self.fallback_prototypes[c].assign(fallback_prototype)

    def balance_prototypes_across_classes(self):
        update_counts = self.update_count
        initialized = self.initialized
        min_updates = tf.reduce_min(update_counts)

        needs_update_mask = tf.logical_or(
            tf.logical_not(initialized),
            tf.less(update_counts, min_updates + 1)
        )

        if tf.reduce_any(initialized):
            indices_to_update = tf.where(needs_update_mask)

            if tf.shape(indices_to_update)[0] > 0:
                fallbacks_to_assign = tf.gather_nd(self.fallback_prototypes, indices_to_update)
                self.prototypes.scatter_nd_update(indices_to_update, fallbacks_to_assign)

                updates_for_initialized = tf.fill(tf.shape(indices_to_update)[0:1], True)
                self.initialized.scatter_nd_update(indices_to_update, updates_for_initialized)

    def get_balanced_prototype_similarity(self, features, temperature=PROTOTYPE_TEMPERATURE):
        self.balance_prototypes_across_classes()

        features_norm = tf.nn.l2_normalize(features, axis=1)
        prototypes_norm = tf.nn.l2_normalize(self.prototypes, axis=1)

        logits = tf.matmul(features_norm, prototypes_norm, transpose_b=True) / temperature

        total_samples = tf.cast(tf.reduce_sum(self.class_sample_counts), dtype=tf.float32) + 1e-8
        class_frequencies = tf.cast(self.class_sample_counts, tf.float32) / total_samples

        inv_freq_weights = 1.0 / (class_frequencies + 1e-8)
        inv_freq_weights = inv_freq_weights / tf.reduce_mean(inv_freq_weights)  # Normalize

        weighted_logits = logits * tf.expand_dims(inv_freq_weights, 0)

        return tf.nn.softmax(weighted_logits)

    @tf.function
    def get_prototype_consistency_loss(self, features, true_labels):
        features_norm = tf.nn.l2_normalize(features, axis=1)
        prototypes_norm = tf.nn.l2_normalize(self.prototypes, axis=1)

        true_class_indices = tf.argmax(true_labels, axis=1)
        target_prototypes = tf.gather(prototypes_norm, true_class_indices)

        consistency_loss = tf.reduce_mean(tf.square(features_norm - target_prototypes))
        return consistency_loss

    def get_prototype_similarity(self, features, temperature=PROTOTYPE_TEMPERATURE):
        return self.get_balanced_prototype_similarity(features, temperature)


class AdaptiveLossWeight:
    def __init__(self, initial_alpha=INITIAL_LOSS_ALPHA, decay_rate=ALPHA_DECAY_RATE):
        self.alpha = tf.Variable(initial_alpha, trainable=False, name="adaptive_alpha")
        self.decay_rate = decay_rate
        self.min_alpha = 0.1
        self.max_alpha = 0.9

    def update_alpha(self, epoch, prototype_confidence=None):
        epoch_factor = tf.pow(self.decay_rate, tf.cast(epoch, tf.float32))

        if prototype_confidence is not None:
            confidence_factor = tf.reduce_max(prototype_confidence)
            new_alpha = self.alpha * epoch_factor * (2.0 - confidence_factor)
        else:
            new_alpha = self.alpha * epoch_factor

        new_alpha = tf.clip_by_value(new_alpha, self.min_alpha, self.max_alpha)
        self.alpha.assign(new_alpha)
        return new_alpha


# ------------------ Balanced Dataset Creation ------------------
def create_balanced_prototype_subset(x_train, y_train, subset_size=PROTO_SUBSET_SIZE):
    samples_per_class = max(subset_size // NUM_CLASSES, 10)  # At least 10 samples per class

    balanced_indices = []

    for class_id in range(NUM_CLASSES):
        class_indices = np.where(y_train == class_id)[0]

        if len(class_indices) >= samples_per_class:
            selected = np.random.choice(class_indices, samples_per_class, replace=False)
        else:
            selected = np.random.choice(class_indices, samples_per_class, replace=True)

        balanced_indices.extend(selected)

    balanced_indices = np.array(balanced_indices)
    np.random.shuffle(balanced_indices)

    return x_train[balanced_indices], y_train[balanced_indices], balanced_indices


def create_stratified_dataset_pipeline(x, y, model_name, training=True, oversample_rare=False):
    if training and oversample_rare:
        unique_classes, class_counts = np.unique(y, return_counts=True)
        max_count = np.max(class_counts)

        oversampled_x, oversampled_y = [], []

        for class_id, count in zip(unique_classes, class_counts):
            class_mask = (y == class_id)
            class_x = x[class_mask]
            class_y = y[class_mask]

            oversample_ratio = max_count // count
            remainder = max_count % count

            oversampled_x.append(class_x)
            oversampled_y.append(class_y)

            if oversample_ratio > 1:
                for _ in range(oversample_ratio - 1):
                    oversampled_x.append(class_x)
                    oversampled_y.append(class_y)

            if remainder > 0:
                indices = np.random.choice(len(class_x), remainder, replace=True)
                oversampled_x.append(class_x[indices])
                oversampled_y.append(class_y[indices])

        x = np.concatenate(oversampled_x, axis=0)
        y = np.concatenate(oversampled_y, axis=0)

        shuffle_idx = np.random.permutation(len(x))
        x, y = x[shuffle_idx], y[shuffle_idx]

    return create_dataset_pipeline(x, y, model_name, training)


# ------------------ Data Loading & Preprocessing ------------------
def load_cifar10():
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
    return (x_train, y_train.flatten()), (x_test, y_test.flatten())


def preprocess_image(image, label, model_name):
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, (IMAGE_SIZE, IMAGE_SIZE))

    if model_name == 'inceptionv3':
        image = tf.keras.applications.inception_v3.preprocess_input(image)
    elif model_name == 'vgg16':
        image = tf.keras.applications.vgg16.preprocess_input(image)
    elif model_name == 'resnet50':
        image = tf.keras.applications.resnet50.preprocess_input(image)

    return image, label


def preprocess_batch_for_model(images, model_name):
    images = tf.cast(images, tf.float32)
    images = tf.image.resize(images, (IMAGE_SIZE, IMAGE_SIZE))

    if model_name == 'inceptionv3':
        images = tf.keras.applications.inception_v3.preprocess_input(images)
    elif model_name == 'vgg16':
        images = tf.keras.applications.vgg16.preprocess_input(images)
    elif model_name == 'resnet50':
        images = tf.keras.applications.resnet50.preprocess_input(images)

    return images


@tf.function
def mixup(images, labels):
    batch_size = tf.shape(images)[0]
    lam = tf.random.uniform(shape=[], minval=0, maxval=MIXUP_ALPHA)
    idx = tf.random.shuffle(tf.range(batch_size))

    mixed_images = lam * images + (1 - lam) * tf.gather(images, idx)
    mixed_labels = lam * labels + (1 - lam) * tf.gather(labels, idx)
    return mixed_images, mixed_labels


def create_dataset_pipeline(x, y, model_name, training=True):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if training:
        ds = ds.shuffle(len(x))

    ds = ds.map(lambda img, lbl: preprocess_image(img, lbl, model_name), num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(BATCH_SIZE)
    ds = ds.map(lambda img, lbl: (img, tf.one_hot(lbl, NUM_CLASSES)), num_parallel_calls=tf.data.AUTOTUNE)

    if training and MIXUP_ALPHA > 0:
        ds = ds.map(mixup, num_parallel_calls=tf.data.AUTOTUNE)

    return ds.prefetch(tf.data.AUTOTUNE)


# ------------------ Enhanced Model Definition ------------------
def build_model(model_name, multi_scale=False):
    base_model_name = f"base_{model_name}"

    if model_name == 'inceptionv3':
        base_model = InceptionV3(include_top=False, weights='imagenet',
                                 input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), pooling='avg', name=base_model_name)
    elif model_name == 'vgg16':
        base_model = VGG16(include_top=False, weights='imagenet',
                           input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), pooling='avg', name=base_model_name)
    elif model_name == 'resnet50':
        base_model = ResNet50(include_top=False, weights='imagenet',
                              input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), pooling='avg', name=base_model_name)
    else:
        raise ValueError(f"Unknown model name: {model_name}")

    inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
    features = base_model(inputs, training=False)

    if multi_scale:
        x = layers.Dropout(0.3, name="top_dropout")(features)
    else:
        x = layers.Dropout(0.3, name="top_dropout")(features)

    outputs = layers.Dense(NUM_CLASSES, activation='softmax', name="classifier")(x)
    model = Model(inputs, outputs)
    return model


# ------------------ Standard Experiment (Unchanged) ------------------
def run_standard_experiment(x_train, y_train, x_test, y_test, model_name):
    print(f"\n--- Starting Standard Training: {model_name.upper()} ---")
    tf.keras.backend.clear_session()

    # Create directories for this specific model run
    std_phase1_ckpt_dir = f"cifar10_{model_name}_std_phase1_ckpt"
    std_ckpt_dir = f"cifar10_{model_name}_std_ckpt"
    os.makedirs(std_phase1_ckpt_dir, exist_ok=True)
    os.makedirs(std_ckpt_dir, exist_ok=True)
    phase1_model_path = os.path.join(std_phase1_ckpt_dir, "phase1_model.keras")
    best_model_path = os.path.join(std_ckpt_dir, "best_model.keras")

    train_ds = create_dataset_pipeline(x_train, y_train, model_name, training=True)
    val_ds = create_dataset_pipeline(x_test, y_test, model_name, training=False)

    model = build_model(model_name)
    base_model_layer = model.get_layer(f"base_{model_name}")

    print("\n--- Phase 1: Training the head (backbone frozen) ---")
    base_model_layer.trainable = False
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=BASE_LR),
                  loss='categorical_crossentropy', metrics=['accuracy'])
    model.fit(train_ds, validation_data=val_ds, epochs=PHASE1_EPOCHS, verbose=2)
    print(f"Saving phase 1 model to {phase1_model_path}")
    model.save(phase1_model_path)

    print("\n--- Phase 2: Fine-tuning the entire model ---")
    base_model_layer.trainable = True
    model.compile(optimizer=keras.optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=FINETUNE_LR),
                  loss='categorical_crossentropy', metrics=['accuracy'])

    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(best_model_path, monitor="val_accuracy", save_best_only=True)
    early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True)

    history = model.fit(train_ds, validation_data=val_ds,
                        epochs=TOTAL_EPOCHS_STANDARD, initial_epoch=PHASE1_EPOCHS,
                        callbacks=[checkpoint_cb, early_stop], verbose=2)

    best_val_acc = max(history.history.get('val_accuracy', [0]))
    print(f"Standard Model Best Validation Accuracy: {best_val_acc:.4f}")
    return best_val_acc, phase1_model_path, best_model_path


# ------------------ Enhanced Custom Training ------------------
def run_enhanced_custom_training_from_checkpoint(initial_model_path, epochs_to_run, learning_rate, checkpoint_dir,
                                                 x_train, y_train, x_test, y_test, model_name):
    print(f"\n--- Starting Enhanced Custom Loss Training from {os.path.basename(initial_model_path)} ---")

    tf.keras.backend.clear_session()
    gc.collect()

    os.makedirs(checkpoint_dir, exist_ok=True)
    best_model_path = os.path.join(checkpoint_dir, "best_model.keras")

    train_ds = create_dataset_pipeline(x_train, y_train, model_name, training=True)
    val_ds = create_dataset_pipeline(x_test, y_test, model_name, training=False)

    loaded_model = keras.models.load_model(initial_model_path)

    inputs = loaded_model.input
    base_features = loaded_model.get_layer("top_dropout").input
    predictions = loaded_model.output

    feature_extractor = Model(inputs=inputs, outputs=base_features)
    predictor = Model(inputs=inputs, outputs=predictions)

    feature_extractor.trainable = True
    predictor.trainable = True

    feature_dim = EXPERIMENT_CONFIGS[model_name]["feature_dim"]
    prototype_manager = PrototypeManager(NUM_CLASSES, feature_dim)
    adaptive_weight = AdaptiveLossWeight()

    optimizer = keras.optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=learning_rate)
    loss_fn = tf.keras.losses.CategoricalCrossentropy()
    train_acc_metric = keras.metrics.CategoricalAccuracy()
    val_acc_metric = keras.metrics.CategoricalAccuracy()

    @tf.function
    def memory_efficient_train_step(x, y):
        with tf.GradientTape(persistent=False) as tape:
            features = feature_extractor(x, training=True)
            y_pred = predictor(x, training=True)

            y_proto_sim = prototype_manager.get_prototype_similarity(features)

            alpha = adaptive_weight.alpha
            y_combined = alpha * y_pred + (1 - alpha) * y_proto_sim

            classification_loss = loss_fn(y, y_combined)

            consistency_loss = prototype_manager.get_prototype_consistency_loss(features, y)

            total_loss = classification_loss + PROTOTYPE_REG_WEIGHT * consistency_loss

        trainable_vars = feature_extractor.trainable_variables + predictor.trainable_variables
        grads = tape.gradient(total_loss, trainable_vars)

        optimizer.apply_gradients(zip(grads, trainable_vars))

        train_acc_metric.update_state(y, y_pred)
        return total_loss, classification_loss, consistency_loss, y_proto_sim

    @tf.function
    def val_step(x, y):
        y_pred = predictor(x, training=False)
        val_acc_metric.update_state(y, y_pred)
        return y_pred

    best_val_acc = 0.0

    print(f"Starting training with batch size: {BATCH_SIZE}")

    for epoch in range(epochs_to_run):
        print(f"\nEpoch {epoch + 1}/{epochs_to_run}")

        train_acc_metric.reset_state()
        epoch_losses = []
        epoch_cls_losses = []
        epoch_consistency_losses = []
        prototype_confidences = []

        for step, (x_batch, y_batch) in enumerate(tqdm(train_ds, desc="Training")):
            try:
                total_loss, cls_loss, cons_loss, proto_sim = memory_efficient_train_step(x_batch, y_batch)

                epoch_losses.append(total_loss.numpy())
                epoch_cls_losses.append(cls_loss.numpy())
                epoch_consistency_losses.append(cons_loss.numpy())
                prototype_confidences.append(proto_sim.numpy())

                if step % PROTOTYPE_UPDATE_FREQUENCY == 0:
                    features_for_proto = feature_extractor(x_batch, training=False)
                    y_pred_for_proto = predictor(x_batch, training=False)

                    prototype_manager.update_prototypes_batch(features_for_proto, y_batch, y_pred_for_proto)

                if step % 100 == 0:
                    gc.collect()

            except tf.errors.ResourceExhaustedError as e:
                print(f"Memory error at step {step}. Trying to continue...")
                gc.collect()
                tf.keras.backend.clear_session()
                continue

        if epoch_losses:
            avg_total_loss = np.mean(epoch_losses)
            avg_cls_loss = np.mean(epoch_cls_losses)
            avg_cons_loss = np.mean(epoch_consistency_losses)

            if prototype_confidences:
                avg_proto_confidence = np.mean(np.concatenate(prototype_confidences, axis=0), axis=0)
                avg_proto_confidence_tensor = tf.constant(avg_proto_confidence)
                new_alpha = adaptive_weight.update_alpha(epoch, avg_proto_confidence_tensor)
            else:
                new_alpha = adaptive_weight.update_alpha(epoch)
        else:
            print("Warning: No successful training steps in this epoch")
            continue

        val_acc_metric.reset_state()
        try:
            for x_batch, y_batch in tqdm(val_ds, desc="Validating"):
                val_step(x_batch, y_batch)
            val_acc = val_acc_metric.result().numpy()
        except tf.errors.ResourceExhaustedError:
            print("Memory error during validation. Skipping validation for this epoch.")
            val_acc = 0.0

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            try:
                predictor.save(best_model_path)
            except Exception as e:
                print(f"Could not save model: {e}")

        initialized_count = tf.reduce_sum(tf.cast(prototype_manager.initialized, tf.int32)).numpy()
        min_updates = tf.reduce_min(prototype_manager.update_count).numpy()
        max_updates = tf.reduce_max(prototype_manager.update_count).numpy()

        print(f"Train Acc: {train_acc_metric.result():.4f}, Val Acc: {val_acc:.4f} (Best: {best_val_acc:.4f})")
        print(
            f"Losses - Total: {avg_total_loss:.4f}, Classification: {avg_cls_loss:.4f}, Consistency: {avg_cons_loss:.4f}")
        print(
            f"Alpha: {new_alpha:.4f}, Prototypes: {initialized_count}/{NUM_CLASSES}, Updates: {min_updates}-{max_updates}")

        if max_updates > min_updates * 3:
            print(f"⚠️  Class imbalance detected! Some classes updated {max_updates}x more than others.")

        gc.collect()

    return best_val_acc


# ------------------ Main Execution ------------------
if __name__ == "__main__":
    (x_train, y_train), (x_test, y_test) = load_cifar10()

    print(f"Dataset info: {len(x_train)} training samples, {len(np.unique(y_train))} classes")

    unique_classes, class_counts = np.unique(y_train, return_counts=True)
    print(
        f"Class balance - Min: {np.min(class_counts)}, Max: {np.max(class_counts)}, Mean: {np.mean(class_counts):.1f}")

    all_results = []

    for model_name in EXPERIMENT_CONFIGS.keys():
        print(f"\n{'=' * 25}\nStarting Experiment for: {model_name.upper()} on CIFAR-10\n{'=' * 25}")

        acc_standard, phase1_path, best_std_path = run_standard_experiment(x_train, y_train, x_test, y_test, model_name)

        acc_custom_from_phase1 = run_enhanced_custom_training_from_checkpoint(
            initial_model_path=phase1_path,
            epochs_to_run=CUSTOM_TRAIN_EPOCHS_FROM_PHASE1,
            learning_rate=CUSTOM_TRAIN_LR,
            checkpoint_dir=f"cifar10_{model_name}_enhanced_custom_from_phase1_ckpt",
            x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test,
            model_name=model_name
        )

        acc_custom_on_best_std = run_enhanced_custom_training_from_checkpoint(
            initial_model_path=best_std_path,
            epochs_to_run=CUSTOM_REFINE_EPOCHS_ON_BEST,
            learning_rate=CUSTOM_REFINE_LR,
            checkpoint_dir=f"cifar10_{model_name}_enhanced_custom_on_best_ckpt",
            x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test,
            model_name=model_name
        )

        all_results.append({
            'model_name': model_name,
            'standard_accuracy': acc_standard,
            'enhanced_custom_from_phase1_accuracy': acc_custom_from_phase1,
            'enhanced_custom_on_best_standard_accuracy': acc_custom_on_best_std
        })

    results_df = pd.DataFrame(all_results)
    results_file = "cifar10_enhanced_models_comparison.csv"
    results_df.to_csv(results_file, index=False)

    print(f"\n\n{'=' * 20} FINAL ENHANCED RESULTS {'=' * 20}")
    print(results_df)
    print(f"\nResults saved to {results_file}")