import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.manifold import TSNE
from collections import defaultdict
import random
from transformers import TFCLIPModel, CLIPProcessor
from PIL import Image
import requests
import os
import tarfile
from scipy.io import loadmat

# Set seeds for reproducibility
random.seed(53)
np.random.seed(53)
tf.random.set_seed(53)

# Load SUN397 dataset
def load_sun_dataset():
    print("Downloading SUN397 dataset...")
    
    # Dataset URLs
    images_url = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
    partitions_url = "https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip"
    
    dataset_dir = "./datasets/SUN397"
    images_dir = os.path.join(dataset_dir, "SUN397")
    partitions_dir = os.path.join(dataset_dir, "Partitions")
    
    try:
        # Create directories
        os.makedirs(dataset_dir, exist_ok=True)
        
        # Download and extract images if not already present
        if not os.path.exists(images_dir):
            print("Downloading SUN397 images... (this may take a while, ~37GB)")
            print("Note: This is a large download. You may want to use a subset for testing.")
            
            # For demonstration, we'll use a smaller approach
            print("Creating demo subset of SUN397...")
            create_sun_demo_dataset(images_dir)
        
        # Download partitions info
        if not os.path.exists(partitions_dir):
            print("Downloading SUN397 partitions...")
            import zipfile
            response = requests.get(partitions_url, stream=True)
            response.raise_for_status()
            
            partitions_zip_path = os.path.join(dataset_dir, "Partitions.zip")
            with open(partitions_zip_path, "wb") as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            
            with zipfile.ZipFile(partitions_zip_path, 'r') as zip_ref:
                zip_ref.extractall(dataset_dir)
            
            os.remove(partitions_zip_path)
            
    except Exception as e:
        print(f"Error downloading dataset: {e}")
        print("Creating demo dataset instead...")
        create_sun_demo_dataset(images_dir)
    
    # Load the dataset
    return load_sun_images_from_directory(images_dir)

def create_sun_demo_dataset(target_dir):
    """Create a demo dataset with scene images for testing"""
    print("Creating demo SUN dataset with sample scene categories...")
    
    # Scene categories for demo
    demo_categories = [
        'kitchen',
        'bedroom', 
        'living_room',
        'bathroom',
        'office'
    ]
    
    # Create directory structure
    for category in demo_categories:
        category_dir = os.path.join(target_dir, category)
        os.makedirs(category_dir, exist_ok=True)
    
    # For demo purposes, we'll create placeholder images
    print("Note: Creating placeholder images for demo. Replace with actual SUN397 data for real experiments.")
    
    from PIL import Image, ImageDraw, ImageFont
    import random
    
    # Create sample images for each category
    for category in demo_categories:
        category_dir = os.path.join(target_dir, category)
        
        # Create 100 sample images per category for demo
        for i in range(100):
            # Create a colored image with text label
            img = Image.new('RGB', (224, 224), 
                          color=(random.randint(50, 200), 
                                random.randint(50, 200), 
                                random.randint(50, 200)))
            
            draw = ImageDraw.Draw(img)
            
            # Add some geometric shapes to simulate scene content
            for _ in range(random.randint(3, 8)):
                shape_type = random.choice(['rectangle', 'circle'])
                x1, y1 = random.randint(0, 150), random.randint(0, 150)
                x2, y2 = x1 + random.randint(20, 70), y1 + random.randint(20, 70)
                
                color = (random.randint(0, 255), 
                        random.randint(0, 255), 
                        random.randint(0, 255))
                
                if shape_type == 'rectangle':
                    draw.rectangle([x1, y1, x2, y2], fill=color)
                else:
                    draw.ellipse([x1, y1, x2, y2], fill=color)
            
            # Add category text
            try:
                font = ImageFont.load_default()
                draw.text((10, 10), category.replace('_', ' ').title(), 
                         fill=(255, 255, 255), font=font)
            except:
                draw.text((10, 10), category.replace('_', ' ').title(), 
                         fill=(255, 255, 255))
            
            img_path = os.path.join(category_dir, f"{category}_{i:03d}.jpg")
            img.save(img_path)
  

def load_sun_images_from_directory(images_dir):
    """Load images from the SUN dataset directory structure"""
    
    # Find all subdirectories (scene categories)
    scene_categories = []
    for root, dirs, files in os.walk(images_dir):
        if files and any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in files):
            rel_path = os.path.relpath(root, images_dir)
            if rel_path != '.':
                scene_categories.append(rel_path)
    
    # For demo, limit to first 10 categories to keep it manageable
    scene_categories = sorted(scene_categories)[:10]
    print(f"Found {len(scene_categories)} scene categories")
    print("Categories:", scene_categories[:5], "..." if len(scene_categories) > 5 else "")
    
    # Create class names mapping
    class_names = [cat.replace('_', ' ').replace('/', '_') for cat in scene_categories]
    
    images = []
    labels = []
    
    print("Loading images...")
    for class_idx, category in enumerate(scene_categories):
        category_path = os.path.join(images_dir, category)
        image_files = [f for f in os.listdir(category_path) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
        # Limit images per category for demo
        image_files = image_files[:50]
        
        print(f"Loading {len(image_files)} images from {category}...")
        
        for img_file in image_files:
            img_path = os.path.join(category_path, img_file)
            try:
                img = Image.open(img_path).convert('RGB')
                img = img.resize((224, 224))
                img_array = np.array(img)
                
                images.append(img_array)
                labels.append(class_idx)
                
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
    
    images_array = np.array(images)
    labels_array = np.array(labels)
    
    
    # Display class distribution
    unique, counts = np.unique(labels_array, return_counts=True)
    print("\nClass distribution:")
    for class_idx, count in zip(unique, counts):
        print(f"  {class_names[class_idx]}: {count} images")
    
    return images_array, labels_array, class_names

# TensorFlow CLIP Model wrapper
class TensorFlowCLIP:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model = TFCLIPModel.from_pretrained(model_name, output_attentions=True)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.model_name = model_name
        
    def normalize_features(self, features):
        """L2 normalize features"""
        return tf.nn.l2_normalize(features, axis=-1)
    
    def get_text_features(self, texts):
        """Extract text features"""
        inputs = self.processor(text=texts, return_tensors="tf", padding=True, truncation=True)
        outputs = self.model.get_text_features(**inputs)
        return self.normalize_features(outputs)
    
    def get_vision_features_with_attention(self, images):
        """Extract vision features with attention maps"""
        # Convert PIL images to tensor format
        if isinstance(images[0], Image.Image):
            pixel_values = []
            for img in images:
                inputs = self.processor(images=img, return_tensors="tf")
                pixel_values.append(inputs['pixel_values'][0])
            pixel_values = tf.stack(pixel_values)
        else:
            # Assume already preprocessed
            pixel_values = images
        
        # Get vision model outputs with attention
        vision_outputs = self.model.vision_model(
            pixel_values=pixel_values, 
            output_attentions=True, 
            output_hidden_states=True
        )
        
        return {
            'last_hidden_state': vision_outputs.last_hidden_state,
            'attentions': vision_outputs.attentions,
            'hidden_states': vision_outputs.hidden_states,
            'pooler_output': vision_outputs.pooler_output
        }
    
    def get_image_features(self, images):
        if isinstance(images[0], Image.Image):
            inputs = self.processor(images=images, return_tensors="tf")
            pixel_values = inputs['pixel_values']
        else:
            pixel_values = images
            
        outputs = self.model.get_image_features(pixel_values=pixel_values)
        return self.normalize_features(outputs)

# Load CLIP model
clip_model = TensorFlowCLIP()

# Load dataset
all_images, all_labels, class_names = load_sun_dataset()
num_classes = len(class_names)
print(f"Total samples: {len(all_images)}")
print(f"Number of classes: {num_classes}")

# Create text features for scene classes
scene_prompts = [f"a photo of a {class_name.replace('_', ' ')}" for class_name in class_names]
vanilla_text_features = clip_model.get_text_features(scene_prompts)

# Create train/test split
samples_per_class = 16
remaining_samples = []
val_samples_per_class = defaultdict(list)

# First pass: collect validation samples (16 per class)
for idx, label in enumerate(all_labels):
    if len(val_samples_per_class[label]) < samples_per_class:
        val_samples_per_class[label].append(idx)
    else:
        remaining_samples.append(idx)

# Create validation and test indices
val_indices = []
for class_label in range(num_classes):
    val_indices.extend(val_samples_per_class[class_label])

test_indices = remaining_samples.copy()

# Create validation and test datasets
val_images = all_images[val_indices]
val_labels = all_labels[val_indices]
test_images = all_images[test_indices]
test_labels = all_labels[test_indices]

def numpy_to_pil(img_array):
    """Convert numpy array to PIL Image"""
    if img_array.max() <= 1.0:
        img_array = img_array * 255.0
    img_array = img_array.astype(np.uint8)
    return Image.fromarray(img_array)

def find_best_patch_heads(validation_images, validation_labels, text_features):
    print("Finding best heads for patch enrichment...")
    all_selected_heads = []
    
    for idx in range(len(validation_images)):
        if idx % 20 == 0:
            print(f"Processing sample {idx}/{len(validation_images)}")
        
        pil_img = numpy_to_pil(validation_images[idx])
        inputs = clip_model.processor(images=pil_img, return_tensors="tf")
        pixel_values = inputs['pixel_values']
        label = validation_labels[idx]
        
        # Get vision features with attention
        vision_outputs = clip_model.get_vision_features_with_attention([pil_img])
        
        # Extract patch embeddings (excluding CLS token)
        patch_embeds = vision_outputs['last_hidden_state'][:, 1:, :]  # Shape: [1, 196, 768]
        
        sample_heads = []
        # Analyze last 3 layers
        for layer_offset in range(3):
            layer_idx = -(layer_offset + 1)
            layer_attention = vision_outputs['attentions'][layer_idx]  # Shape: [1, 12, 197, 197]
            
            head_scores = []
            for head_idx in range(12):  # 12 attention heads
                # Get attention weights from CLS to patches
                attn_weights = layer_attention[0, head_idx, 0, 1:]  # Shape: [196]
                attn_weights = tf.expand_dims(attn_weights, axis=-1)  # Shape: [196, 1]
                
                # Weight patches by attention
                weighted_patches = patch_embeds[0] * attn_weights  # Broadcasting
                avg_patch = tf.reduce_mean(weighted_patches, axis=0)  # Shape: [768]
                
                # Project through visual projection
                projected = clip_model.model.visual_projection(tf.expand_dims(avg_patch, 0))
                projected = clip_model.normalize_features(projected)
                
                # Calculate similarity with target text
                target_text_feat = tf.expand_dims(text_features[label], 0)
                similarity = tf.reduce_sum(projected * target_text_feat, axis=-1)
                head_scores.append((head_idx, float(similarity.numpy())))
            
            # Sort by similarity and take top 3
            head_scores.sort(key=lambda x: x[1], reverse=True)
            top_3 = [head for head, _ in head_scores[:3]]
            sample_heads.extend([(layer_idx, head) for head in top_3])
        
        all_selected_heads.append(sample_heads)
    
    # Count head frequency across all samples
    head_frequency = defaultdict(int)
    for sample_heads in all_selected_heads:
        for layer_head in sample_heads:
            head_frequency[layer_head] += 1
    
    # Get most common heads
    most_common = sorted(head_frequency.items(), key=lambda x: x[1], reverse=True)[:9]
    common_heads = [layer_head for layer_head, count in most_common]
    
    print("Most common patch heads:")
    for (layer_idx, head_idx), count in most_common:
        print(f"  Layer {layer_idx}, Head {head_idx}: {count} times")
    
    return common_heads

def find_best_cls_heads(validation_images, validation_labels, text_features):
    print("Finding best heads for CLS enrichment...")
    all_selected_heads = []
    
    for idx in range(len(validation_images)):
        if idx % 20 == 0:
            print(f"Processing sample {idx}/{len(validation_images)}")
        
        pil_img = numpy_to_pil(validation_images[idx])
        label = validation_labels[idx]
        
        # Get vision features with attention
        vision_outputs = clip_model.get_vision_features_with_attention([pil_img])
        
        sample_heads = []
        # Analyze last 3 layers
        for layer_offset in range(3):
            layer_idx = -(layer_offset + 1)
            layer_attention = vision_outputs['attentions'][layer_idx]
            layer_hidden = vision_outputs['hidden_states'][layer_idx]
            
            head_scores = []
            for head_idx in range(12):
                # Get CLS attention weights (attention FROM CLS TO all tokens)
                cls_attention = layer_attention[0, head_idx, 0, :]  # Shape: [197]
                cls_attention = tf.nn.softmax(cls_attention, axis=0)
                cls_attention = tf.expand_dims(cls_attention, axis=-1)  # Shape: [197, 1]
                
                # Create enriched CLS representation
                enriched_cls = tf.reduce_sum(layer_hidden[0] * cls_attention, axis=0)  # Shape: [768]
                
                # Project and normalize
                projected = clip_model.model.visual_projection(tf.expand_dims(enriched_cls, 0))
                projected = clip_model.normalize_features(projected)
                
                # Calculate similarity
                target_text_feat = tf.expand_dims(text_features[label], 0)
                similarity = tf.reduce_sum(projected * target_text_feat, axis=-1)
                head_scores.append((head_idx, float(similarity.numpy())))
            
            head_scores.sort(key=lambda x: x[1], reverse=True)
            top_3 = [head for head, _ in head_scores[:3]]
            sample_heads.extend([(layer_idx, head) for head in top_3])
        
        all_selected_heads.append(sample_heads)
    
    # Count frequency
    head_frequency = defaultdict(int)
    for sample_heads in all_selected_heads:
        for layer_head in sample_heads:
            head_frequency[layer_head] += 1
    
    most_common = sorted(head_frequency.items(), key=lambda x: x[1], reverse=True)[:9]
    common_heads = [layer_head for layer_head, count in most_common]
    
    print("Most common CLS heads:")
    for (layer_idx, head_idx), count in most_common:
        print(f"  Layer {layer_idx}, Head {head_idx}: {count} times")
    
    return common_heads

def find_optimal_beta(validation_images, validation_labels, text_features, common_heads, enrichment_type):
    print(f"Finding optimal beta for {enrichment_type}...")
    
    beta_range = np.linspace(0.0, 1.0, 11)
    best_beta = 0.0
    best_accuracy = 0.0
    
    for beta in beta_range:
        correct = 0
        total = 0
        
        for idx in range(len(validation_images)):
            pil_img = numpy_to_pil(validation_images[idx])
            label = validation_labels[idx]
            
            # Get vision features
            vision_outputs = clip_model.get_vision_features_with_attention([pil_img])
            
            # Get standard CLS features
            cls_embed = clip_model.model.visual_projection(vision_outputs['last_hidden_state'][:, 0, :])
            cls_normalized = clip_model.normalize_features(cls_embed)
            cls_logits = tf.linalg.matvec(text_features, cls_normalized[0])
            
            if enrichment_type == "patch":
                # Patch enrichment
                patch_embeds = vision_outputs['last_hidden_state'][:, 1:, :]
                enriched = tf.zeros_like(patch_embeds[0][0])  # Shape: [768]
                
                for layer_idx, head_idx in common_heads:
                    attn_weights = vision_outputs['attentions'][layer_idx][0, head_idx, 0, 1:]
                    attn_weights = tf.expand_dims(attn_weights, axis=-1)
                    weighted_patches = patch_embeds[0] * attn_weights
                    enriched += tf.reduce_mean(weighted_patches, axis=0)
                
                enriched /= len(common_heads)
                projected = clip_model.model.visual_projection(tf.expand_dims(enriched, 0))
                
            elif enrichment_type == "cls":
                # CLS enrichment
                enriched = tf.zeros_like(vision_outputs['hidden_states'][-1][0][0])
                
                for layer_idx, head_idx in common_heads:
                    layer_attention = vision_outputs['attentions'][layer_idx]
                    layer_hidden = vision_outputs['hidden_states'][layer_idx]
                    
                    cls_attention = layer_attention[0, head_idx, 0, :]
                    cls_attention = tf.nn.softmax(cls_attention, axis=0)
                    cls_attention = tf.expand_dims(cls_attention, axis=-1)
                    
                    layer_enriched = tf.reduce_sum(layer_hidden[0] * cls_attention, axis=0)
                    enriched += layer_enriched
                
                enriched /= len(common_heads)
                projected = clip_model.model.visual_projection(tf.expand_dims(enriched, 0))
            
            normalized = clip_model.normalize_features(projected)
            enriched_logits = tf.linalg.matvec(text_features, normalized[0])
            
            # Combine logits
            combined_logits = beta * cls_logits + (1 - beta) * enriched_logits
            pred = tf.argmax(combined_logits).numpy()
            
            if pred == label:
                correct += 1
            total += 1
        
        accuracy = correct / total
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_beta = float(beta)
    
    print(f"Optimal beta: {best_beta:.3f} (accuracy: {best_accuracy:.4f})")
    return best_beta

# TensorFlow CNN+MLP Classifier
class HybridClassifier(tf.keras.Model):
    def __init__(self, mlp_input_dim, num_classes):
        super(HybridClassifier, self).__init__()
        
        # CNN for processing 2D patch features
        self.cnn = tf.keras.Sequential([
            tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu'),
            tf.keras.layers.MaxPool2D(2, 2),
            tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu'),
            tf.keras.layers.MaxPool2D(2, 2),
            tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(512, activation='relu')
        ])
        
        # MLP for processing 1D CLS features
        self.mlp = tf.keras.layers.Dense(512, activation='relu')
        
        # Final classifier
        self.classifier = tf.keras.layers.Dense(num_classes)
        
    def call(self, patch_features, cls_features):
        # Process patch features with CNN
        cnn_out = self.cnn(patch_features)
        
        # Process CLS features with MLP
        mlp_out = self.mlp(cls_features)
        
        # Combine features
        combined = tf.concat([cnn_out, mlp_out], axis=1)
        
        return self.classifier(combined)

# Feature extraction function
def extract_features(images, labels, patch_heads, cls_heads):
    cls_features_list = []
    patch_features_list = []
    combined_cls_features_list = []
    labels_list = []
    
    print("Extracting features...")
    for idx in range(len(images)):
        if idx % 50 == 0:
            print(f"Processing {idx}/{len(images)}")
            
        pil_img = numpy_to_pil(images[idx])
        
        # Get vision features
        vision_outputs = clip_model.get_vision_features_with_attention([pil_img])
        
        # CLS features
        cls_embed = clip_model.model.visual_projection(vision_outputs['last_hidden_state'][:, 0, :])
        cls_feat = clip_model.normalize_features(cls_embed)
        
        # Enriched CLS
        enriched_cls = tf.zeros_like(vision_outputs['hidden_states'][-1][0][0])
        for layer_idx, head_idx in cls_heads:
            layer_attention = vision_outputs['attentions'][layer_idx]
            layer_hidden = vision_outputs['hidden_states'][layer_idx]
            
            cls_attention = layer_attention[0, head_idx, 0, :]
            cls_attention = tf.nn.softmax(cls_attention, axis=0)
            cls_attention = tf.expand_dims(cls_attention, axis=-1)
            
            layer_enriched = tf.reduce_sum(layer_hidden[0] * cls_attention, axis=0)
            enriched_cls += layer_enriched
        
        enriched_cls /= len(cls_heads)
        enriched_cls_proj = clip_model.model.visual_projection(tf.expand_dims(enriched_cls, 0))
        enriched_cls_feat = clip_model.normalize_features(enriched_cls_proj)
        
        # Enriched patches
        patch_embeds = vision_outputs['last_hidden_state'][:, 1:, :]
        enriched_patches = tf.zeros_like(patch_embeds[0])
        
        for layer_idx, head_idx in patch_heads:
            attn_weights = vision_outputs['attentions'][layer_idx][0, head_idx, 0, 1:]
            attn_weights = tf.expand_dims(attn_weights, axis=-1)
            weighted_patches = patch_embeds[0] * attn_weights
            enriched_patches += weighted_patches
        
        enriched_patches /= len(patch_heads)
        # Reshape to 2D for CNN (assuming 14x14 patches for ViT-Base)
        patch_2d = tf.reshape(enriched_patches, [14, 14, 768])
        
        combined_cls = tf.concat([cls_feat, enriched_cls_feat], axis=1)
        
        # Convert to numpy for storage
        cls_features_list.append(cls_feat[0].numpy())
        patch_features_list.append(patch_2d.numpy())
        combined_cls_features_list.append(combined_cls[0].numpy())
        labels_list.append(labels[idx])
    
    return (np.array(cls_features_list), 
            np.array(patch_features_list),
            np.array(combined_cls_features_list),
            np.array(labels_list))

# Find heads
common_patch_heads = find_best_patch_heads(val_images, val_labels, vanilla_text_features)
common_cls_heads = find_best_cls_heads(val_images, val_labels, vanilla_text_features)

# Find betas
optimal_beta1 = find_optimal_beta(val_images, val_labels, vanilla_text_features, common_patch_heads, "patch")
optimal_beta2 = find_optimal_beta(val_images, val_labels, vanilla_text_features, common_cls_heads, "cls")
optimal_beta3 = optimal_beta2

# Extract training features
cls_features, patch_features, combined_cls_features, train_labels = extract_features(
    val_images, val_labels, common_patch_heads, common_cls_heads)

# Train classifiers
print("Training TensorFlow classifiers...")

# Patch classifier
patch_classifier = HybridClassifier(mlp_input_dim=512, num_classes=num_classes)
patch_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
patch_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Train patch classifier
print("Training Patch Classifier...")
for epoch in range(50):
    total_loss = 0
    correct = 0
    total = 0
    
    # Create batches
    batch_size = 16
    num_batches = len(cls_features) // batch_size
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_cls = cls_features[start_idx:end_idx]
        batch_patch = patch_features[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        
        with tf.GradientTape() as tape:
            outputs = patch_classifier(batch_patch, batch_cls, training=True)
            loss = patch_loss_fn(batch_labels, outputs)
        
        gradients = tape.gradient(loss, patch_classifier.trainable_variables)
        patch_optimizer.apply_gradients(zip(gradients, patch_classifier.trainable_variables))
        
        total_loss += loss.numpy()
        predicted = tf.argmax(outputs, axis=1).numpy()
        correct += np.sum(predicted == batch_labels)
        total += len(batch_labels)
    
    if (epoch + 1) % 10 == 0:
        print(f'Patch Epoch [{epoch+1}/50], Accuracy: {100*correct/total:.2f}%')

# CLS classifier
cls_classifier = HybridClassifier(mlp_input_dim=1024, num_classes=num_classes)
cls_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

print("Training CLS Classifier...")
for epoch in range(50):
    total_loss = 0
    correct = 0
    total = 0
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_cls = combined_cls_features[start_idx:end_idx]
        batch_patch = patch_features[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        
        with tf.GradientTape() as tape:
            outputs = cls_classifier(batch_patch, batch_cls, training=True)
            loss = patch_loss_fn(batch_labels, outputs)
        
        gradients = tape.gradient(loss, cls_classifier.trainable_variables)
        cls_optimizer.apply_gradients(zip(gradients, cls_classifier.trainable_variables))
        
        total_loss += loss.numpy()
        predicted = tf.argmax(outputs, axis=1).numpy()
        correct += np.sum(predicted == batch_labels)
        total += len(batch_labels)
    
    if (epoch + 1) % 10 == 0:
        print(f'CLS Epoch [{epoch+1}/50], Accuracy: {100*correct/total:.2f}%')

# Combined classifier
combined_classifier = HybridClassifier(mlp_input_dim=1024, num_classes=num_classes)
combined_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

print("Training Combined Classifier...")
for epoch in range(50):
    total_loss = 0
    correct = 0
    total = 0
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_cls = combined_cls_features[start_idx:end_idx]
        batch_patch = patch_features[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        
        with tf.GradientTape() as tape:
            outputs = combined_classifier(batch_patch, batch_cls, training=True)
            loss = patch_loss_fn(batch_labels, outputs)
        
        gradients = tape.gradient(loss, combined_classifier.trainable_variables)
        combined_optimizer.apply_gradients(zip(gradients, combined_classifier.trainable_variables))
        
        total_loss += loss.numpy()
        predicted = tf.argmax(outputs, axis=1).numpy()
        correct += np.sum(predicted == batch_labels)
        total += len(batch_labels)
    
    if (epoch + 1) % 10 == 0:
        print(f'Combined Epoch [{epoch+1}/50], Accuracy: {100*correct/total:.2f}%')

# Evaluate on test set
print("Evaluating on test set...")

correct_counts = [0] * 7
all_predictions = [[] for _ in range(7)]
true_labels = []

for idx in range(len(test_images)):
    if idx % 50 == 0:
        print(f"Processing {idx}/{len(test_images)}")
    
    pil_img = numpy_to_pil(test_images[idx])
    label = test_labels[idx]
    
    # Get vision features
    vision_outputs = clip_model.get_vision_features_with_attention([pil_img])
    
    # Vanilla CLIP
    cls_embed = clip_model.model.visual_projection(vision_outputs['last_hidden_state'][:, 0, :])
    cls_normalized = clip_model.normalize_features(cls_embed)
    vanilla_logits = tf.linalg.matvec(vanilla_text_features, cls_normalized[0])
    pred_vanilla = tf.argmax(vanilla_logits).numpy()
    
    # Extract features for other methods
    cls_feat = cls_normalized
    
    # Enriched patch
    patch_embeds = vision_outputs['last_hidden_state'][:, 1:, :]
    enriched_patches = tf.zeros_like(patch_embeds[0])
    
    for layer_idx, head_idx in common_patch_heads:
        attn_weights = vision_outputs['attentions'][layer_idx][0, head_idx, 0, 1:]
        attn_weights = tf.expand_dims(attn_weights, axis=-1)
        weighted_patches = patch_embeds[0] * attn_weights
        enriched_patches += weighted_patches
    
    enriched_patches /= len(common_patch_heads)
    enriched_patch_1d = tf.reduce_mean(enriched_patches, axis=0)
    patch_proj = clip_model.model.visual_projection(tf.expand_dims(enriched_patch_1d, 0))
    patch_norm = clip_model.normalize_features(patch_proj)
    patch_logits = tf.linalg.matvec(vanilla_text_features, patch_norm[0])
    
    # Enriched CLS
    enriched_cls = tf.zeros_like(vision_outputs['hidden_states'][-1][0][0])
    
    for layer_idx, head_idx in common_cls_heads:
        layer_attention = vision_outputs['attentions'][layer_idx]
        layer_hidden = vision_outputs['hidden_states'][layer_idx]
        
        cls_attention = layer_attention[0, head_idx, 0, :]
        cls_attention = tf.nn.softmax(cls_attention, axis=0)
        cls_attention = tf.expand_dims(cls_attention, axis=-1)
        
        layer_enriched = tf.reduce_sum(layer_hidden[0] * cls_attention, axis=0)
        enriched_cls += layer_enriched
    
    enriched_cls /= len(common_cls_heads)
    enriched_cls_proj = clip_model.model.visual_projection(tf.expand_dims(enriched_cls, 0))
    enriched_cls_norm = clip_model.normalize_features(enriched_cls_proj)
    enriched_cls_logits = tf.linalg.matvec(vanilla_text_features, enriched_cls_norm[0])
    
    # Predictions using learning-based approach
    pred_patch_learning = tf.argmax(optimal_beta1 * vanilla_logits + (1 - optimal_beta1) * patch_logits).numpy()
    pred_cls_learning = tf.argmax(optimal_beta2 * vanilla_logits + (1 - optimal_beta2) * enriched_cls_logits).numpy()
    pred_combined_learning = tf.argmax(optimal_beta3 * vanilla_logits + 
                                     (1 - optimal_beta3) / 2 * enriched_cls_logits + 
                                     (1 - optimal_beta3) / 2 * patch_logits).numpy()
    
    # CNN+MLP predictions (TensorFlow)
    patch_2d = tf.reshape(enriched_patches, [1, 14, 14, 768])  # Add batch dimension
    combined_cls_feat = tf.concat([cls_feat, enriched_cls_norm], axis=1)
    cls_feat_np = cls_feat
    
    pred_patch_training = tf.argmax(patch_classifier(patch_2d, cls_feat_np), axis=1).numpy()[0]
    pred_cls_training = tf.argmax(cls_classifier(patch_2d, combined_cls_feat), axis=1).numpy()[0]
    pred_combined_training = tf.argmax(combined_classifier(patch_2d, combined_cls_feat), axis=1).numpy()[0]
    
    # Store results
    predictions = [pred_vanilla, pred_patch_learning, pred_patch_training, 
                  pred_cls_learning, pred_cls_training, pred_combined_learning, pred_combined_training]
    
    for i, pred in enumerate(predictions):
        all_predictions[i].append(pred)
        if pred == label:
            correct_counts[i] += 1
    
    true_labels.append(label)

# Results
eval_size = len(test_images)
method_names = ["Vanilla CLIP", "INFER (Patch Learning)", "INFER (Patch Training)", 
                "INFER (CLS Learning)", "INFER (CLS Training)", 
                "INFER (Combined Learning)", "INFER (Combined Training)"]

print(f"\n{'='*60}")
print("RESULTS ON SUN397 DATASET")
print(f"{'='*60}")

for i, name in enumerate(method_names):
    accuracy = correct_counts[i] / eval_size
    print(f"{i+1}. {name:30} {accuracy:.4f}")

# Visualizations
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
cmaps = ['Blues', 'Greens', 'Oranges', 'Reds', 'Purples', 'RdPu', 'YlOrRd']

for i, (name, preds, cmap) in enumerate(zip(method_names, all_predictions, cmaps)):
    row, col = i // 4, i % 4
    if i < len(method_names):
        cm = confusion_matrix(true_labels, preds, labels=list(range(num_classes)))
        disp = ConfusionMatrixDisplay(cm, display_labels=class_names)
        disp.plot(ax=axes[row, col], cmap=cmap, values_format='d')
        axes[row, col].set_title(name, fontsize=10)
        axes[row, col].tick_params(axis='x', rotation=45, labelsize=8)
        axes[row, col].tick_params(axis='y', labelsize=8)

# Hide unused subplots
if len(method_names) < 8:
    axes[1, 3].axis('off')

plt.tight_layout()
plt.show()

# Performance comparison
accuracies = [count / eval_size for count in correct_counts]
short_names = ['Vanilla', 'Patch-L', 'Patch-T', 'CLS-L', 'CLS-T', 'Comb-L', 'Comb-T']

plt.figure(figsize=(14, 8))
colors = ['blue', 'green', 'orange', 'red', 'purple', 'pink', 'brown']
bars = plt.bar(short_names, accuracies, color=colors)
plt.ylabel('Accuracy', fontsize=12)
plt.xlabel('Methods', fontsize=12)
plt.title('Performance Comparison on SUN397 Dataset', fontsize=14)
plt.xticks(rotation=45)
plt.ylim(0, 1)

for bar, acc in zip(bars, accuracies):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{acc:.3f}', ha='center', va='bottom', fontsize=10)

plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Display sample images from each class
plt.figure(figsize=(20, 4))
for i, class_name in enumerate(class_names):
    # Find first image of this class in test set
    class_indices = np.where(test_labels == i)[0]
    if len(class_indices) > 0:
        sample_idx = class_indices[0]
        sample_image = test_images[sample_idx]
        
        plt.subplot(1, num_classes, i + 1)
        plt.imshow(sample_image.astype(np.uint8))
        plt.title(f'{class_name.replace("_", " ").title()}', fontsize=10)
        plt.axis('off')

plt.suptitle('Sample Images from SUN397 Dataset', fontsize=16)
plt.tight_layout()
plt.show()


# Per-class accuracy analysis
print("\nPer-class accuracy for best method:")
best_method_idx = np.argmax(accuracies)
best_method_name = method_names[best_method_idx]
best_predictions = all_predictions[best_method_idx]

print(f"Best method: {best_method_name} (Overall accuracy: {accuracies[best_method_idx]:.4f})")

for class_idx, class_name in enumerate(class_names):
    class_mask = np.array(true_labels) == class_idx
    class_predictions = np.array(best_predictions)[class_mask]
    if len(class_predictions) > 0:
        class_accuracy = np.sum(class_predictions == class_idx) / len(class_predictions)
        print(f"  {class_name.replace('_', ' ').title():20}: {class_accuracy:.4f} ({np.sum(class_predictions == class_idx)}/{len(class_predictions)})")

# Feature dimensionality info
print(f"\nFeature dimensions:")
print(f"  CLS features: {cls_features.shape}")
print(f"  Patch features: {patch_features.shape}")
print(f"  Combined CLS features: {combined_cls_features.shape}")

# Head analysis summary
print(f"\nSelected attention heads:")
print(f"  Patch heads: {len(common_patch_heads)} heads across {len(set([layer for layer, head in common_patch_heads]))} layers")
print(f"  CLS heads: {len(common_cls_heads)} heads across {len(set([layer for layer, head in common_cls_heads]))} layers")



