import os

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import math
import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras import layers, Model
from keras.applications import InceptionV3, VGG16, ResNet50
import pandas as pd
from tqdm import tqdm
import gc

def configure_gpu_memory():
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            print(f"Configured {len(gpus)} GPU(s) with memory growth enabled")
        except RuntimeError as e:
            print(f"GPU configuration error: {e}")


configure_gpu_memory()

# ------------------ Hyperparameters ------------------
BATCH_SIZE = 16
IMAGE_SIZE = 160
NUM_CLASSES = 10
MIXUP_ALPHA = 0.2
PROTO_SUBSET_SIZE = 10000

# Training Phase Hyperparameters
PHASE1_EPOCHS = 10
TOTAL_EPOCHS_STANDARD = 20
CUSTOM_REFINE_EPOCHS_ON_BEST = 20

# Learning Rates
BASE_LR = 0.001
FINETUNE_LR = 1e-4
CUSTOM_REFINE_LR = 1e-5
WEIGHT_DECAY = 0.01

INITIAL_LOSS_ALPHA = 0.5
PROTOTYPE_TEMPERATURE = 10.0
PROTOTYPE_REG_WEIGHT = 0.05
ALPHA_DECAY_RATE = 0.98
NUM_IMAGES_FOR_PROTO = 50

# ------------------ Model Configurations ------------------
EXPERIMENT_CONFIGS = {
    "inceptionv3": {"feature_dim": 2048},
    "vgg16": {"feature_dim": 512},
    "resnet50": {"feature_dim": 2048},
}


class AdaptiveLossWeight:
    def __init__(self, initial_alpha=INITIAL_LOSS_ALPHA, decay_rate=ALPHA_DECAY_RATE):
        self.alpha = tf.Variable(initial_alpha, trainable=False, name="adaptive_alpha")
        self.decay_rate = decay_rate
        self.min_alpha, self.max_alpha = 0.1, 0.9

    def update_alpha(self, epoch, prototype_confidence=None):
        epoch_factor = tf.pow(self.decay_rate, tf.cast(epoch, tf.float32))
        if prototype_confidence is not None:
            confidence_factor = tf.reduce_max(prototype_confidence)
            new_alpha = self.alpha * epoch_factor * (2.0 - confidence_factor)
        else:
            new_alpha = self.alpha * epoch_factor
        new_alpha = tf.clip_by_value(new_alpha, self.min_alpha, self.max_alpha)
        self.alpha.assign(new_alpha)
        return new_alpha


def load_cifar10():
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar00.load_data()
    return (x_train, y_train.flatten()), (x_test, y_test.flatten())

def preprocess_image(image, label, model_name):
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
    if model_name == 'inceptionv3':
        image = tf.keras.applications.inception_v3.preprocess_input(image)
    elif model_name == 'vgg16':
        image = tf.keras.applications.vgg16.preprocess_input(image)
    elif model_name == 'resnet50':
        image = tf.keras.applications.resnet50.preprocess_input(image)
    return image, label

def preprocess_batch_for_model(images, model_name):
    images = tf.cast(images, tf.float32)
    images = tf.image.resize(images, (IMAGE_SIZE, IMAGE_SIZE))
    if model_name == 'inceptionv3':
        return tf.keras.applications.inception_v3.preprocess_input(images)
    elif model_name == 'vgg16':
        return tf.keras.applications.vgg16.preprocess_input(images)
    elif model_name == 'resnet50':
        return tf.keras.applications.resnet50.preprocess_input(images)


def create_dataset_pipeline(x, y, model_name, training=True):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if training:
        ds = ds.shuffle(len(x))
    ds = ds.map(lambda img, lbl: preprocess_image(img, lbl, model_name), num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(BATCH_SIZE).map(lambda img, lbl: (img, tf.one_hot(lbl, NUM_CLASSES)),
                                  num_parallel_calls=tf.data.AUTOTUNE)
    if training and MIXUP_ALPHA > 0:
        ds = ds.map(mixup, num_parallel_calls=tf.data.AUTOTUNE)
    return ds.prefetch(tf.data.AUTOTUNE)


@tf.function
def mixup(images, labels):
    lam = tf.random.uniform(shape=[], maxval=MIXUP_ALPHA)
    idx = tf.random.shuffle(tf.range(tf.shape(images)[0]))
    return lam * images + (1 - lam) * tf.gather(images, idx), lam * labels + (1 - lam) * tf.gather(labels, idx)


# ------------------ Model Definition ------------------
def build_model(model_name):
    if model_name == 'inceptionv3':
        base_model = InceptionV3(include_top=False, weights='imagenet', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
                                 pooling='avg')
    elif model_name == 'vgg16':
        base_model = VGG16(include_top=False, weights='imagenet', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
                           pooling='avg')
    elif model_name == 'resnet50':
        base_model = ResNet50(include_top=False, weights='imagenet', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
                              pooling='avg')
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
    features = base_model(inputs, training=False)
    x = layers.Dropout(0.3, name="top_dropout")(features)
    outputs = layers.Dense(NUM_CLASSES, activation='softmax', name="classifier")(x)
    return Model(inputs, outputs)


def calculate_prototypes(feature_extractor, predictor, x_subset, y_subset, model_name):
    print("Calculating prototypes for the new epoch...")
    x_subset_preprocessed = preprocess_batch_for_model(x_subset, model_name)

    features = feature_extractor.predict(x_subset_preprocessed, batch_size=BATCH_SIZE, verbose=0)
    predictions = predictor.predict(x_subset_preprocessed, batch_size=BATCH_SIZE, verbose=0)

    pred_classes = np.argmax(predictions, axis=1)

    proto_matrix = np.zeros((NUM_CLASSES, features.shape[-1]), dtype=np.float32)

    for c in range(NUM_CLASSES):
        class_mask = (y_subset == c)
        correct_pred_mask = (pred_classes == c)
        valid_indices = np.where(class_mask & correct_pred_mask)[0]

        if len(valid_indices) > 0:
            class_features = features[valid_indices]
            class_confidences = predictions[valid_indices, c]

            top_indices = np.argsort(-class_confidences)[:NUM_IMAGES_FOR_PROTO]

            proto_matrix[c] = np.mean(class_features[top_indices], axis=0)
        else:
            class_indices = np.where(class_mask)[0]
            if len(class_indices) > 0:
                proto_matrix[c] = np.mean(features[class_indices], axis=0)

    return tf.constant(proto_matrix)


# ------------------ Standard Experiment ------------------
def run_standard_experiment(x_train, y_train, x_test, y_test, model_name):
    print(f"\n--- Starting Standard Training: {model_name.upper()} ---")
    tf.keras.backend.clear_session()
    phase1_model_path = f"cifar10_{model_name}_std_phase1.keras"
    best_model_path = f"cifar10_{model_name}_std_best.keras"
    train_ds = create_dataset_pipeline(x_train, y_train, model_name, training=True)
    val_ds = create_dataset_pipeline(x_test, y_test, model_name, training=False)
    model = build_model(model_name)
    model.layers[1].trainable = False
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=BASE_LR), loss='categorical_crossentropy',
                  metrics=['accuracy'])
    model.fit(train_ds, validation_data=val_ds, epochs=PHASE1_EPOCHS, verbose=2)
    model.save(phase1_model_path)
    model.layers[1].trainable = True
    model.compile(optimizer=keras.optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=FINETUNE_LR),
                  loss='categorical_crossentropy', metrics=['accuracy'])
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(best_model_path, monitor="val_accuracy", save_best_only=True)
    early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True)
    history = model.fit(train_ds, validation_data=val_ds, epochs=TOTAL_EPOCHS_STANDARD, initial_epoch=PHASE1_EPOCHS,
                        callbacks=[checkpoint_cb, early_stop], verbose=2)
    best_val_acc = max(history.history.get('val_accuracy', [0]))
    print(f"Standard Model Best Validation Accuracy: {best_val_acc:.4f}")
    return best_val_acc, best_model_path


def run_enhanced_custom_training_from_checkpoint(initial_model_path, epochs_to_run, learning_rate, checkpoint_dir,
                                                 x_train, y_train, x_test, y_test, x_train_subset, y_train_subset,
                                                 model_name,
                                                 # --- ABLATION STUDY PARAMETERS ---
                                                 use_consistency_loss=True,
                                                 use_adaptive_alpha=True):
    print(f"\n--- Starting Custom Training from {os.path.basename(initial_model_path)} ---")
    print(f"Ablation Settings: Consistency Loss={use_consistency_loss}, Adaptive Alpha={use_adaptive_alpha}")
    tf.keras.backend.clear_session()
    gc.collect()

    os.makedirs(checkpoint_dir, exist_ok=True)
    best_model_path = os.path.join(checkpoint_dir, "best_model.keras")

    train_ds = create_dataset_pipeline(x_train, y_train, model_name, training=True)
    val_ds = create_dataset_pipeline(x_test, y_test, model_name, training=False)

    loaded_model = keras.models.load_model(initial_model_path)
    inputs = loaded_model.input
    base_features = loaded_model.get_layer("top_dropout").input
    predictions = loaded_model.output
    feature_extractor = Model(inputs=inputs, outputs=base_features)
    predictor = Model(inputs=inputs, outputs=predictions)
    feature_extractor.trainable, predictor.trainable = True, True

    adaptive_weight = AdaptiveLossWeight()
    optimizer = keras.optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=learning_rate)
    loss_fn = tf.keras.losses.CategoricalCrossentropy()
    train_acc_metric, val_acc_metric = keras.metrics.CategoricalAccuracy(), keras.metrics.CategoricalAccuracy()

    @tf.function
    def train_step(x, y, prototypes):
        with tf.GradientTape() as tape:
            features = feature_extractor(x, training=True)
            y_pred = predictor(x, training=True)

            # --- Prototype similarity calculation ---
            features_norm = tf.nn.l2_normalize(features, axis=1)
            prototypes_norm = tf.nn.l2_normalize(prototypes, axis=1)
            logits = tf.matmul(features_norm, prototypes_norm, transpose_b=True) / PROTOTYPE_TEMPERATURE
            y_proto_sim = tf.nn.softmax(logits)

            alpha = adaptive_weight.alpha
            y_combined = alpha * y_pred + (1 - alpha) * y_proto_sim
            classification_loss = loss_fn(y, y_combined)

            if use_consistency_loss:
                true_class_indices = tf.argmax(y, axis=1)
                target_prototypes = tf.gather(prototypes_norm, true_class_indices)
                consistency_loss = tf.reduce_mean(tf.square(features_norm - target_prototypes))
                total_loss = classification_loss + PROTOTYPE_REG_WEIGHT * consistency_loss
            else:
                total_loss = classification_loss

        trainable_vars = feature_extractor.trainable_variables + predictor.trainable_variables
        grads = tape.gradient(total_loss, trainable_vars)
        optimizer.apply_gradients(zip(grads, trainable_vars))
        train_acc_metric.update_state(y, y_pred)
        return y_proto_sim

    @tf.function
    def val_step(x, y):
        y_pred = predictor(x, training=False)
        val_acc_metric.update_state(y, y_pred)

    best_val_acc = 0.0
    for epoch in range(epochs_to_run):
        print(f"\nEpoch {epoch + 1}/{epochs_to_run}")

        prototypes = calculate_prototypes(feature_extractor, predictor, x_train_subset, y_train_subset, model_name)

        train_acc_metric.reset_state()
        prototype_confidences = []

        for step, (x_batch, y_batch) in enumerate(tqdm(train_ds, desc="Training")):
            proto_sim = train_step(x_batch, y_batch, prototypes)
            prototype_confidences.append(proto_sim.numpy())

        if use_adaptive_alpha and prototype_confidences:
            avg_proto_confidence = np.mean(np.concatenate(prototype_confidences, axis=0), axis=0)
            new_alpha = adaptive_weight.update_alpha(epoch, tf.constant(avg_proto_confidence))
            print(f"Alpha updated to: {new_alpha:.4f}")

        val_acc_metric.reset_state()
        for x_batch, y_batch in tqdm(val_ds, desc="Validating"): val_step(x_batch, y_batch)
        val_acc = val_acc_metric.result().numpy()

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            predictor.save(best_model_path)

        print(f"Train Acc: {train_acc_metric.result():.4f}, Val Acc: {val_acc:.4f} (Best: {best_val_acc:.4f})")
        gc.collect()

    return best_val_acc


if __name__ == "__main__":
    (x_train, y_train), (x_test, y_test) = load_cifar10()

    idx = np.random.choice(len(x_train), PROTO_SUBSET_SIZE, replace=False)
    x_train_subset, y_train_subset = x_train[idx], y_train[idx]

    all_results = []

    for model_name in EXPERIMENT_CONFIGS.keys():
        print(f"\n{'=' * 25}\nStarting Experiment for: {model_name.upper()} on CIFAR-10\n{'=' * 25}")

        acc_standard, best_std_path = run_standard_experiment(x_train, y_train, x_test, y_test, model_name)

        acc_enhanced_full = run_enhanced_custom_training_from_checkpoint(
            initial_model_path=best_std_path, epochs_to_run=CUSTOM_REFINE_EPOCHS_ON_BEST,
            learning_rate=CUSTOM_REFINE_LR,
            checkpoint_dir=f"cifar10_{model_name}_enhanced_full_ckpt", x_train=x_train, y_train=y_train, x_test=x_test,
            y_test=y_test, x_train_subset=x_train_subset, y_train_subset=y_train_subset, model_name=model_name,
            use_consistency_loss=True, use_adaptive_alpha=True
        )

        # Ablation 1: No Consistency Loss
        acc_ablation_no_consistency = run_enhanced_custom_training_from_checkpoint(
            initial_model_path=best_std_path, epochs_to_run=CUSTOM_REFINE_EPOCHS_ON_BEST,
            learning_rate=CUSTOM_REFINE_LR,
            checkpoint_dir=f"cifar10_{model_name}_ablation_no_consistency_ckpt", x_train=x_train, y_train=y_train,
            x_test=x_test, y_test=y_test, x_train_subset=x_train_subset, y_train_subset=y_train_subset,
            model_name=model_name,
            use_consistency_loss=False, use_adaptive_alpha=True
        )

        # Ablation 2: Fixed Alpha
        acc_ablation_fixed_alpha = run_enhanced_custom_training_from_checkpoint(
            initial_model_path=best_std_path, epochs_to_run=CUSTOM_REFINE_EPOCHS_ON_BEST,
            learning_rate=CUSTOM_REFINE_LR,
            checkpoint_dir=f"cifar10_{model_name}_ablation_fixed_alpha_ckpt", x_train=x_train, y_train=y_train,
            x_test=x_test, y_test=y_test, x_train_subset=x_train_subset, y_train_subset=y_train_subset,
            model_name=model_name,
            use_consistency_loss=True, use_adaptive_alpha=False
        )

        all_results.append({
            'model_name': model_name,
            'standard_accuracy': acc_standard,
            'enhanced_full_accuracy': acc_enhanced_full,
            'ablation_no_consistency_loss': acc_ablation_no_consistency,
            'ablation_fixed_alpha': acc_ablation_fixed_alpha,
        })

    results_df = pd.DataFrame(all_results)
    results_file = "cifar10_ablation_study_results.csv"
    results_df.to_csv(results_file, index=False)

    print(f"\n\n{'=' * 20} FINAL ABLATION RESULTS {'=' * 20}")
    print(results_df)
    print(f"\nResults saved to {results_file}")

