import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.manifold import TSNE
from collections import defaultdict
import random
from transformers import TFCLIPModel, CLIPProcessor
from PIL import Image
import requests

# Set seeds for reproducibility
random.seed(53)
np.random.seed(53)
tf.random.set_seed(53)

try:
    import tensorflow_datasets as tfds
    print("tensorflow_datasets available")
except ImportError:
    print("tensorflow_datasets not available")
    tfds = None

# Load Oxford-IIIT Pet Dataset
def load_pets_dataset():
    print("Downloading Oxford-IIIT Pet dataset...")
    
    if tfds is not None:
        try:
            # Download the dataset
            dataset, info = tfds.load('oxford_iiit_pet', 
                                     with_info=True, 
                                     as_supervised=True,
                                     split=['train', 'test'])
            
            train_dataset, test_dataset = dataset
            
            # Get class names from info
            class_names = info.features['label'].names
            print(f"Classes found: {class_names}")
            print(f"Number of classes: {len(class_names)}")
            
            # Combine train and test for our custom split
            all_images = []
            all_labels = []
            
            # Process train split
            for image, label in train_dataset:
                # Resize image
                image = tf.image.resize(image, [224, 224])
                all_images.append(image.numpy())
                all_labels.append(label.numpy())
            
            # Process test split  
            for image, label in test_dataset:
                image = tf.image.resize(image, [224, 224])
                all_images.append(image.numpy())
                all_labels.append(label.numpy())
                
            images_array = np.array(all_images)
            labels_array = np.array(all_labels)
            
        except Exception as e:
            print(f"Error with tfds.load: {e}")
            tfds = None  
    
    if tfds is None:
        
        # Alternative method using manual download
        import os
        import tarfile
        
        # Download annotations and images
        annotations_url = "http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz"
        images_url = "http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz"
        
        # Create directory
        data_dir = "./datasets/pets"
        os.makedirs(data_dir, exist_ok=True)
        
        try:
            # Download and extract annotations
            annotations_path = tf.keras.utils.get_file(
                fname='annotations.tar.gz',
                origin=annotations_url,
                extract=True,
                cache_dir=data_dir
            )
            
            # Download and extract images
            images_path = tf.keras.utils.get_file(
                fname='images.tar.gz', 
                origin=images_url,
                extract=True,
                cache_dir=data_dir
            )
        except Exception as e:

            

        images_dir = os.path.join(data_dir, "datasets", "images")
        
        if os.path.exists(images_dir):
            files = [f for f in os.listdir(images_dir) if f.endswith('.jpg')]
            breeds = set()
            for f in files:
                breed = '_'.join(f.split('_')[:-1])  # Remove number and extension
                breeds.add(breed)
            
            class_names = sorted(list(breeds))[:37]  # Limit to 37 classes as in original dataset
            print(f"Found {len(class_names)} pet breeds")
        else:
            # Fallback: create demo classes
            class_names = [
                'Abyssinian', 'Bengal', 'Birman', 'Bombay', 'British_Shorthair',
                'Egyptian_Mau', 'Maine_Coon', 'Persian', 'Ragdoll', 'Russian_Blue',
                'Siamese', 'Sphynx', 'american_bulldog', 'american_pit_bull_terrier',
                'basset_hound', 'beagle', 'boxer', 'chihuahua', 'english_cocker_spaniel',
                'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese',
                'japanese_chin', 'keeshond', 'leonberger', 'miniature_pinscher',
                'newfoundland', 'pomeranian', 'pug', 'saint_bernard', 'samoyed',
                'scottish_terrier', 'shiba_inu', 'staffordshire_bull_terrier',
                'wheaten_terrier', 'yorkshire_terrier'
            ]
            print(f"Using fallback class names: {len(class_names)} classes")
        
        # Create organized directory structure
        organized_dir = os.path.join(data_dir, "organized")
        if not os.path.exists(organized_dir):
            os.makedirs(organized_dir)
            for class_name in class_names:
                os.makedirs(os.path.join(organized_dir, class_name), exist_ok=True)
            
            # If images exist, move files to appropriate directories
            if os.path.exists(images_dir):
                import shutil
                for f in os.listdir(images_dir):
                    if f.endswith('.jpg'):
                        breed = '_'.join(f.split('_')[:-1])
                        if breed in class_names:
                            src = os.path.join(images_dir, f)
                            dst = os.path.join(organized_dir, breed, f)
                            if not os.path.exists(dst):
                                try:
                                    shutil.copy2(src, dst)
                                except:
                                    pass
            else:
                # Create dummy data for testing
                print("Creating dummy data for testing...")
                from PIL import Image as PILImage
                for class_name in class_names:
                    class_dir = os.path.join(organized_dir, class_name)
                    for i in range(50):  # 50 images per class
                        # Create a random colored image
                        img = PILImage.new('RGB', (224, 224), 
                                         color=(np.random.randint(0, 255), 
                                               np.random.randint(0, 255), 
                                               np.random.randint(0, 255)))
                        img.save(os.path.join(class_dir, f"{class_name}_{i}.jpg"))
        
        # Use image_dataset_from_directory
        print("Loading images from organized directory...")
        
        full_dataset = tf.keras.utils.image_dataset_from_directory(
            organized_dir,
            validation_split=0.0,
            seed=53,
            image_size=(224, 224),
            batch_size=1,
            shuffle=False,
            label_mode='int'
        )
        
        class_names = full_dataset.class_names
        
        # Convert to numpy arrays
        images = []
        labels = []
        
        total_samples = 0
        for image_batch, label_batch in full_dataset:
            images.append(image_batch[0].numpy())
            labels.append(label_batch[0].numpy())
            total_samples += 1
            
            if total_samples % 500 == 0:
                print(f"Loaded {total_samples} images...")
        
        images_array = np.array(images)
        labels_array = np.array(labels)

    
    # Display class distribution
    unique, counts = np.unique(labels_array, return_counts=True)
    print("\nClass distribution (top 10):")
    for i, (class_idx, count) in enumerate(zip(unique, counts)):
        if i < 10:  # Show first 10 classes
            class_name = class_names[class_idx] if class_idx < len(class_names) else f"Class_{class_idx}"
            print(f"  {class_name}: {count} images")
    
    return images_array, labels_array, class_names

# Load TensorFlow CLIP model
model_name = "openai/clip-vit-base-patch32"
model = TFCLIPModel.from_pretrained(model_name, output_attentions=True)
processor = CLIPProcessor.from_pretrained(model_name)

# Load dataset
print("Loading Oxford-IIIT Pet dataset...")
all_images, all_labels, class_names = load_pets_dataset()
num_classes = len(class_names)


# Create text features for pet classes
def format_pet_name(class_name):
    formatted = class_name.replace('_', ' ').title()
    return f"a photo of a {formatted}"

pet_prompts = [format_pet_name(class_name) for class_name in class_names]
vanilla_inputs = processor(text=pet_prompts, return_tensors="tf", padding=True, truncation=True)

vanilla_text_outputs = model.get_text_features(**vanilla_inputs)
vanilla_text_features = tf.nn.l2_normalize(vanilla_text_outputs, axis=-1)

samples_per_class = 16  # For validation set (16 per class)
remaining_samples = []
val_samples_per_class = defaultdict(list)

# First pass: collect validation samples (16 per class)
for idx, label in enumerate(all_labels):
    if len(val_samples_per_class[label]) < samples_per_class:
        val_samples_per_class[label].append(idx)
    else:
        remaining_samples.append(idx)

# Create validation and test indices
val_indices = []
for class_label in range(num_classes):
    if class_label in val_samples_per_class:
        val_indices.extend(val_samples_per_class[class_label])

# Use remaining samples as test set
test_indices = remaining_samples.copy()

# Verify validation set has reasonable distribution
val_class_counts = defaultdict(int)
for idx in val_indices:
    val_class_counts[all_labels[idx]] += 1
print("Validation class distribution (top 10):")
for i, (class_idx, count) in enumerate(sorted(val_class_counts.items())):
    if i < 10:
        class_name = class_names[class_idx] if class_idx < len(class_names) else f"Class_{class_idx}"
        print(f"  {class_name}: {count} samples")

# Verify test set distribution
test_class_counts = defaultdict(int)
for idx in test_indices:
    test_class_counts[all_labels[idx]] += 1
print("Test class distribution (top 10):")
for i, (class_idx, count) in enumerate(sorted(test_class_counts.items())):
    if i < 10:
        class_name = class_names[class_idx] if class_idx < len(class_names) else f"Class_{class_idx}"
        print(f"  {class_name}: {count} samples")

# Create validation and test datasets
val_images = all_images[val_indices]
val_labels = all_labels[val_indices]
test_images = all_images[test_indices]
test_labels = all_labels[test_indices]

# Helper function to convert numpy array to PIL Image for CLIP processing
def numpy_to_pil(img_array):
    """Convert numpy array to PIL Image"""
    # Ensure values are in [0, 255] range
    if img_array.max() <= 1.0:
        img_array = img_array * 255.0
    img_array = img_array.astype(np.uint8)
    return Image.fromarray(img_array)

# TensorFlow helper functions for attention-based enrichment
def extract_attention_weights(attention_outputs, layer_idx, head_idx):
    """Extract attention weights for specific layer and head"""
    attention = attention_outputs[layer_idx]
    head_attention = attention[:, head_idx, 0, 1:]  # CLS to patch attention
    return head_attention

def extract_hidden_states(outputs, layer_idx):
    return outputs.hidden_states[layer_idx]

def compute_enriched_patches(hidden_states, attention_weights):
    patch_embeds = hidden_states[:, 1:, :]  # Remove CLS token
    attention_expanded = tf.expand_dims(attention_weights, -1)
    
    # Apply attention weighting
    weighted_patches = patch_embeds * attention_expanded
    
    # Average over patches
    enriched = tf.reduce_mean(weighted_patches, axis=1)
    
    return enriched

def compute_enriched_cls(hidden_states, attention_weights):

    attention_probs = tf.nn.softmax(attention_weights, axis=-1)
    
    # Expand for broadcasting
    attention_expanded = tf.expand_dims(attention_probs, -1)
    
    # Compute weighted sum
    enriched_cls = tf.reduce_sum(hidden_states * attention_expanded, axis=1)
    
    return enriched_cls


def find_best_patch_heads(validation_images, validation_labels, text_features):
    print("Finding best heads for patch enrichment...")
    all_selected_heads = []
    
    for idx in range(min(len(validation_images), 200)):  # Limit for efficiency with large dataset
        if idx % 20 == 0:
            print(f"Processing sample {idx}/{min(len(validation_images), 200)}")
        
        pil_img = numpy_to_pil(validation_images[idx])
        image_input = processor(images=pil_img, return_tensors="tf")['pixel_values']
        label = validation_labels[idx]
        
        if label >= len(text_features):
            continue
        
        vision_out = model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
        patch_embeds = vision_out.last_hidden_state[:, 1:, :]
        
        sample_heads = []
        for layer_offset in range(3):
            layer_idx = -(layer_offset + 1)
            layer_attention = vision_out.attentions[layer_idx]
            
            head_scores = []
            for head_idx in range(12):
                attn_weights = extract_attention_weights(vision_out.attentions, layer_idx, head_idx)
                enriched_patch = compute_enriched_patches(vision_out.last_hidden_state, attn_weights)
                
                projected = model.visual_projection(enriched_patch)
                projected = tf.nn.l2_normalize(projected, axis=-1)
                
                target_text = tf.expand_dims(text_features[label], 0)
                similarity = tf.reduce_sum(projected * target_text, axis=-1)
                head_scores.append((head_idx, float(similarity.numpy())))
            
            head_scores.sort(key=lambda x: x[1], reverse=True)
            top_3 = [head for head, _ in head_scores[:3]]
            sample_heads.extend([(layer_idx, head) for head in top_3])
        
        all_selected_heads.append(sample_heads)
    
    head_frequency = defaultdict(int)
    for sample_heads in all_selected_heads:
        for layer_head in sample_heads:
            head_frequency[layer_head] += 1
    
    most_common = sorted(head_frequency.items(), key=lambda x: x[1], reverse=True)[:9]
    common_heads = [layer_head for layer_head, count in most_common]
    
    print("Most common patch heads:")
    for (layer_idx, head_idx), count in most_common:
        print(f"  Layer {layer_idx}, Head {head_idx}: {count} times")
    
    return common_heads

def find_best_cls_heads(validation_images, validation_labels, text_features):
    print("Finding best heads for CLS enrichment...")
    all_selected_heads = []
    
    for idx in range(min(len(validation_images), 200)):  # Limit for efficiency
        if idx % 20 == 0:
            print(f"Processing sample {idx}/{min(len(validation_images), 200)}")
        
        pil_img = numpy_to_pil(validation_images[idx])
        image_input = processor(images=pil_img, return_tensors="tf")['pixel_values']
        label = validation_labels[idx]
        
        if label >= len(text_features):
            continue
        
        vision_out = model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
        
        sample_heads = []
        for layer_offset in range(3):
            layer_idx = -(layer_offset + 1)
            
            head_scores = []
            for head_idx in range(12):
                cls_attention = extract_attention_weights(vision_out.attentions, layer_idx, head_idx)
                # For CLS enrichment, we use all tokens (including CLS)
                full_attention = vision_out.attentions[layer_idx][:, head_idx, 0, :]
                layer_hidden = extract_hidden_states(vision_out, layer_idx)
                
                enriched_cls = compute_enriched_cls(layer_hidden, full_attention)
                
                projected = model.visual_projection(enriched_cls)
                projected = tf.nn.l2_normalize(projected, axis=-1)
                
                target_text = tf.expand_dims(text_features[label], 0)
                similarity = tf.reduce_sum(projected * target_text, axis=-1)
                head_scores.append((head_idx, float(similarity.numpy())))
            
            head_scores.sort(key=lambda x: x[1], reverse=True)
            top_3 = [head for head, _ in head_scores[:3]]
            sample_heads.extend([(layer_idx, head) for head in top_3])
        
        all_selected_heads.append(sample_heads)
    
    head_frequency = defaultdict(int)
    for sample_heads in all_selected_heads:
        for layer_head in sample_heads:
            head_frequency[layer_head] += 1
    
    most_common = sorted(head_frequency.items(), key=lambda x: x[1], reverse=True)[:9]
    common_heads = [layer_head for layer_head, count in most_common]
    
    print("Most common CLS heads:")
    for (layer_idx, head_idx), count in most_common:
        print(f"  Layer {layer_idx}, Head {head_idx}: {count} times")
    
    return common_heads

def find_optimal_beta(validation_images, validation_labels, text_features, common_heads, enrichment_type):
    print(f"Finding optimal beta for {enrichment_type}...")
    
    beta_range = np.linspace(0.0, 1.0, 11)
    best_beta = 0.0
    best_accuracy = 0.0
    
    for beta in beta_range:
        correct = 0
        total = 0
        
        for idx in range(min(len(validation_images), 100)):  # Limit for efficiency
            pil_img = numpy_to_pil(validation_images[idx])
            image_input = processor(images=pil_img, return_tensors="tf")['pixel_values']
            label = validation_labels[idx]
            
            if label >= len(text_features):
                continue
            
            vision_out = model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
            
            cls_embed = model.visual_projection(vision_out.last_hidden_state[:, 0, :])
            cls_normalized = tf.nn.l2_normalize(cls_embed, axis=-1)
            cls_logits = tf.squeeze(tf.matmul(cls_normalized, text_features, transpose_b=True))
            
            if enrichment_type == "patch":
                enriched = tf.zeros_like(vision_out.last_hidden_state[0, 0, :])
                
                for layer_idx, head_idx in common_heads:
                    attn_weights = extract_attention_weights(vision_out.attentions, layer_idx, head_idx)
                    layer_enriched = compute_enriched_patches(vision_out.last_hidden_state, attn_weights)
                    enriched += tf.squeeze(layer_enriched)
                
                enriched /= len(common_heads)
                projected = model.visual_projection(tf.expand_dims(enriched, 0))
                
            elif enrichment_type == "cls":
                enriched = tf.zeros_like(extract_hidden_states(vision_out, -1)[0, 0, :])
                
                for layer_idx, head_idx in common_heads:
                    layer_attention = vision_out.attentions[layer_idx][:, head_idx, 0, :]
                    layer_hidden = extract_hidden_states(vision_out, layer_idx)
                    
                    layer_enriched = compute_enriched_cls(layer_hidden, layer_attention)
                    enriched += tf.squeeze(layer_enriched)
                
                enriched /= len(common_heads)
                projected = model.visual_projection(tf.expand_dims(enriched, 0))
            
            normalized = tf.nn.l2_normalize(projected, axis=-1)
            enriched_logits = tf.squeeze(tf.matmul(normalized, text_features, transpose_b=True))
            
            combined_logits = beta * cls_logits + (1 - beta) * enriched_logits
            pred = tf.argmax(combined_logits).numpy()
            
            if pred == label:
                correct += 1
            total += 1
        
        accuracy = correct / total if total > 0 else 0
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_beta = beta
    
    print(f"Optimal beta: {best_beta:.3f} (accuracy: {best_accuracy:.4f})")
    return best_beta

# TensorFlow CNN+MLP Classifier
class HybridClassifier(tf.keras.Model):
    def __init__(self, mlp_input_dim, num_classes):
        super(HybridClassifier, self).__init__()
        
        # CNN for processing 2D patch features
        self.cnn = tf.keras.Sequential([
            tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu'),
            tf.keras.layers.MaxPool2D(2, 2),
            tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu'),
            tf.keras.layers.MaxPool2D(2, 2),
            tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(512, activation='relu')
        ])
        
        # MLP for processing 1D CLS features
        self.mlp = tf.keras.layers.Dense(512, activation='relu')
        
        # Final classifier
        self.classifier = tf.keras.layers.Dense(num_classes)
        
    def call(self, patch_features, cls_features):
        # Process patch features with CNN
        cnn_out = self.cnn(patch_features)
        
        # Process CLS features with MLP
        mlp_out = self.mlp(cls_features)
        
        # Combine features
        combined = tf.concat([cnn_out, mlp_out], axis=1)
        
        return self.classifier(combined)

# Feature extraction function
def extract_features(images, labels, patch_heads, cls_heads):
    cls_features_list = []
    patch_features_list = []
    combined_cls_features_list = []
    labels_list = []
  
    for idx in range(len(images)):
        if idx % 50 == 0:
            print(f"Processing {idx}/{len(images)}")
        
        if labels[idx] >= num_classes:
            continue
            
        pil_img = numpy_to_pil(images[idx])
        image_input = processor(images=pil_img, return_tensors="tf")['pixel_values']
        
        vision_out = model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
        
        # CLS features
        cls_embed = model.visual_projection(vision_out.last_hidden_state[:, 0, :])
        cls_feat = tf.nn.l2_normalize(cls_embed, axis=-1)
        
        # Enriched CLS
        enriched_cls = tf.zeros_like(extract_hidden_states(vision_out, -1)[0, 0, :])
        for layer_idx, head_idx in cls_heads:
            layer_attention = vision_out.attentions[layer_idx][:, head_idx, 0, :]
            layer_hidden = extract_hidden_states(vision_out, layer_idx)
            
            layer_enriched = compute_enriched_cls(layer_hidden, layer_attention)
            enriched_cls += tf.squeeze(layer_enriched)
        
        enriched_cls /= len(cls_heads)
        enriched_cls_proj = model.visual_projection(tf.expand_dims(enriched_cls, 0))
        enriched_cls_feat = tf.nn.l2_normalize(enriched_cls_proj, axis=-1)
        
        # Enriched patches
        patch_embeds = vision_out.last_hidden_state[:, 1:, :]
        enriched_patches = tf.zeros_like(patch_embeds[0])
        
        for layer_idx, head_idx in patch_heads:
            attn_weights = extract_attention_weights(vision_out.attentions, layer_idx, head_idx)
            weighted_patches = compute_enriched_patches(vision_out.last_hidden_state, attn_weights)
            enriched_patches += tf.expand_dims(tf.squeeze(weighted_patches), 0)
        
        enriched_patches /= len(patch_heads)
        patch_2d = tf.reshape(enriched_patches[0], [7, 7, 768])  # ViT-Base-Patch32 creates 7x7 patches
        
        combined_cls = tf.concat([cls_feat, enriched_cls_feat], axis=1)
        
        # Convert to numpy for storage
        cls_features_list.append(cls_feat.numpy().squeeze())
        patch_features_list.append(patch_2d.numpy())
        combined_cls_features_list.append(combined_cls.numpy().squeeze())
        labels_list.append(labels[idx])
    
    return (np.array(cls_features_list), 
            np.array(patch_features_list),
            np.array(combined_cls_features_list),
            np.array(labels_list))


# Find heads
common_patch_heads = find_best_patch_heads(val_images, val_labels, vanilla_text_features)
common_cls_heads = find_best_cls_heads(val_images, val_labels, vanilla_text_features)

# Find betas
optimal_beta1 = find_optimal_beta(val_images, val_labels, vanilla_text_features, common_patch_heads, "patch")
optimal_beta2 = find_optimal_beta(val_images, val_labels, vanilla_text_features, common_cls_heads, "cls")
optimal_beta3 = optimal_beta2

# Extract training features
cls_features, patch_features, combined_cls_features, train_labels = extract_features(
    val_images, val_labels, common_patch_heads, common_cls_heads)


# Patch classifier
patch_classifier = HybridClassifier(mlp_input_dim=512, num_classes=num_classes)
patch_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
patch_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Train patch classifier
print("Training Patch Classifier...")
batch_size = 16
num_batches = len(cls_features) // batch_size

for epoch in range(50):
    total_loss = 0
    correct = 0
    total = 0
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_cls = cls_features[start_idx:end_idx]
        batch_patch = patch_features[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        
        with tf.GradientTape() as tape:
            outputs = patch_classifier(batch_patch, batch_cls, training=True)
            loss = patch_loss_fn(batch_labels, outputs)
        
        gradients = tape.gradient(loss, patch_classifier.trainable_variables)
        patch_optimizer.apply_gradients(zip(gradients, patch_classifier.trainable_variables))
        
        total_loss += loss.numpy()
        predicted = tf.argmax(outputs, axis=1).numpy()
        correct += np.sum(predicted == batch_labels)
        total += len(batch_labels)
    
    if (epoch + 1) % 10 == 0:
        print(f'Patch Epoch [{epoch+1}/50], Accuracy: {100*correct/total:.2f}%')

# CLS classifier
cls_classifier = HybridClassifier(mlp_input_dim=1024, num_classes=num_classes)
cls_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

print("Training CLS Classifier...")
for epoch in range(50):
    total_loss = 0
    correct = 0
    total = 0
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_cls = combined_cls_features[start_idx:end_idx]
        batch_patch = patch_features[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        
        with tf.GradientTape() as tape:
            outputs = cls_classifier(batch_patch, batch_cls, training=True)
            loss = patch_loss_fn(batch_labels, outputs)
        
        gradients = tape.gradient(loss, cls_classifier.trainable_variables)
        cls_optimizer.apply_gradients(zip(gradients, cls_classifier.trainable_variables))
        
        total_loss += loss.numpy()
        predicted = tf.argmax(outputs, axis=1).numpy()
        correct += np.sum(predicted == batch_labels)
        total += len(batch_labels)
    
    if (epoch + 1) % 10 == 0:
        print(f'CLS Epoch [{epoch+1}/50], Accuracy: {100*correct/total:.2f}%')

# Combined classifier
combined_classifier = HybridClassifier(mlp_input_dim=1024, num_classes=num_classes)
combined_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

print("Training Combined Classifier...")
for epoch in range(50):
    total_loss = 0
    correct = 0
    total = 0
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_cls = combined_cls_features[start_idx:end_idx]
        batch_patch = patch_features[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        
        with tf.GradientTape() as tape:
            outputs = combined_classifier(batch_patch, batch_cls, training=True)
            loss = patch_loss_fn(batch_labels, outputs)
        
        gradients = tape.gradient(loss, combined_classifier.trainable_variables)
        combined_optimizer.apply_gradients(zip(gradients, combined_classifier.trainable_variables))
        
        total_loss += loss.numpy()
        predicted = tf.argmax(outputs, axis=1).numpy()
        correct += np.sum(predicted == batch_labels)
        total += len(batch_labels)
    
    if (epoch + 1) % 10 == 0:
        print(f'Combined Epoch [{epoch+1}/50], Accuracy: {100*correct/total:.2f}%')

# Evaluate on test set
print("Evaluating on test set...")

correct_counts = [0] * 7
all_predictions = [[] for _ in range(7)]
true_labels = []

# Limit test set for efficiency if it's very large
max_test_samples = min(1000, len(test_images))
test_indices_subset = np.random.choice(len(test_images), max_test_samples, replace=False)

for i, idx in enumerate(test_indices_subset):
    if i % 50 == 0:
        print(f"Processing {i}/{max_test_samples}")
    
    if test_labels[idx] >= num_classes:
        continue
    
    pil_img = numpy_to_pil(test_images[idx])
    image_input = processor(images=pil_img, return_tensors="tf")['pixel_values']
    label = test_labels[idx]
    
    vision_out = model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
    
    # Vanilla CLIP
    cls_embed = model.visual_projection(vision_out.last_hidden_state[:, 0, :])
    cls_normalized = tf.nn.l2_normalize(cls_embed, axis=-1)
    vanilla_logits = tf.squeeze(tf.matmul(cls_normalized, vanilla_text_features, transpose_b=True))
    pred_vanilla = tf.argmax(vanilla_logits).numpy()
    
    # Extract features for other methods
    cls_feat = tf.nn.l2_normalize(cls_embed, axis=-1)
    
    # Enriched patch
    patch_embeds = vision_out.last_hidden_state[:, 1:, :]
    enriched_patches = tf.zeros_like(patch_embeds[0])
    
    for layer_idx, head_idx in common_patch_heads:
        attn_weights = extract_attention_weights(vision_out.attentions, layer_idx, head_idx)
        weighted_patches = compute_enriched_patches(vision_out.last_hidden_state, attn_weights)
        enriched_patches += tf.expand_dims(tf.squeeze(weighted_patches), 0)
    
    enriched_patches /= len(common_patch_heads)
    enriched_patch_1d = tf.reduce_mean(enriched_patches, axis=1)
    patch_proj = model.visual_projection(enriched_patch_1d)
    patch_norm = tf.nn.l2_normalize(patch_proj, axis=-1)
    patch_logits = tf.squeeze(tf.matmul(patch_norm, vanilla_text_features, transpose_b=True))
    
    # Enriched CLS
    enriched_cls = tf.zeros_like(extract_hidden_states(vision_out, -1)[0, 0, :])
    
    for layer_idx, head_idx in common_cls_heads:
        layer_attention = vision_out.attentions[layer_idx][:, head_idx, 0, :]
        layer_hidden = extract_hidden_states(vision_out, layer_idx)
        
        layer_enriched = compute_enriched_cls(layer_hidden, layer_attention)
        enriched_cls += tf.squeeze(layer_enriched)
    
    enriched_cls /= len(common_cls_heads)
    enriched_cls_proj = model.visual_projection(tf.expand_dims(enriched_cls, 0))
    enriched_cls_norm = tf.nn.l2_normalize(enriched_cls_proj, axis=-1)
    enriched_cls_logits = tf.squeeze(tf.matmul(enriched_cls_norm, vanilla_text_features, transpose_b=True))
    
    # Predictions using learning-based approach
    pred_patch_learning = tf.argmax(optimal_beta1 * vanilla_logits + (1 - optimal_beta1) * patch_logits).numpy()
    pred_cls_learning = tf.argmax(optimal_beta2 * vanilla_logits + (1 - optimal_beta2) * enriched_cls_logits).numpy()
    pred_combined_learning = tf.argmax(optimal_beta3 * vanilla_logits + 
                             (1 - optimal_beta3) / 2 * enriched_cls_logits + 
                             (1 - optimal_beta3) / 2 * patch_logits).numpy()
    
    patch_2d = tf.reshape(enriched_patches[0], [7, 7, 768]).numpy()
    patch_2d = np.expand_dims(patch_2d, 0)  # Add batch dimension
    combined_cls_feat = tf.concat([cls_feat, enriched_cls_norm], axis=1).numpy()
    cls_feat_np = cls_feat.numpy()
    
    pred_patch_training = tf.argmax(patch_classifier(patch_2d, cls_feat_np), axis=1).numpy()[0]
    pred_cls_training = tf.argmax(cls_classifier(patch_2d, combined_cls_feat), axis=1).numpy()[0]
    pred_combined_training = tf.argmax(combined_classifier(patch_2d, combined_cls_feat), axis=1).numpy()[0]
    
    # Store results
    predictions = [pred_vanilla, pred_patch_learning, pred_patch_training, 
                  pred_cls_learning, pred_cls_training, pred_combined_learning, pred_combined_training]
    
    for j, pred in enumerate(predictions):
        all_predictions[j].append(pred)
        if pred == label:
            correct_counts[j] += 1
    
    true_labels.append(label)

# Results
eval_size = len(true_labels)
method_names = ["Vanilla CLIP", "Patch Learning", "Patch Training", 
                "CLS Learning", "CLS Training", "Combined Learning", "INFER"]



for i, name in enumerate(method_names):
    accuracy = correct_counts[i] / eval_size
    print(f"{i+1}. {name:25} {accuracy:.4f}")

# Visualizations
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
cmaps = ['Blues', 'Greens', 'Reds', 'Purples', 'Oranges', 'YlOrRd', 'RdPu']

# Show confusion matrices for first few methods
methods_to_show = min(7, len(method_names))
for i in range(methods_to_show):
    row, col = i // 4, i % 4
    if i < len(all_predictions):
        # Limit classes shown in confusion matrix for readability
        max_classes_shown = min(10, num_classes)
        class_subset = list(range(max_classes_shown))
        
        # Filter predictions and labels for subset
        filtered_true = [l for l in true_labels if l < max_classes_shown]
        filtered_pred = [all_predictions[i][j] for j, l in enumerate(true_labels) if l < max_classes_shown]
        
        if len(filtered_true) > 0:
            cm = confusion_matrix(filtered_true, filtered_pred, labels=class_subset)
            class_names_subset = [class_names[j] if j < len(class_names) else f"Class_{j}" 
                                for j in class_subset]
            disp = ConfusionMatrixDisplay(cm, display_labels=class_names_subset)
            disp.plot(ax=axes[row, col], cmap=cmaps[i % len(cmaps)], values_format='d')
            axes[row, col].set_title(method_names[i])
            axes[row, col].tick_params(axis='x', rotation=45, labelsize=8)
            axes[row, col].tick_params(axis='y', labelsize=8)

# Hide empty subplots
for i in range(methods_to_show, 8):
    row, col = i // 4, i % 4
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()

# Performance comparison
accuracies = [count / eval_size for count in correct_counts]
short_names = ['Vanilla', 'Patch-L', 'Patch-T', 'CLS-L', 'CLS-T', 'Comb-L', 'Comb-T']

plt.figure(figsize=(12, 6))
bars = plt.bar(short_names, accuracies, color=['blue', 'green', 'red', 'purple', 'orange', 'brown', 'pink'])
plt.ylabel('Accuracy')
plt.title('Performance Comparison on Oxford-IIIT Pet Dataset')
plt.xticks(rotation=45)
plt.ylim(0, 1)

for bar, acc in zip(bars, accuracies):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{acc:.3f}', ha='center', va='bottom')

plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Display sample images from each class (first 10 classes)
max_classes_to_show = min(10, num_classes)
plt.figure(figsize=(15, 6))
for i in range(max_classes_to_show):
    # Find first image of this class in test set
    class_indices = np.where(np.array(true_labels) == i)[0]
    if len(class_indices) > 0:
        sample_idx = class_indices[0]
        original_test_idx = test_indices_subset[sample_idx]
        sample_image = test_images[original_test_idx]
        
        plt.subplot(2, 5, i + 1)
        plt.imshow(sample_image.astype(np.uint8))
        class_name = class_names[i] if i < len(class_names) else f"Class_{i}"
        plt.title(f'{class_name.replace("_", " ").title()}', fontsize=10)
        plt.axis('off')

plt.tight_layout()
plt.show()

# Per-class accuracy analysis for best method
best_method_idx = np.argmax(accuracies)
best_predictions = all_predictions[best_method_idx]
best_method_name = method_names[best_method_idx]

print(f"\nBest performing method: {best_method_name} ({accuracies[best_method_idx]:.4f})")
print("\nPer-class accuracy for best method (top 10 classes):")

class_accuracies = []
for class_idx in range(min(10, num_classes)):
    class_mask = np.array(true_labels) == class_idx
    if np.sum(class_mask) > 0:
        class_predictions = np.array(best_predictions)[class_mask]
        class_accuracy = np.sum(class_predictions == class_idx) / len(class_predictions)
        class_name = class_names[class_idx] if class_idx < len(class_names) else f"Class_{class_idx}"
        print(f"  {class_name.replace('_', ' ').title():20}: {class_accuracy:.4f} ({np.sum(class_predictions == class_idx)}/{len(class_predictions)})")
        class_accuracies.append(class_accuracy)

# Feature dimensionality info
print(f"\nFeature dimensions:")
print(f"  CLS features: {cls_features.shape}")
print(f"  Patch features: {patch_features.shape}")
print(f"  Combined CLS features: {combined_cls_features.shape}")

# Head analysis summary
print(f"\nSelected attention heads:")
print(f"  Patch heads: {len(common_patch_heads)} heads across {len(set([layer for layer, head in common_patch_heads]))} layers")
print(f"  CLS heads: {len(common_cls_heads)} heads across {len(set([layer for layer, head in common_cls_heads]))} layers")

