import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import layers, Model
from keras.applications import InceptionV3, VGG16, ResNet50
import matplotlib.pyplot as plt
import cv2
import gc
from tqdm import tqdm
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score
from scipy.spatial.distance import cosine
import umap

input_shape = (160, 160, 3)
batch_size = 16
num_classes = 10

EXPERIMENT_CONFIGS = {
    "inceptionv3": {"feature_dim": 2048},
    "vgg16": {"feature_dim": 512},
    "resnet50": {"feature_dim": 2048},
}

NUM_ANALYSIS_IMAGES = 5
FEW_SHOT_K = 10

print("--- Loading CIFAR-10 dataset ---")
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
y_train, y_test = y_train.flatten(), y_test.flatten()

cifar10_labels = [
    'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck',
]

def preprocess_image(image, label, model_name, is_one_hot=True):
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, (input_shape[0], input_shape[1]))

    if model_name == 'inceptionv3':
        image = tf.keras.applications.inception_v3.preprocess_input(image)
    elif model_name == 'vgg16':
        image = tf.keras.applications.vgg16.preprocess_input(image)
    elif model_name == 'resnet50':
        image = tf.keras.applications.resnet50.preprocess_input(image)
    else:
        raise ValueError(f"Unknown model_name for preprocessing: {model_name}")

    if is_one_hot:
        label = tf.one_hot(label, num_classes)
    return image, label

def create_dataset(x, y, model_name):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    ds = ds.map(lambda img, lbl: preprocess_image(img, lbl, model_name), num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

print("--- Dataset ready ---")

def create_analysis_model(model_name):
    if model_name == 'inceptionv3':
        base_model = InceptionV3(include_top=False, weights=None, input_shape=input_shape, pooling='avg')
    elif model_name == 'vgg16':
        base_model = VGG16(include_top=False, weights=None, input_shape=input_shape, pooling='avg')
    elif model_name == 'resnet50':
        base_model = ResNet50(include_top=False, weights=None, input_shape=input_shape, pooling='avg')
    else:
        raise ValueError(f"Unknown model name: {model_name}")

    inputs = keras.Input(shape=input_shape)
    base_out = base_model(inputs, training=False)
    last_conv_layer = next(layer for layer in reversed(base_model.layers) if isinstance(layer, layers.Conv2D))
    if last_conv_layer is None:
        raise ValueError("Could not find a Conv2D layer in the base model.")

    conv_extractor = Model(base_model.input, last_conv_layer.output)
    conv_activation = conv_extractor(inputs)
    features = base_out
    x = layers.Dropout(0.3, name="top_dropout")(features)
    preds = layers.Dense(num_classes, activation="softmax", name="classifier")(x)
    full_model = Model(inputs, preds, name="full_model")
    analysis_model = Model(inputs, {"preds": preds, "features": features, "conv_activation": conv_activation}, name="analysis_model")
    return full_model, analysis_model

def analyze_feature_space(standard_model, custom_model, model_name, val_dataset, report, output_dir):
    print("\n--- Analyzing Feature Space Geometry ---")
    _, standard_analysis_model = create_analysis_model(model_name)
    standard_analysis_model.set_weights(standard_model.get_weights())
    _, custom_analysis_model = create_analysis_model(model_name)
    custom_analysis_model.set_weights(custom_model.get_weights())

    all_labels, standard_features, custom_features = [], [], []
    for img_batch, lbl_batch in tqdm(val_dataset, desc="Extracting Features", total=len(val_dataset)):
        all_labels.extend(np.argmax(lbl_batch.numpy(), axis=-1))
        feats_s = standard_analysis_model.predict_on_batch(img_batch)["features"]
        feats_c = custom_analysis_model.predict_on_batch(img_batch)["features"]
        standard_features.extend(feats_s)
        custom_features.extend(feats_c)

    all_labels = np.array(all_labels)
    standard_features = np.array(standard_features)
    custom_features = np.array(custom_features)

    def compute_metrics(features, labels):
        present_classes = np.unique(labels)
        if len(present_classes) < 2: return 0., 0., 0.
        centroids = np.array([features[labels == i].mean(axis=0) for i in present_classes])
        intra_dists = [np.mean([cosine(f, centroids[list(present_classes).index(l)]) for f in features[labels == l]]) for l in present_classes]
        inter_dists = [np.mean([cosine(centroids[i], centroids[j]) for j in range(len(present_classes)) if i != j]) for i in range(len(present_classes))]
        return float(np.mean(intra_dists)), float(np.mean(inter_dists)), float(silhouette_score(features, labels, metric='cosine'))

    std_intra, std_inter, std_s_score = compute_metrics(standard_features, all_labels)
    cust_intra, cust_inter, cust_s_score = compute_metrics(custom_features, all_labels)

    report.write("\n### 1. Feature Space Geometry Analysis\n\n")
    report.write("| Metric | Standard Model | Custom Loss Model | Note |\n")
    report.write("|---|---|---|---|\n")
    report.write(f"| Avg. Intra-Class Distance | {std_intra:.4f} | **{cust_intra:.4f}** | Lower is better |\n")
    report.write(f"| Avg. Inter-Class Distance | {std_inter:.4f} | **{cust_inter:.4f}** | Higher is better |\n")
    report.write(f"| Silhouette Score | {std_s_score:.4f} | **{cust_s_score:.4f}** | Higher is better |\n")
    print("Feature space analysis complete.")

    print("Running t-SNE...")
    tsne = TSNE(n_components=2, verbose=0, perplexity=40, random_state=42)
    std_tsne = tsne.fit_transform(standard_features)
    cust_tsne = tsne.fit_transform(custom_features)

    print("Running UMAP...")
    reducer = umap.UMAP(random_state=42)
    std_umap = reducer.fit_transform(standard_features)
    cust_umap = reducer.fit_transform(custom_features)

    fig, axes = plt.subplots(2, 2, figsize=(20, 18))
    fig.suptitle('Feature Space Visualization Comparison', fontsize=20)

    scatter1 = axes[0, 0].scatter(std_tsne[:, 0], std_tsne[:, 1], c=all_labels, cmap='viridis', alpha=0.6)
    axes[0, 0].set_title("Standard Model Feature Space (t-SNE)")
    fig.colorbar(scatter1, ax=axes[0, 0], ticks=np.unique(all_labels))

    scatter2 = axes[0, 1].scatter(cust_tsne[:, 0], cust_tsne[:, 1], c=all_labels, cmap='viridis', alpha=0.6)
    axes[0, 1].set_title("Custom Loss Model Feature Space (t-SNE)")
    fig.colorbar(scatter2, ax=axes[0, 1], ticks=np.unique(all_labels))

    scatter3 = axes[1, 0].scatter(std_umap[:, 0], std_umap[:, 1], c=all_labels, cmap='viridis', alpha=0.6)
    axes[1, 0].set_title("Standard Model Feature Space (UMAP)")
    fig.colorbar(scatter3, ax=axes[1, 0], ticks=np.unique(all_labels))

    scatter4 = axes[1, 1].scatter(cust_umap[:, 0], cust_umap[:, 1], c=all_labels, cmap='viridis', alpha=0.6)
    axes[1, 1].set_title("Custom Loss Model Feature Space (UMAP)")
    fig.colorbar(scatter4, ax=axes[1, 1], ticks=np.unique(all_labels))

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    # Save the combined plot
    plot_path = os.path.join(output_dir, "feature_space_comparison.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"t-SNE and UMAP plots saved to {plot_path}.")

    # Add image to report
    report.write(f"\n![Feature Space Visualization]({os.path.basename(plot_path)})\n")

    del standard_features, custom_features, std_tsne, cust_tsne, std_umap, cust_umap, all_labels
    gc.collect()

def evaluate_robustness(standard_model, custom_model, model_name, val_dataset, report):
    print("\n--- Evaluating Robustness to Corruptions ---")
    report.write("\n### 2. Robustness to Corruptions\n\n")
    report.write("| Corruption | Severity | Standard Accuracy | Custom Loss Accuracy |\n")
    report.write("|---|---|---|---|\n")

    x_test_subset, y_test_subset = x_test[:1000], y_test[:1000]

    for severity in [1, 2, 3, 4, 5]:
        corrupted_images, corrupted_labels = [], []
        for img_np, lbl in zip(x_test_subset, y_test_subset):
            noise = np.zeros(img_np.shape, dtype=np.float32)
            cv2.randn(noise, 0, severity * 15)
            corrupted_np = np.clip(img_np.astype(np.float32) + noise, 0, 255).astype(np.uint8)
            corrupted_tf, _ = preprocess_image(corrupted_np, lbl, model_name, is_one_hot=False)
            corrupted_images.append(corrupted_tf.numpy())
            corrupted_labels.append(lbl)

        corrupted_labels_onehot = tf.one_hot(np.array(corrupted_labels), num_classes)
        ds_corrupted = tf.data.Dataset.from_tensor_slices(
            (np.array(corrupted_images), corrupted_labels_onehot)).batch(batch_size)

        _, std_acc = standard_model.evaluate(ds_corrupted, verbose=0)
        _, cust_acc = custom_model.evaluate(ds_corrupted, verbose=0)
        report.write(f"| gaussian_noise | {severity} | {std_acc:.4f} | **{cust_acc:.4f}** |\n")
        del corrupted_images, corrupted_labels, ds_corrupted
        gc.collect()
    print("Robustness evaluation complete.")

def perform_linear_probing(standard_model, custom_model, model_name, train_dataset, val_dataset, report):
    print("\n--- Performing Few-Shot Learning (Linear Probing) ---")
    _, standard_analysis_model = create_analysis_model(model_name)
    standard_analysis_model.set_weights(standard_model.get_weights())
    _, custom_analysis_model = create_analysis_model(model_name)
    custom_analysis_model.set_weights(custom_model.get_weights())

    standard_backbone = Model(standard_analysis_model.inputs, standard_analysis_model.output["features"])
    custom_backbone = Model(custom_analysis_model.inputs, custom_analysis_model.output["features"])

    x_shot, y_shot = [], []
    for c in range(num_classes):
        class_indices = np.where(y_train == c)[0]
        chosen_indices = np.random.choice(class_indices, FEW_SHOT_K, replace=False)
        x_shot.extend(x_train[chosen_indices])
        y_shot.extend(y_train[chosen_indices])
    x_shot, y_shot = np.array(x_shot), np.array(y_shot)
    ds_shot = create_dataset(x_shot, y_shot, model_name)

    def train_probe(backbone):
        backbone.trainable = False
        inputs = keras.Input(shape=input_shape)
        features = backbone(inputs, training=False)
        outputs = layers.Dense(num_classes, activation='softmax')(features)
        probe_model = Model(inputs, outputs)
        probe_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
        probe_model.fit(ds_shot, epochs=50, validation_data=val_dataset, verbose=0,
                        callbacks=[keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)])
        return probe_model.evaluate(val_dataset, verbose=0)[1]

    std_probe_acc = train_probe(standard_backbone)
    cust_probe_acc = train_probe(custom_backbone)

    report.write("\n### 3. Few-Shot Learning (Linear Probing)\n\n")
    report.write(f"On a {FEW_SHOT_K}-shot learning task, the linear probe on the **standard model's** features achieved **{std_probe_acc:.4f}** accuracy.\n")
    report.write(f"The linear probe on the **custom loss model's** features achieved **{cust_probe_acc:.4f}** accuracy.\n")
    print("Linear probing complete.")

def run_occlusion_study(model, model_label, technical_model_name, images_np, labels_np, report, output_dir,
                        patch_size=40, stride=20):
    print(f"\n--- Running Occlusion Study for {model_label} model ---")
    report.write(f"\n### 4. Occlusion Sensitivity Analysis ({model_label} Model)\n\n")

    for i in range(min(NUM_ANALYSIS_IMAGES, len(images_np))):
        img_np, true_label_idx = images_np[i], labels_np[i]

        img_tf, _ = preprocess_image(img_np, true_label_idx, technical_model_name, is_one_hot=False)
        img_tensor = tf.expand_dims(img_tf, 0)
        base_preds = model.predict(img_tensor, verbose=0)
        base_confidence = base_preds[0, true_label_idx]

        heatmap_grid_width = (input_shape[0] - patch_size) // stride + 1
        heatmap_grid_height = (input_shape[1] - patch_size) // stride + 1  # Use input_shape[1] for height

        confidence_drops = np.zeros((heatmap_grid_height, heatmap_grid_width), dtype=np.float32)

        temp_img_for_occlusion = img_tf.numpy()  # Use the preprocessed (resized, scaled) image for occluding

        row_idx = 0
        for y in tqdm(range(0, input_shape[0] - patch_size + 1, stride),
                      desc=f"Occluding Image {i + 1} ({model_label})"):
            col_idx = 0
            for x in range(0, input_shape[0] - patch_size + 1, stride):
                occluded_img_tf_prep = temp_img_for_occlusion.copy()
                occluded_img_tf_prep[y:y + patch_size, x:x + patch_size, :] = 0.5  # Using 0.5 (mid-gray) for occlusion

                occluded_tensor = tf.expand_dims(occluded_img_tf_prep, 0)
                preds = model.predict(occluded_tensor, verbose=0)
                occluded_confidence = preds[0, true_label_idx]

                drop_in_confidence = base_confidence - occluded_confidence

                confidence_drops[row_idx, col_idx] = drop_in_confidence
                col_idx += 1
            row_idx += 1

        heatmap = cv2.resize(confidence_drops, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_LINEAR)

        heatmap = np.maximum(0, heatmap)  # Only positive drops (where occlusion hurts confidence)
        if np.max(heatmap) > 0:
            heatmap = heatmap / np.max(heatmap)  # Normalize if there are positive values

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        pred_idx = np.argmax(base_preds[0])

        display_img_original = (img_tf.numpy() * 255).astype(np.uint8)  # Scale back to 0-255 for cv2/plt if needed
        ax1.imshow(cv2.cvtColor(display_img_original, cv2.COLOR_BGR2RGB))
        ax1.set_title(f"Original (Pred: {cifar10_labels[pred_idx]})")
        ax1.axis('off')

        ax2.imshow(cv2.cvtColor(display_img_original, cv2.COLOR_BGR2RGB))
        ax2.imshow(heatmap, cmap='jet', alpha=0.5, vmin=0, vmax=1)  # Set vmin/vmax for consistent scaling
        ax2.set_title("Occlusion Heatmap")
        ax2.axis('off')

        fig_path = os.path.join(output_dir, f"occlusion_study_{model_label.lower()}_{i}.png")
        plt.suptitle(f"Occlusion Analysis (True: {cifar10_labels[true_label_idx]})")
        plt.savefig(fig_path)
        plt.close()

        report.write(f"![Occlusion Analysis Image {i + 1}]({os.path.basename(fig_path)})\n")

    print(f"Occlusion study images saved to '{output_dir}'.")

# --- 6. Main Execution ---
if __name__ == "__main__":
    for model_name in EXPERIMENT_CONFIGS.keys():
        print(f"\n{'='*20} RUNNING ANALYSIS FOR: {model_name.upper()} {'='*20}")

        STD_MODEL_PATH = f"cifar10_{model_name}_std_ckpt/best_model.keras"
        CUSTOM_MODEL_PATH = f"cifar10_{model_name}_enhanced_custom_on_best_ckpt/best_model.keras"

        if not os.path.exists(STD_MODEL_PATH) or not os.path.exists(CUSTOM_MODEL_PATH):
            print(f"\nERROR: Model files not found for {model_name.upper()}!")
            print(f"Please ensure '{STD_MODEL_PATH}' and '{CUSTOM_MODEL_PATH}' exist.")
            continue

        train_dataset = create_dataset(x_train, y_train, model_name)
        val_dataset = create_dataset(x_test, y_test, model_name)

        print(f"--- Loading Standard Model from {STD_MODEL_PATH} ---")
        standard_model = keras.models.load_model(STD_MODEL_PATH)
        print(f"--- Loading Custom Model from {CUSTOM_MODEL_PATH} ---")
        custom_model = keras.models.load_model(CUSTOM_MODEL_PATH)

        standard_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
        custom_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

        output_dir = os.path.join("analysis_results_cifar10", model_name)
        os.makedirs(output_dir, exist_ok=True)

        report_path = os.path.join(output_dir, "full_analysis_report.md")
        with open(report_path, "w") as report:
            report.write(f"# Analysis Report: Standard vs. Custom Loss on CIFAR10 ({model_name.upper()})\n\n")
            analyze_feature_space(standard_model, custom_model, model_name, val_dataset, report, output_dir)
            evaluate_robustness(standard_model, custom_model, model_name, val_dataset, report)
            perform_linear_probing(standard_model, custom_model, model_name, train_dataset, val_dataset, report)

            # MODIFIED: Pass the technical 'model_name' in addition to the label
            run_occlusion_study(standard_model, "Standard", model_name, x_test, y_test, report, output_dir)
            run_occlusion_study(custom_model, "Custom", model_name, x_test, y_test, report, output_dir)

            print(f"\n--- Analysis for {model_name.upper()} complete. Report and images saved to '{output_dir}'. ---")

        del standard_model, custom_model, train_dataset, val_dataset
        gc.collect()