import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.manifold import TSNE
from collections import defaultdict
import random
from transformers import TFCLIPModel, CLIPProcessor
from PIL import Image
import requests
import os
import json

# Set seeds for reproducibility
random.seed(53)
np.random.seed(53)
tf.random.set_seed(53)

def load_imagenet_dataset():    
    try:
        import tensorflow_datasets as tfds
        print("Loading ImageNet validation set...")
        ds, ds_info = tfds.load(
            'imagenet2012', 
            split='validation',
            shuffle_files=False,
            as_supervised=True,
            with_info=True,
            download=True
        )
        
        # Get class names
        class_names = ds_info.features['label'].names
        num_classes = len(class_names)
        print(f"Total classes: {num_classes}")
        
        # Convert to numpy arrays
        print("Converting to numpy arrays...")
        images = []
        labels = []
        
        # Process images and resize to 224x224
        for idx, (image, label) in enumerate(ds.take(16000)):  
            if idx % 1000 == 0:
                print(f"Processed {idx}/16000 images...")
            
            # Resize image to 224x224
            image = tf.image.resize(image, [224, 224])
            image = tf.cast(image, tf.float32)
            
            images.append(image.numpy())
            labels.append(label.numpy())
        
        images_array = np.array(images)
        labels_array = np.array(labels)
        
        return images_array, labels_array, class_names
        
    except ImportError:
        print("tensorflow_datasets not available. Using manual ImageNet loading...")
        return load_imagenet_manual()

def load_imagenet_manual():
    imagenet_path = "./imagenet"  # Update this path to your ImageNet directory
    val_path = os.path.join(imagenet_path, "val")
    
    if not os.path.exists(val_path):
        raise FileNotFoundError(f"ImageNet validation directory not found at {val_path}")
    
    # Load ImageNet class names
    with open(os.path.join(imagenet_path, "imagenet_class_index.json"), 'r') as f:
        class_idx = json.load(f)
    class_names = [class_idx[str(i)][1] for i in range(1000)]
    
    # Get all class directories
    class_dirs = sorted([d for d in os.listdir(val_path) if os.path.isdir(os.path.join(val_path, d))])
    
    images = []
    labels = []
    
    print("Loading ImageNet validation images...")
    for class_idx, class_dir in enumerate(class_dirs):
        class_path = os.path.join(val_path, class_dir)
        image_files = [f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
        for img_file in image_files:
            try:
                img_path = os.path.join(class_path, img_file)
                img = Image.open(img_path).convert('RGB')
                img = img.resize((224, 224))
                img_array = np.array(img)
                
                images.append(img_array)
                labels.append(class_idx)
                
                if len(images) % 1000 == 0:
                    print(f"Loaded {len(images)} images...")
                    
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
    
    return np.array(images), np.array(labels), class_names

# Load TensorFlow CLIP model
model_name = "openai/clip-vit-base-patch32"
tf_model = TFCLIPModel.from_pretrained(model_name, output_attentions=True)
processor = CLIPProcessor.from_pretrained(model_name)

print("TensorFlow CLIP model loaded successfully!")

all_images, all_labels, class_names = load_imagenet_dataset()
num_classes = len(class_names)

# Use class names directly for ImageNet
imagenet_prompts = [f"a photo of a {class_name.replace('_', ' ')}" for class_name in class_names]

# Process text prompts in batches to avoid memory issues
batch_size = 100
vanilla_text_features_list = []

for i in range(0, len(imagenet_prompts), batch_size):
    batch_prompts = imagenet_prompts[i:i+batch_size]
    vanilla_inputs = processor(text=batch_prompts, return_tensors="tf", padding=True, truncation=True)
    
    # Get text features using TensorFlow model
    text_outputs = tf_model.get_text_features(**vanilla_inputs)
    # Normalize features
    vanilla_text_features_batch = tf.nn.l2_normalize(text_outputs, axis=-1)
    vanilla_text_features_list.append(vanilla_text_features_batch)

vanilla_text_features = tf.concat(vanilla_text_features_list, axis=0)

samples_per_class = 16  # Much smaller validation set for ImageNet
max_test_samples = 50000  # Limit test set size for computational efficiency

# Collect validation samples (fewer per class for ImageNet)
val_samples_per_class = defaultdict(list)
remaining_samples = []

# First pass: collect validation samples
for idx, label in enumerate(all_labels):
    if len(val_samples_per_class[label]) < samples_per_class:
        val_samples_per_class[label].append(idx)
    else:
        remaining_samples.append(idx)

# Create validation indices
val_indices = []
for class_label in range(num_classes):
    val_indices.extend(val_samples_per_class[class_label])

random.shuffle(remaining_samples)
test_indices = remaining_samples[:max_test_samples]


# Create validation and test datasets
val_images = all_images[val_indices]
val_labels = all_labels[val_indices]
test_images = all_images[test_indices]
test_labels = all_labels[test_indices]

def preprocess_image_for_clip(img_array):
    img_tensor = tf.convert_to_tensor(img_array, dtype=tf.float32)
    
    # Ensure values are in [0, 255] range
    if tf.reduce_max(img_tensor) <= 1.0:
        img_tensor = img_tensor * 255.0
    
    # Convert to PIL Image for processor
    img_array_uint8 = tf.cast(img_tensor, tf.uint8).numpy()
    pil_img = Image.fromarray(img_array_uint8)
    
    # Process with CLIP processor
    image_input = processor(images=pil_img, return_tensors="tf")['pixel_values']
    return image_input

def get_attention_weights(vision_outputs, layer_idx, head_idx):
    attention_tensor = vision_outputs.attentions[layer_idx]
    head_attention = attention_tensor[:, head_idx, :, :]
    return head_attention

def find_best_patch_heads(validation_images, validation_labels, text_features):
    print("Finding best heads for patch enrichment...")
    all_selected_heads = []
    
    # Use fewer samples for head selection due to computational constraints
    sample_indices = np.random.choice(len(validation_images), min(200, len(validation_images)), replace=False)
    
    for i, idx in enumerate(sample_indices):
        if i % 20 == 0:
            print(f"Processing sample {i}/{len(sample_indices)}")
        
        image_input = preprocess_image_for_clip(validation_images[idx])
        label = validation_labels[idx]
        
        # Get vision model outputs with attention
        vision_outputs = tf_model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
        patch_embeds = vision_outputs.last_hidden_state[:, 1:, :]  # Remove CLS token
        
        sample_heads = []
        for layer_offset in range(3):
            layer_idx = -(layer_offset + 1)  # Last 3 layers
            
            head_scores = []
            for head_idx in range(12):  # 12 attention heads in ViT-Base
                # Get attention weights from CLS token to patches
                attention_weights = get_attention_weights(vision_outputs, layer_idx, head_idx)
                cls_to_patch_attention = attention_weights[0, 0, 1:]  # CLS to patches
                
                # Weighted average of patch embeddings
                weighted_patches = patch_embeds[0] * tf.expand_dims(cls_to_patch_attention, axis=-1)
                avg_patch = tf.reduce_mean(weighted_patches, axis=0)
                
                # Project through visual projection layer
                projected = tf_model.visual_projection(tf.expand_dims(avg_patch, axis=0))
                projected = tf.nn.l2_normalize(projected, axis=-1)
                
                # Calculate similarity with text features
                target_text_feature = tf.expand_dims(text_features[label], axis=0)
                similarity = tf.reduce_sum(projected * target_text_feature, axis=-1)
                head_scores.append((head_idx, float(similarity.numpy()[0])))
            
            head_scores.sort(key=lambda x: x[1], reverse=True)
            top_3 = [head for head, _ in head_scores[:3]]
            sample_heads.extend([(layer_idx, head) for head in top_3])
        
        all_selected_heads.append(sample_heads)
    
    # Count head frequencies
    head_frequency = defaultdict(int)
    for sample_heads in all_selected_heads:
        for layer_head in sample_heads:
            head_frequency[layer_head] += 1
    
    most_common = sorted(head_frequency.items(), key=lambda x: x[1], reverse=True)[:9]
    common_heads = [layer_head for layer_head, count in most_common]
    
    print("Most common patch heads:")
    for (layer_idx, head_idx), count in most_common:
        print(f"  Layer {layer_idx}, Head {head_idx}: {count} times")
    
    return common_heads

def find_best_cls_heads(validation_images, validation_labels, text_features):
    all_selected_heads = []
    
    # Use fewer samples for head selection due to computational constraints
    sample_indices = np.random.choice(len(validation_images), min(200, len(validation_images)), replace=False)
    
    for i, idx in enumerate(sample_indices):
        if i % 20 == 0:
            print(f"Processing sample {i}/{len(sample_indices)}")
        
        image_input = preprocess_image_for_clip(validation_images[idx])
        label = validation_labels[idx]
        
        # Get vision model outputs
        vision_outputs = tf_model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
        
        sample_heads = []
        for layer_offset in range(3):
            layer_idx = -(layer_offset + 1)
            
            head_scores = []
            for head_idx in range(12):
                # Get attention weights and hidden states for this layer
                attention_weights = get_attention_weights(vision_outputs, layer_idx, head_idx)
                layer_hidden = vision_outputs.hidden_states[layer_idx]
                
                # CLS attention weights (softmax over all tokens)
                cls_attention = attention_weights[0, 0, :]  # CLS token attention to all tokens
                cls_attention = tf.nn.softmax(cls_attention, axis=0)
                
                # Weighted combination of all hidden states
                enriched_cls = tf.reduce_sum(layer_hidden[0] * tf.expand_dims(cls_attention, axis=-1), axis=0)
                
                # Project and normalize
                projected = tf_model.visual_projection(tf.expand_dims(enriched_cls, axis=0))
                projected = tf.nn.l2_normalize(projected, axis=-1)
                
                # Calculate similarity
                target_text_feature = tf.expand_dims(text_features[label], axis=0)
                similarity = tf.reduce_sum(projected * target_text_feature, axis=-1)
                head_scores.append((head_idx, float(similarity.numpy()[0])))
            
            head_scores.sort(key=lambda x: x[1], reverse=True)
            top_3 = [head for head, _ in head_scores[:3]]
            sample_heads.extend([(layer_idx, head) for head in top_3])
        
        all_selected_heads.append(sample_heads)
    
    # Count head frequencies
    head_frequency = defaultdict(int)
    for sample_heads in all_selected_heads:
        for layer_head in sample_heads:
            head_frequency[layer_head] += 1
    
    most_common = sorted(head_frequency.items(), key=lambda x: x[1], reverse=True)[:9]
    common_heads = [layer_head for layer_head, count in most_common]
    
    print("Most common CLS heads:")
    for (layer_idx, head_idx), count in most_common:
        print(f"  Layer {layer_idx}, Head {head_idx}: {count} times")
    
    return common_heads

def find_optimal_beta(validation_images, validation_labels, text_features, common_heads, enrichment_type):
    print(f"Finding optimal beta for {enrichment_type}...")
    
    beta_range = np.linspace(0.0, 1.0, 11)
    best_beta = 0.0
    best_accuracy = 0.0
    
    # Use subset for beta optimization due to computational constraints
    sample_indices = np.random.choice(len(validation_images), min(500, len(validation_images)), replace=False)
    
    for beta in beta_range:
        correct = 0
        total = 0
        
        for idx in sample_indices:
            image_input = preprocess_image_for_clip(validation_images[idx])
            label = validation_labels[idx]
            
            # Get vision model outputs
            vision_outputs = tf_model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
            
            # Vanilla CLS logits
            cls_embed = tf_model.visual_projection(vision_outputs.last_hidden_state[:, 0, :])
            cls_normalized = tf.nn.l2_normalize(cls_embed, axis=-1)
            cls_logits = tf.squeeze(tf.matmul(cls_normalized, text_features, transpose_b=True))
            
            # Enriched logits
            if enrichment_type == "patch":
                patch_embeds = vision_outputs.last_hidden_state[:, 1:, :]
                enriched = tf.zeros_like(patch_embeds[0][0])
                
                for layer_idx, head_idx in common_heads:
                    attention_weights = get_attention_weights(vision_outputs, layer_idx, head_idx)
                    cls_to_patch_attention = attention_weights[0, 0, 1:]
                    weighted_patches = patch_embeds[0] * tf.expand_dims(cls_to_patch_attention, axis=-1)
                    enriched += tf.reduce_mean(weighted_patches, axis=0)
                
                enriched /= len(common_heads)
                projected = tf_model.visual_projection(tf.expand_dims(enriched, axis=0))
                
            elif enrichment_type == "cls":
                enriched = tf.zeros_like(vision_outputs.hidden_states[-1][0][0])
                
                for layer_idx, head_idx in common_heads:
                    attention_weights = get_attention_weights(vision_outputs, layer_idx, head_idx)
                    layer_hidden = vision_outputs.hidden_states[layer_idx]
                    
                    cls_attention = attention_weights[0, 0, :]
                    cls_attention = tf.nn.softmax(cls_attention, axis=0)
                    
                    layer_enriched = tf.reduce_sum(layer_hidden[0] * tf.expand_dims(cls_attention, axis=-1), axis=0)
                    enriched += layer_enriched
                
                enriched /= len(common_heads)
                projected = tf_model.visual_projection(tf.expand_dims(enriched, axis=0))
            
            normalized = tf.nn.l2_normalize(projected, axis=-1)
            enriched_logits = tf.squeeze(tf.matmul(normalized, text_features, transpose_b=True))
            
            # Combined logits
            combined_logits = beta * cls_logits + (1 - beta) * enriched_logits
            pred = tf.argmax(combined_logits).numpy()
            
            if pred == label:
                correct += 1
            total += 1
        
        accuracy = correct / total
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_beta = beta
    
    print(f"Optimal beta: {best_beta:.3f} (accuracy: {best_accuracy:.4f})")
    return best_beta

# TensorFlow CNN+MLP Classifier (updated for ImageNet's 1000 classes)
class HybridClassifier(tf.keras.Model):
    def __init__(self, mlp_input_dim, num_classes=1000):
        super(HybridClassifier, self).__init__()
        
        # CNN for processing 2D patch features
        self.cnn = tf.keras.Sequential([
            tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu'),
            tf.keras.layers.MaxPool2D(2, 2),
            tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu'),
            tf.keras.layers.MaxPool2D(2, 2),
            tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(512, activation='relu'),
            tf.keras.layers.Dropout(0.5)
        ])
        
        # MLP for processing 1D CLS features
        self.mlp = tf.keras.Sequential([
            tf.keras.layers.Dense(512, activation='relu'),
            tf.keras.layers.Dropout(0.5)
        ])
        
        # Final classifier
        self.classifier = tf.keras.layers.Dense(num_classes)
        
    def call(self, patch_features, cls_features, training=False):
        # Process patch features with CNN
        # Input shape: (batch_size, 7, 7, 768) -> need to transpose to (batch_size, 768, 7, 7) for CNN
        patch_features = tf.transpose(patch_features, [0, 3, 1, 2])
        cnn_out = self.cnn(patch_features, training=training)
        
        # Process CLS features with MLP
        mlp_out = self.mlp(cls_features, training=training)
        
        # Combine features
        combined = tf.concat([cnn_out, mlp_out], axis=1)
        
        return self.classifier(combined)

def extract_features(images, labels, patch_heads, cls_heads):
    cls_features_list = []
    patch_features_list = []
    combined_cls_features_list = []
    labels_list = []
    
    print("Extracting features...")
    for idx in range(len(images)):
        if idx % 100 == 0:
            print(f"Processing {idx}/{len(images)}")
            
        image_input = preprocess_image_for_clip(images[idx])
        
        # Get vision model outputs
        vision_outputs = tf_model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
        
        # CLS features
        cls_embed = tf_model.visual_projection(vision_outputs.last_hidden_state[:, 0, :])
        cls_feat = tf.nn.l2_normalize(cls_embed, axis=-1)
        
        # Enriched CLS
        enriched_cls = tf.zeros_like(vision_outputs.hidden_states[-1][0][0])
        for layer_idx, head_idx in cls_heads:
            attention_weights = get_attention_weights(vision_outputs, layer_idx, head_idx)
            layer_hidden = vision_outputs.hidden_states[layer_idx]
            
            cls_attention = attention_weights[0, 0, :]
            cls_attention = tf.nn.softmax(cls_attention, axis=0)
            
            layer_enriched = tf.reduce_sum(layer_hidden[0] * tf.expand_dims(cls_attention, axis=-1), axis=0)
            enriched_cls += layer_enriched
        
        enriched_cls /= len(cls_heads)
        enriched_cls_proj = tf_model.visual_projection(tf.expand_dims(enriched_cls, axis=0))
        enriched_cls_feat = tf.nn.l2_normalize(enriched_cls_proj, axis=-1)
        
        # Enriched patches
        patch_embeds = vision_outputs.last_hidden_state[:, 1:, :]
        enriched_patches = tf.zeros_like(patch_embeds[0])
        
        for layer_idx, head_idx in patch_heads:
            attention_weights = get_attention_weights(vision_outputs, layer_idx, head_idx)
            cls_to_patch_attention = attention_weights[0, 0, 1:]
            weighted_patches = patch_embeds[0] * tf.expand_dims(cls_to_patch_attention, axis=-1)
            enriched_patches += weighted_patches
        
        enriched_patches /= len(patch_heads)
        patch_2d = tf.reshape(enriched_patches, [7, 7, 768])  # ViT-Base-Patch32 creates 7x7 patches
        
        combined_cls = tf.concat([cls_feat, enriched_cls_feat], axis=1)
        
        # Convert to numpy for storage
        cls_features_list.append(cls_feat[0].numpy())
        patch_features_list.append(patch_2d.numpy())
        combined_cls_features_list.append(combined_cls[0].numpy())
        labels_list.append(labels[idx])
    
    return (np.array(cls_features_list), 
            np.array(patch_features_list),
            np.array(combined_cls_features_list),
            np.array(labels_list))

# Find heads
common_patch_heads = find_best_patch_heads(val_images, val_labels, vanilla_text_features)
common_cls_heads = find_best_cls_heads(val_images, val_labels, vanilla_text_features)

# Find betas
optimal_beta1 = find_optimal_beta(val_images, val_labels, vanilla_text_features, common_patch_heads, "patch")
optimal_beta2 = find_optimal_beta(val_images, val_labels, vanilla_text_features, common_cls_heads, "cls")
optimal_beta3 = optimal_beta2

# Extract training features
cls_features, patch_features, combined_cls_features, train_labels = extract_features(
    val_images, val_labels, common_patch_heads, common_cls_heads)


# Patch classifier
patch_classifier = HybridClassifier(mlp_input_dim=512, num_classes=num_classes)
patch_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
patch_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Train patch classifier
print("Training Patch Classifier...")
epochs = 30
batch_size = 8

for epoch in range(epochs):
    total_loss = 0
    correct = 0
    total = 0
    
    # Create batches
    num_batches = len(cls_features) // batch_size
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_cls = cls_features[start_idx:end_idx]
        batch_patch = patch_features[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        
        with tf.GradientTape() as tape:
            outputs = patch_classifier(batch_patch, batch_cls, training=True)
            loss = patch_loss_fn(batch_labels, outputs)
        
        gradients = tape.gradient(loss, patch_classifier.trainable_variables)
        patch_optimizer.apply_gradients(zip(gradients, patch_classifier.trainable_variables))
        
        total_loss += loss.numpy()
        predicted = tf.argmax(outputs, axis=1).numpy()
        correct += np.sum(predicted == batch_labels)
        total += len(batch_labels)
    
    if (epoch + 1) % 5 == 0:
        print(f'Patch Epoch [{epoch+1}/{epochs}], Accuracy: {100*correct/total:.2f}%')

# CLS classifier
cls_classifier = HybridClassifier(mlp_input_dim=1024, num_classes=num_classes)
cls_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)

print("Training CLS Classifier...")
for epoch in range(epochs):
    total_loss = 0
    correct = 0
    total = 0
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_cls = combined_cls_features[start_idx:end_idx]
        batch_patch = patch_features[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        
        with tf.GradientTape() as tape:
            outputs = cls_classifier(batch_patch, batch_cls, training=True)
            loss = patch_loss_fn(batch_labels, outputs)
        
        gradients = tape.gradient(loss, cls_classifier.trainable_variables)
        cls_optimizer.apply_gradients(zip(gradients, cls_classifier.trainable_variables))
        
        total_loss += loss.numpy()
        predicted = tf.argmax(outputs, axis=1).numpy()
        correct += np.sum(predicted == batch_labels)
        total += len(batch_labels)
    
    if (epoch + 1) % 5 == 0:
        print(f'CLS Epoch [{epoch+1}/{epochs}], Accuracy: {100*correct/total:.2f}%')

# Combined classifier
combined_classifier = HybridClassifier(mlp_input_dim=1024, num_classes=num_classes)
combined_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)

print("Training Combined Classifier...")
for epoch in range(epochs):
    total_loss = 0
    correct = 0
    total = 0
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_cls = combined_cls_features[start_idx:end_idx]
        batch_patch = patch_features[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        
        with tf.GradientTape() as tape:
            outputs = combined_classifier(batch_patch, batch_cls, training=True)
            loss = patch_loss_fn(batch_labels, outputs)
        
        gradients = tape.gradient(loss, combined_classifier.trainable_variables)
        combined_optimizer.apply_gradients(zip(gradients, combined_classifier.trainable_variables))
        
        total_loss += loss.numpy()
        predicted = tf.argmax(outputs, axis=1).numpy()
        correct += np.sum(predicted == batch_labels)
        total += len(batch_labels)
    
    if (epoch + 1) % 5 == 0:
        print(f'Combined Epoch [{epoch+1}/{epochs}], Accuracy: {100*correct/total:.2f}%')

# Evaluate on test set (using top-1 and top-5 accuracy for ImageNet)
print("Evaluating on test set...")

correct_counts_top1 = [0] * 2
correct_counts_top5 = [0] * 2
all_predictions = [[] for _ in range(2)]
true_labels = []

for idx in range(len(test_images)):
    if idx % 100 == 0:
        print(f"Processing {idx}/{len(test_images)}")
    
    image_input = preprocess_image_for_clip(test_images[idx])
    label = test_labels[idx]
    
    # Get vision model outputs
    vision_outputs = tf_model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
    
    # Vanilla CLIP
    cls_embed = tf_model.visual_projection(vision_outputs.last_hidden_state[:, 0, :])
    cls_normalized = tf.nn.l2_normalize(cls_embed, axis=-1)
    vanilla_logits = tf.squeeze(tf.matmul(cls_normalized, vanilla_text_features, transpose_b=True))
    
    # Extract features for enriched methods
    cls_feat = tf.nn.l2_normalize(cls_embed, axis=-1)
    
    # Enriched patch
    patch_embeds = vision_outputs.last_hidden_state[:, 1:, :]
    enriched_patches = tf.zeros_like(patch_embeds[0])
    
    for layer_idx, head_idx in common_patch_heads:
        attention_weights = get_attention_weights(vision_outputs, layer_idx, head_idx)
        cls_to_patch_attention = attention_weights[0, 0, 1:]
        weighted_patches = patch_embeds[0] * tf.expand_dims(cls_to_patch_attention, axis=-1)
        enriched_patches += weighted_patches
    
    enriched_patches /= len(common_patch_heads)
    enriched_patch_1d = tf.reduce_mean(enriched_patches, axis=0)
    patch_proj = tf_model.visual_projection(tf.expand_dims(enriched_patch_1d, axis=0))
    patch_norm = tf.nn.l2_normalize(patch_proj, axis=-1)
    patch_logits = tf.squeeze(tf.matmul(patch_norm, vanilla_text_features, transpose_b=True))
    
    # Enriched CLS
    enriched_cls = tf.zeros_like(vision_outputs.hidden_states[-1][0][0])
    
    for layer_idx, head_idx in common_cls_heads:
        attention_weights = get_attention_weights(vision_outputs, layer_idx, head_idx)
        layer_hidden = vision_outputs.hidden_states[layer_idx]
        
        cls_attention = attention_weights[0, 0, :]
        cls_attention = tf.nn.softmax(cls_attention, axis=0)
        
        layer_enriched = tf.reduce_sum(layer_hidden[0] * tf.expand_dims(cls_attention, axis=-1), axis=0)
        enriched_cls += layer_enriched
    
    enriched_cls /= len(common_cls_heads)
    enriched_cls_proj = tf_model.visual_projection(tf.expand_dims(enriched_cls, axis=0))
    enriched_cls_norm = tf.nn.l2_normalize(enriched_cls_proj, axis=-1)
    enriched_cls_logits = tf.squeeze(tf.matmul(enriched_cls_norm, vanilla_text_features, transpose_b=True))
    
    # Combined learning-based approach
    combined_learning_logits = (optimal_beta3 * vanilla_logits + 
                               (1 - optimal_beta3) / 2 * enriched_cls_logits + 
                               (1 - optimal_beta3) / 2 * patch_logits)
    
    # Get top-1 and top-5 predictions
    vanilla_top5 = tf.nn.top_k(vanilla_logits, k=5)[1].numpy()
    combined_top5 = tf.nn.top_k(combined_learning_logits, k=5)[1].numpy()
    
    pred_vanilla_top1 = vanilla_top5[0]
    pred_combined_top1 = combined_top5[0]
    
    # Store predictions
    predictions = [pred_vanilla_top1, pred_combined_top1]
    top5_predictions = [vanilla_top5, combined_top5]
    
    for i, pred in enumerate(predictions):
        all_predictions[i].append(pred)
        # Top-1 accuracy
        if pred == label:
            correct_counts_top1[i] += 1
        # Top-5 accuracy
        if label in top5_predictions[i]:
            correct_counts_top5[i] += 1
    
    true_labels.append(label)

# Results
eval_size = len(test_images)
method_names = ["Vanilla CLIP", "INFER (Combined)"]


print("\nTop-1 Accuracy:")
for i, name in enumerate(method_names):
    accuracy = correct_counts_top1[i] / eval_size
    print(f"{i+1}. {name:30} {accuracy:.4f}")

print("\nTop-5 Accuracy:")
for i, name in enumerate(method_names):
    accuracy = correct_counts_top5[i] / eval_size
    print(f"{i+1}. {name:30} {accuracy:.4f}")

# Performance comparison visualization
accuracies_top1 = [count / eval_size for count in correct_counts_top1]
accuracies_top5 = [count / eval_size for count in correct_counts_top5]
short_names = ['Vanilla CLIP', 'INFER']

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Top-1 accuracy
bars1 = ax1.bar(short_names, accuracies_top1, color=['blue', 'pink'])
ax1.set_ylabel('Top-1 Accuracy')
ax1.set_title('ImageNet Top-1 Accuracy Comparison')
ax1.set_ylim(0, max(accuracies_top1) * 1.1)

for bar, acc in zip(bars1, accuracies_top1):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005, 
             f'{acc:.3f}', ha='center', va='bottom')

ax1.grid(True, alpha=0.3)

# Top-5 accuracy
bars2 = ax2.bar(short_names, accuracies_top5, color=['blue', 'pink'])
ax2.set_ylabel('Top-5 Accuracy')
ax2.set_title('ImageNet Top-5 Accuracy Comparison')
ax2.set_ylim(0, max(accuracies_top5) * 1.1)

for bar, acc in zip(bars2, accuracies_top5):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005, 
             f'{acc:.3f}', ha='center', va='bottom')

ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Display sample images from random ImageNet classes
print("Displaying sample images from ImageNet...")
plt.figure(figsize=(20, 4))

# Select 10 random classes that appear in test set
test_class_indices = np.unique(test_labels)
selected_classes = np.random.choice(test_class_indices, min(10, len(test_class_indices)), replace=False)

for i, class_idx in enumerate(selected_classes):
    # Find first image of this class in test set
    class_mask = test_labels == class_idx
    if np.any(class_mask):
        sample_idx = np.where(class_mask)[0][0]
        sample_image = test_images[sample_idx]
        
        plt.subplot(1, len(selected_classes), i + 1)
        plt.imshow(sample_image.astype(np.uint8))
        plt.title(f'{class_names[class_idx][:15]}...', fontsize=8)  # Truncate long class names
        plt.axis('off')

plt.suptitle('Sample Images from ImageNet Dataset')
plt.tight_layout()
plt.show()

# Create a detailed performance analysis
print(f"\n{'='*60}")
print("DETAILED ANALYSIS")
print(f"{'='*60}")

# Class distribution in test set
unique_test_labels, counts = np.unique(test_labels, return_counts=True)

# Feature dimensionality info
print(f"\nFeature dimensions:")
print(f"  CLS features: {cls_features.shape}")
print(f"  Patch features: {patch_features.shape}")
print(f"  Combined CLS features: {combined_cls_features.shape}")

# Head analysis summary
print(f"\nSelected attention heads:")
print(f"  Patch heads: {len(common_patch_heads)} heads across {len(set([layer for layer, head in common_patch_heads]))} layers")
print(f"  CLS heads: {len(common_cls_heads)} heads across {len(set([layer for layer, head in common_cls_heads]))} layers")

# Improvement analysis
vanilla_top1 = accuracies_top1[0]
infer_top1 = accuracies_top1[1]
vanilla_top5 = accuracies_top5[0]
infer_top5 = accuracies_top5[1]

improvement_top1 = ((infer_top1 - vanilla_top1) / vanilla_top1) * 100 if vanilla_top1 > 0 else 0
improvement_top5 = ((infer_top5 - vanilla_top5) / vanilla_top5) * 100 if vanilla_top5 > 0 else 0

print(f"\nPerformance improvements:")
print(f"  Top-1 accuracy improvement: {improvement_top1:+.2f}%")
print(f"  Top-5 accuracy improvement: {improvement_top5:+.2f}%")

# Analysis of challenging classes (classes with lowest accuracy)
print(f"\nAnalyzing challenging classes...")

# Calculate per-class accuracy for the best method
best_method_idx = np.argmax(accuracies_top1)
best_predictions = all_predictions[best_method_idx]

class_accuracies = {}
for class_idx in unique_test_labels:
    class_mask = np.array(true_labels) == class_idx
    if np.any(class_mask):
        class_predictions = np.array(best_predictions)[class_mask]
        class_accuracy = np.sum(class_predictions == class_idx) / len(class_predictions)
        class_accuracies[class_idx] = class_accuracy

# Show top 5 most challenging classes
if class_accuracies:
    sorted_classes = sorted(class_accuracies.items(), key=lambda x: x[1])
    print(f"\nTop 5 most challenging classes:")
    for i, (class_idx, acc) in enumerate(sorted_classes[:5]):
        class_count = np.sum(np.array(true_labels) == class_idx)
        print(f"  {i+1}. {class_names[class_idx][:30]:30}: {acc:.3f} ({class_count} samples)")
    
    # Show top 5 easiest classes
    print(f"\nTop 5 easiest classes:")
    for i, (class_idx, acc) in enumerate(sorted_classes[-5:]):
        class_count = np.sum(np.array(true_labels) == class_idx)
        print(f"  {i+1}. {class_names[class_idx][:30]:30}: {acc:.3f} ({class_count} samples)")



# Memory cleanup
import gc
gc.collect()

print("Memory cleanup completed.")


# Display model summaries
print("\nPatch Classifier Summary:")
try:
    # Build the model by calling it once
    dummy_patch = tf.random.normal([1, 7, 7, 768])
    dummy_cls = tf.random.normal([1, 512])
    _ = patch_classifier(dummy_patch, dummy_cls)
    patch_classifier.summary()
except:
    print("Could not display patch classifier summary")

print("\nCLS Classifier Summary:")
try:
    dummy_patch = tf.random.normal([1, 7, 7, 768])
    dummy_cls = tf.random.normal([1, 1024])
    _ = cls_classifier(dummy_patch, dummy_cls)
    cls_classifier.summary()
except:
    print("Could not display CLS classifier summary")

print("\nCombined Classifier Summary:")
try:
    dummy_patch = tf.random.normal([1, 7, 7, 768])
    dummy_cls = tf.random.normal([1, 1024])
    _ = combined_classifier(dummy_patch, dummy_cls)
    combined_classifier.summary()
except:
    print("Could not display combined classifier summary")

# Save models in TensorFlow format
print("\nSaving TensorFlow models...")
try:
    patch_classifier.save_weights('./patch_classifier_weights')
    cls_classifier.save_weights('./cls_classifier_weights')
    combined_classifier.save_weights('./combined_classifier_weights')
    print("Model weights saved successfully!")
except Exception as e:
    print(f"Error saving models: {e}")

# Performance profiling
print(f"\n{'='*60}")
print("PERFORMANCE PROFILING")
print(f"{'='*60}")

# Test inference speed
print("\nTesting inference speed...")
import time

# Single image inference time
sample_image = test_images[0:1]
start_time = time.time()

for _ in range(10):  # Average over 10 runs
    image_input = preprocess_image_for_clip(sample_image[0])
    vision_outputs = tf_model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
    cls_embed = tf_model.visual_projection(vision_outputs.last_hidden_state[:, 0, :])
    cls_normalized = tf.nn.l2_normalize(cls_embed, axis=-1)
    vanilla_logits = tf.squeeze(tf.matmul(cls_normalized, vanilla_text_features, transpose_b=True))

end_time = time.time()
avg_inference_time = (end_time - start_time) / 10
print(f"Average inference time per image: {avg_inference_time:.3f} seconds")
print(f"Estimated throughput: {1/avg_inference_time:.1f} images/second")

# Memory usage estimation
print(f"\nMemory usage estimates:")
print(f"  Text features tensor: {vanilla_text_features.shape} = {np.prod(vanilla_text_features.shape) * 4 / 1e6:.1f} MB")
print(f"  Single image features: ~{768 * 4 / 1e3:.1f} KB")
print(f"  Batch processing recommended for: >{1000/avg_inference_time:.0f} images")

