import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.manifold import TSNE
from collections import defaultdict
import random
from transformers import CLIPModel, CLIPProcessor, TFCLIPModel
import torch  # Still needed for CLIP model
import torch.nn.functional as F
from PIL import Image
import requests
import cv2
import os
import zipfile
import urllib.request
from pathlib import Path

# Set seeds for reproducibility
random.seed(53)
np.random.seed(53)
tf.random.set_seed(53)

# Load UCF-101 dataset
def load_ucf101_dataset(max_classes=10, frames_per_video=8):
    
    # Create data directory
    data_dir = Path("./datasets/ucf101")
    data_dir.mkdir(parents=True, exist_ok=True)
    
    # Download UCF-101 if not exists
    ucf_url = "https://www.crcv.ucf.edu/data/UCF101/UCF101.rar"
    
    # For this example, we'll use a subset of classes that are commonly available
    # You would need to download the full UCF-101 dataset separately
    selected_classes = [
        "ApplyEyeMakeup", "ApplyLipstick", "Archery", "BabyCrawling", "BalanceBeam",
        "BandMarching", "BaseballPitch", "Basketball", "BasketballDunk", "BenchPress"
    ][:max_classes]
    
    print(f"Using {len(selected_classes)} classes: {selected_classes}")
    
    # Check if we have sample data or need to create dummy data
    videos_dir = data_dir / "videos"
    if not videos_dir.exists() or not any(videos_dir.iterdir()):
        
        # Create sample data structure for demonstration
        create_sample_ucf_data(videos_dir, selected_classes, frames_per_video)
        return load_sample_data(videos_dir, selected_classes, frames_per_video)
    
    # Load actual UCF-101 data
    return load_actual_ucf_data(videos_dir, selected_classes, frames_per_video)

def create_sample_ucf_data(videos_dir, classes, frames_per_video):
    """Create sample data structure for demonstration"""
    videos_dir.mkdir(parents=True, exist_ok=True)
    
    print("Creating sample video frames...")
    samples_per_class = 50  # Reduced for demo
    
    for class_name in classes:
        class_dir = videos_dir / class_name
        class_dir.mkdir(exist_ok=True)
        
        for video_idx in range(samples_per_class):
            video_dir = class_dir / f"video_{video_idx:03d}"
            video_dir.mkdir(exist_ok=True)
            
            # Create sample frames (colored noise images)
            for frame_idx in range(frames_per_video):
                # Generate a sample frame with some class-specific characteristics
                frame = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
                
                # Add some class-specific patterns
                color_shift = classes.index(class_name) * 30
                frame[:, :, 0] = np.clip(frame[:, :, 0] + color_shift, 0, 255)
                
                frame_path = video_dir / f"frame_{frame_idx:03d}.jpg"
                cv2.imwrite(str(frame_path), frame)

def load_sample_data(videos_dir, classes, frames_per_video):
    """Load the sample data we created"""
    images = []
    labels = []
    
    print("Loading sample frames...")
    for class_idx, class_name in enumerate(classes):
        class_dir = videos_dir / class_name
        
        if not class_dir.exists():
            continue
            
        video_dirs = sorted([d for d in class_dir.iterdir() if d.is_dir()])
        
        for video_dir in video_dirs:
            frame_files = sorted([f for f in video_dir.iterdir() if f.suffix.lower() == '.jpg'])
            
            if len(frame_files) >= frames_per_video:
                # Take evenly spaced frames
                frame_indices = np.linspace(0, len(frame_files)-1, frames_per_video, dtype=int)
                
                for frame_idx in frame_indices:
                    frame_path = frame_files[frame_idx]
                    frame = cv2.imread(str(frame_path))
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frame = cv2.resize(frame, (224, 224))
                    
                    images.append(frame)
                    labels.append(class_idx)
    
    print(f"Loaded {len(images)} frames from {len(classes)} classes")
    return np.array(images), np.array(labels), classes

def load_actual_ucf_data(videos_dir, classes, frames_per_video):
    """Load actual UCF-101 data if available"""
    images = []
    labels = []
    
    print("Loading UCF-101 frames...")
    
    for class_idx, class_name in enumerate(classes):
        class_pattern = f"{class_name}_*"
        class_videos = list(videos_dir.glob(f"**/{class_pattern}.avi"))
        
        if not class_videos:
            print(f"No videos found for class {class_name}")
            continue
            
        for video_path in class_videos[:20]:  # Limit videos per class
            try:
                cap = cv2.VideoCapture(str(video_path))
                frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                
                if frame_count < frames_per_video:
                    continue
                
                # Extract evenly spaced frames
                frame_indices = np.linspace(0, frame_count-1, frames_per_video, dtype=int)
                
                for frame_idx in frame_indices:
                    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
                    ret, frame = cap.read()
                    
                    if ret:
                        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                        frame = cv2.resize(frame, (224, 224))
                        images.append(frame)
                        labels.append(class_idx)
                
                cap.release()
                
            except Exception as e:
                print(f"Error processing {video_path}: {e}")
                continue
    
    if len(images) == 0:
        print("No valid video data found. Using sample data instead.")
        return load_sample_data(videos_dir, classes, frames_per_video)
    
    print(f"Loaded {len(images)} frames from {len(classes)} classes")
    return np.array(images), np.array(labels), classes

# Load CLIP model (keeping PyTorch version for feature extraction)
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name, output_attentions=True)
processor = CLIPProcessor.from_pretrained(model_name)
model.eval()

# Load dataset
print("Loading UCF-101 dataset...")
all_images, all_labels, class_names = load_ucf101_dataset(max_classes=5, frames_per_video=8)
num_classes = len(class_names)

# Create text features for action classes
action_prompts = [f"a video of {class_name.lower().replace('_', ' ')}" for class_name in class_names]
vanilla_inputs = processor(text=action_prompts, return_tensors="pt", padding=True, truncation=True)

with torch.no_grad():
    vanilla_text_outputs = model.get_text_features(**vanilla_inputs)
    vanilla_text_features = F.normalize(vanilla_text_outputs, dim=-1)

# Create train/test split
samples_per_class = 32  # Increased for better validation
remaining_samples = []
val_samples_per_class = defaultdict(list)

# First pass: collect validation samples
for idx, label in enumerate(all_labels):
    if len(val_samples_per_class[label]) < samples_per_class:
        val_samples_per_class[label].append(idx)
    else:
        remaining_samples.append(idx)

# Create validation and test indices
val_indices = []
for class_label in range(num_classes):
    if len(val_samples_per_class[class_label]) >= samples_per_class:
        val_indices.extend(val_samples_per_class[class_label])
    else:
        # If not enough samples, use what we have
        val_indices.extend(val_samples_per_class[class_label])
        print(f"Warning: Only {len(val_samples_per_class[class_label])} samples for class {class_names[class_label]}")

# Use remaining samples as test set
test_indices = remaining_samples.copy()

print(f"Validation set: {len(val_indices)} samples")
print(f"Test set: {len(test_indices)} samples")

# Verify distributions
val_class_counts = defaultdict(int)
for idx in val_indices:
    val_class_counts[all_labels[idx]] += 1
print("Validation class distribution:")
for class_idx, count in val_class_counts.items():
    print(f"  {class_names[class_idx]}: {count} samples")

test_class_counts = defaultdict(int)
for idx in test_indices:
    test_class_counts[all_labels[idx]] += 1
print("Test class distribution:")
for class_idx, count in test_class_counts.items():
    print(f"  {class_names[class_idx]}: {count} samples")

# Create validation and test datasets
val_images = all_images[val_indices]
val_labels = all_labels[val_indices]
test_images = all_images[test_indices]
test_labels = all_labels[test_indices]

# Helper function to convert numpy array to PIL Image for CLIP processing
def numpy_to_pil(img_array):
    """Convert numpy array to PIL Image"""
    # Ensure values are in [0, 255] range
    if img_array.max() <= 1.0:
        img_array = img_array * 255.0
    img_array = img_array.astype(np.uint8)
    return Image.fromarray(img_array)

# Helper functions (adapted from original)
def find_best_patch_heads(validation_images, validation_labels, text_features):
    print("Finding best heads for patch enrichment...")
    all_selected_heads = []
    
    sample_size = min(len(validation_images), 80)  # Limit for efficiency
    sample_indices = np.random.choice(len(validation_images), sample_size, replace=False)
    
    for i, idx in enumerate(sample_indices):
        if i % 20 == 0:
            print(f"Processing sample {i}/{sample_size}")
        
        pil_img = numpy_to_pil(validation_images[idx])
        image_input = processor(images=pil_img, return_tensors="pt")['pixel_values']
        label = validation_labels[idx]
        
        with torch.no_grad():
            vision_out = model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
            patch_embeds = vision_out.last_hidden_state[:, 1:, :]
            
            sample_heads = []
            for layer_offset in range(3):
                layer_idx = -(layer_offset + 1)
                layer_attention = vision_out.attentions[layer_idx]
                
                head_scores = []
                for head_idx in range(12):
                    attn_weights = layer_attention[0, head_idx, 0, 1:]
                    weighted_patches = patch_embeds[0] * attn_weights.unsqueeze(1)
                    avg_patch = weighted_patches.mean(dim=0)
                    
                    projected = model.visual_projection(avg_patch.unsqueeze(0))
                    projected = F.normalize(projected, dim=-1)
                    
                    similarity = F.cosine_similarity(projected, text_features[label].unsqueeze(0), dim=-1)
                    head_scores.append((head_idx, similarity.item()))
                
                head_scores.sort(key=lambda x: x[1], reverse=True)
                top_3 = [head for head, _ in head_scores[:3]]
                sample_heads.extend([(layer_idx, head) for head in top_3])
            
            all_selected_heads.append(sample_heads)
    
    head_frequency = defaultdict(int)
    for sample_heads in all_selected_heads:
        for layer_head in sample_heads:
            head_frequency[layer_head] += 1
    
    most_common = sorted(head_frequency.items(), key=lambda x: x[1], reverse=True)[:9]
    common_heads = [layer_head for layer_head, count in most_common]
    
    print("Most common patch heads:")
    for (layer_idx, head_idx), count in most_common:
        print(f"  Layer {layer_idx}, Head {head_idx}: {count} times")
    
    return common_heads

def find_best_cls_heads(validation_images, validation_labels, text_features):
    print("Finding best heads for CLS enrichment...")
    all_selected_heads = []
    
    sample_size = min(len(validation_images), 80)  # Limit for efficiency
    sample_indices = np.random.choice(len(validation_images), sample_size, replace=False)
    
    for i, idx in enumerate(sample_indices):
        if i % 20 == 0:
            print(f"Processing sample {i}/{sample_size}")
        
        pil_img = numpy_to_pil(validation_images[idx])
        image_input = processor(images=pil_img, return_tensors="pt")['pixel_values']
        label = validation_labels[idx]
        
        with torch.no_grad():
            vision_out = model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
            
            sample_heads = []
            for layer_offset in range(3):
                layer_idx = -(layer_offset + 1)
                layer_attention = vision_out.attentions[layer_idx]
                layer_hidden = vision_out.hidden_states[layer_idx]
                
                head_scores = []
                for head_idx in range(12):
                    cls_attention = layer_attention[0, head_idx, 0, :]
                    cls_attention = F.softmax(cls_attention, dim=0)
                    
                    enriched_cls = (layer_hidden[0] * cls_attention.unsqueeze(1)).sum(dim=0)
                    
                    projected = model.visual_projection(enriched_cls.unsqueeze(0))
                    projected = F.normalize(projected, dim=-1)
                    
                    similarity = F.cosine_similarity(projected, text_features[label].unsqueeze(0), dim=-1)
                    head_scores.append((head_idx, similarity.item()))
                
                head_scores.sort(key=lambda x: x[1], reverse=True)
                top_3 = [head for head, _ in head_scores[:3]]
                sample_heads.extend([(layer_idx, head) for head in top_3])
            
            all_selected_heads.append(sample_heads)
    
    head_frequency = defaultdict(int)
    for sample_heads in all_selected_heads:
        for layer_head in sample_heads:
            head_frequency[layer_head] += 1
    
    most_common = sorted(head_frequency.items(), key=lambda x: x[1], reverse=True)[:9]
    common_heads = [layer_head for layer_head, count in most_common]
    
    print("Most common CLS heads:")
    for (layer_idx, head_idx), count in most_common:
        print(f"  Layer {layer_idx}, Head {head_idx}: {count} times")
    
    return common_heads

def find_optimal_beta(validation_images, validation_labels, text_features, common_heads, enrichment_type):
    print(f"Finding optimal beta for {enrichment_type}...")
    
    beta_range = np.linspace(0.0, 1.0, 11)
    best_beta = 0.0
    best_accuracy = 0.0
    
    # Use a smaller sample for beta optimization
    sample_size = min(len(validation_images), 50)
    sample_indices = np.random.choice(len(validation_images), sample_size, replace=False)
    
    for beta in beta_range:
        correct = 0
        total = 0
        
        for idx in sample_indices:
            pil_img = numpy_to_pil(validation_images[idx])
            image_input = processor(images=pil_img, return_tensors="pt")['pixel_values']
            label = validation_labels[idx]
            
            with torch.no_grad():
                vision_out = model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
                
                cls_embed = model.visual_projection(vision_out.last_hidden_state[:, 0, :])
                cls_normalized = F.normalize(cls_embed, dim=-1)
                cls_logits = (cls_normalized @ text_features.T).squeeze()
                
                if enrichment_type == "patch":
                    patch_embeds = vision_out.last_hidden_state[:, 1:, :]
                    enriched = torch.zeros_like(patch_embeds[0][0])
                    
                    for layer_idx, head_idx in common_heads:
                        attn_weights = vision_out.attentions[layer_idx][0, head_idx, 0, 1:]
                        weighted_patches = patch_embeds[0] * attn_weights.unsqueeze(1)
                        enriched += weighted_patches.mean(dim=0)
                    
                    enriched /= len(common_heads)
                    projected = model.visual_projection(enriched.unsqueeze(0))
                    
                elif enrichment_type == "cls":
                    enriched = torch.zeros_like(vision_out.hidden_states[-1][0][0])
                    
                    for layer_idx, head_idx in common_heads:
                        layer_attention = vision_out.attentions[layer_idx]
                        layer_hidden = vision_out.hidden_states[layer_idx]
                        
                        cls_attention = layer_attention[0, head_idx, 0, :]
                        cls_attention = F.softmax(cls_attention, dim=0)
                        
                        layer_enriched = (layer_hidden[0] * cls_attention.unsqueeze(1)).sum(dim=0)
                        enriched += layer_enriched
                    
                    enriched /= len(common_heads)
                    projected = model.visual_projection(enriched.unsqueeze(0))
                
                normalized = F.normalize(projected, dim=-1)
                enriched_logits = (normalized @ text_features.T).squeeze()
                
                combined_logits = beta * cls_logits + (1 - beta) * enriched_logits
                pred = combined_logits.argmax().item()
                
                if pred == label:
                    correct += 1
                total += 1
        
        accuracy = correct / total
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_beta = beta.item()
    
    print(f"Optimal beta: {best_beta:.3f} (accuracy: {best_accuracy:.4f})")
    return best_beta

# TensorFlow CNN+MLP Classifier (same as before)
class HybridClassifier(tf.keras.Model):
    def __init__(self, mlp_input_dim, num_classes=5):
        super(HybridClassifier, self).__init__()
        
        # CNN for processing 2D patch features
        self.cnn = tf.keras.Sequential([
            tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu'),
            tf.keras.layers.MaxPool2D(2, 2),
            tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu'),
            tf.keras.layers.MaxPool2D(2, 2),
            tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(512, activation='relu')
        ])
        
        # MLP for processing 1D CLS features
        self.mlp = tf.keras.layers.Dense(512, activation='relu')
        
        # Final classifier
        self.classifier = tf.keras.layers.Dense(num_classes)
        
    def call(self, patch_features, cls_features):
        # Process patch features with CNN
        # Input shape: (batch_size, 7, 7, 768) -> transpose to (batch_size, 768, 7, 7)
        patch_features = tf.transpose(patch_features, [0, 3, 1, 2])
        cnn_out = self.cnn(patch_features)
        
        # Process CLS features with MLP
        mlp_out = self.mlp(cls_features)
        
        # Combine features
        combined = tf.concat([cnn_out, mlp_out], axis=1)
        
        return self.classifier(combined)

# Feature extraction function (same as before)
def extract_features(images, labels, patch_heads, cls_heads):
    cls_features_list = []
    patch_features_list = []
    combined_cls_features_list = []
    labels_list = []
    
    print("Extracting features...")
    for idx in range(len(images)):
        if idx % 50 == 0:
            print(f"Processing {idx}/{len(images)}")
            
        pil_img = numpy_to_pil(images[idx])
        image_input = processor(images=pil_img, return_tensors="pt")['pixel_values']
        
        with torch.no_grad():
            vision_out = model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
            
            # CLS features
            cls_embed = model.visual_projection(vision_out.last_hidden_state[:, 0, :])
            cls_feat = F.normalize(cls_embed, dim=-1)
            
            # Enriched CLS
            enriched_cls = torch.zeros_like(vision_out.hidden_states[-1][0][0])
            for layer_idx, head_idx in cls_heads:
                layer_attention = vision_out.attentions[layer_idx]
                layer_hidden = vision_out.hidden_states[layer_idx]
                
                cls_attention = layer_attention[0, head_idx, 0, :]
                cls_attention = F.softmax(cls_attention, dim=0)
                
                layer_enriched = (layer_hidden[0] * cls_attention.unsqueeze(1)).sum(dim=0)
                enriched_cls += layer_enriched
            
            enriched_cls /= len(cls_heads)
            enriched_cls_proj = model.visual_projection(enriched_cls.unsqueeze(0))
            enriched_cls_feat = F.normalize(enriched_cls_proj, dim=-1)
            
            # Enriched patches
            patch_embeds = vision_out.last_hidden_state[:, 1:, :]
            enriched_patches = torch.zeros_like(patch_embeds[0])
            
            for layer_idx, head_idx in patch_heads:
                attn_weights = vision_out.attentions[layer_idx][0, head_idx, 0, 1:]
                weighted_patches = patch_embeds[0] * attn_weights.unsqueeze(1)
                enriched_patches += weighted_patches
            
            enriched_patches /= len(patch_heads)
            patch_2d = enriched_patches.view(7, 7, 768)  # ViT-Base-Patch32 creates 7x7 patches
            
            combined_cls = torch.cat([cls_feat, enriched_cls_feat], dim=1)
            
            # Convert to numpy for TensorFlow
            cls_features_list.append(cls_feat.squeeze(0).numpy())
            patch_features_list.append(patch_2d.numpy())
            combined_cls_features_list.append(combined_cls.squeeze(0).numpy())
            labels_list.append(labels[idx])
    
    return (np.array(cls_features_list), 
            np.array(patch_features_list),
            np.array(combined_cls_features_list),
            np.array(labels_list))



# Find heads (with reduced samples for efficiency)
common_patch_heads = find_best_patch_heads(val_images, val_labels, vanilla_text_features)
common_cls_heads = find_best_cls_heads(val_images, val_labels, vanilla_text_features)

# Find betas
optimal_beta1 = find_optimal_beta(val_images, val_labels, vanilla_text_features, common_patch_heads, "patch")
optimal_beta2 = find_optimal_beta(val_images, val_labels, vanilla_text_features, common_cls_heads, "cls")
optimal_beta3 = optimal_beta2

# Extract training features
cls_features, patch_features, combined_cls_features, train_labels = extract_features(
    val_images, val_labels, common_patch_heads, common_cls_heads)

# Train classifiers
print("Training TensorFlow classifiers...")

# Patch classifier
patch_classifier = HybridClassifier(mlp_input_dim=512, num_classes=num_classes)
patch_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
patch_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Train patch classifier
print("Training Patch Classifier...")
batch_size = 8  # Reduced for potential memory constraints
epochs = 30  # Reduced for faster training

for epoch in range(epochs):
    total_loss = 0
    correct = 0
    total = 0
    
    # Create batches
    num_batches = len(cls_features) // batch_size
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_cls = cls_features[start_idx:end_idx]
        batch_patch = patch_features[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        
        with tf.GradientTape() as tape:
            outputs = patch_classifier(batch_patch, batch_cls, training=True)
            loss = patch_loss_fn(batch_labels, outputs)
        
        gradients = tape.gradient(loss, patch_classifier.trainable_variables)
        patch_optimizer.apply_gradients(zip(gradients, patch_classifier.trainable_variables))
        
        total_loss += loss.numpy()
        predicted = tf.argmax(outputs, axis=1).numpy()
        correct += np.sum(predicted == batch_labels)
        total += len(batch_labels)
    
    if (epoch + 1) % 10 == 0:
        print(f'Patch Epoch [{epoch+1}/{epochs}], Accuracy: {100*correct/total:.2f}%')

# CLS classifier
cls_classifier = HybridClassifier(mlp_input_dim=1024, num_classes=num_classes)
cls_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

print("Training CLS Classifier...")
for epoch in range(epochs):
    total_loss = 0
    correct = 0
    total = 0
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_cls = combined_cls_features[start_idx:end_idx]
        batch_patch = patch_features[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        
        with tf.GradientTape() as tape:
            outputs = cls_classifier(batch_patch, batch_cls, training=True)
            loss = patch_loss_fn(batch_labels, outputs)
        
        gradients = tape.gradient(loss, cls_classifier.trainable_variables)
        cls_optimizer.apply_gradients(zip(gradients, cls_classifier.trainable_variables))
        
        total_loss += loss.numpy()
        predicted = tf.argmax(outputs, axis=1).numpy()
        correct += np.sum(predicted == batch_labels)
        total += len(batch_labels)
    
    if (epoch + 1) % 10 == 0:
        print(f'CLS Epoch [{epoch+1}/{epochs}], Accuracy: {100*correct/total:.2f}%')

# Combined classifier
combined_classifier = HybridClassifier(mlp_input_dim=1024, num_classes=num_classes)
combined_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

print("Training Combined Classifier...")
for epoch in range(epochs):
    total_loss = 0
    correct = 0
    total = 0
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_cls = combined_cls_features[start_idx:end_idx]
        batch_patch = patch_features[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        
        with tf.GradientTape() as tape:
            outputs = combined_classifier(batch_patch, batch_cls, training=True)
            loss = patch_loss_fn(batch_labels, outputs)
        
        gradients = tape.gradient(loss, combined_classifier.trainable_variables)
        combined_optimizer.apply_gradients(zip(gradients, combined_classifier.trainable_variables))
        
        total_loss += loss.numpy()
        predicted = tf.argmax(outputs, axis=1).numpy()
        correct += np.sum(predicted == batch_labels)
        total += len(batch_labels)
    
    if (epoch + 1) % 10 == 0:
        print(f'Combined Epoch [{epoch+1}/{epochs}], Accuracy: {100*correct/total:.2f}%')


correct_counts = [0] * 7
all_predictions = [[] for _ in range(7)]
true_labels = []

for idx in range(len(test_images)):
    if idx % 50 == 0:
        print(f"Processing {idx}/{len(test_images)}")
    
    pil_img = numpy_to_pil(test_images[idx])
    image_input = processor(images=pil_img, return_tensors="pt")['pixel_values']
    label = test_labels[idx]
    
    with torch.no_grad():
        vision_out = model.vision_model(pixel_values=image_input, output_attentions=True, output_hidden_states=True)
        
        # Vanilla CLIP
        cls_embed = model.visual_projection(vision_out.last_hidden_state[:, 0, :])
        cls_normalized = F.normalize(cls_embed, dim=-1)
        vanilla_logits = (cls_normalized @ vanilla_text_features.T).squeeze()
        pred_vanilla = vanilla_logits.argmax().item()
        
        # Extract features for other methods
        cls_feat = F.normalize(cls_embed, dim=-1)
        
        # Enriched patch
        patch_embeds = vision_out.last_hidden_state[:, 1:, :]
        enriched_patches = torch.zeros_like(patch_embeds[0])
        
        for layer_idx, head_idx in common_patch_heads:
            attn_weights = vision_out.attentions[layer_idx][0, head_idx, 0, 1:]
            weighted_patches = patch_embeds[0] * attn_weights.unsqueeze(1)
            enriched_patches += weighted_patches
        
        enriched_patches /= len(common_patch_heads)
        enriched_patch_1d = enriched_patches.mean(dim=0)
        patch_proj = model.visual_projection(enriched_patch_1d.unsqueeze(0))
        patch_norm = F.normalize(patch_proj, dim=-1)
        patch_logits = (patch_norm @ vanilla_text_features.T).squeeze()
        
        # Enriched CLS
        enriched_cls = torch.zeros_like(vision_out.hidden_states[-1][0][0])
        
        for layer_idx, head_idx in common_cls_heads:
            layer_attention = vision_out.attentions[layer_idx]
            layer_hidden = vision_out.hidden_states[layer_idx]
            
            cls_attention = layer_attention[0, head_idx, 0, :]
            cls_attention = F.softmax(cls_attention, dim=0)
            
            layer_enriched = (layer_hidden[0] * cls_attention.unsqueeze(1)).sum(dim=0)
            enriched_cls += layer_enriched
        
        enriched_cls /= len(common_cls_heads)
        enriched_cls_proj = model.visual_projection(enriched_cls.unsqueeze(0))
        enriched_cls_norm = F.normalize(enriched_cls_proj, dim=-1)
        enriched_cls_logits = (enriched_cls_norm @ vanilla_text_features.T).squeeze()
        
        # Predictions using learning-based approach
        pred_patch_learning = (optimal_beta1 * vanilla_logits + (1 - optimal_beta1) * patch_logits).argmax().item()
        pred_cls_learning = (optimal_beta2 * vanilla_logits + (1 - optimal_beta2) * enriched_cls_logits).argmax().item()
        pred_combined_learning = (optimal_beta3 * vanilla_logits + 
                                 (1 - optimal_beta3) / 2 * enriched_cls_logits + 
                                 (1 - optimal_beta3) / 2 * patch_logits).argmax().item()
        
        # CNN+MLP predictions (TensorFlow)
        patch_2d = enriched_patches.view(7, 7, 768).numpy()
        patch_2d = np.expand_dims(patch_2d, 0)  # Add batch dimension
        combined_cls_feat = torch.cat([cls_feat, enriched_cls_norm], dim=1).numpy()
        cls_feat_np = cls_feat.numpy()
        
        pred_patch_training = tf.argmax(patch_classifier(patch_2d, cls_feat_np), axis=1).numpy()[0]
        pred_cls_training = tf.argmax(cls_classifier(patch_2d, combined_cls_feat), axis=1).numpy()[0]
        pred_combined_training = tf.argmax(combined_classifier(patch_2d, combined_cls_feat), axis=1).numpy()[0]
        
        # Store results
        predictions = [pred_vanilla, pred_patch_learning, pred_patch_training, 
                      pred_cls_learning, pred_cls_training, pred_combined_learning, pred_combined_training]
        
        for i, pred in enumerate(predictions):
            all_predictions[i].append(pred)
            if pred == label:
                correct_counts[i] += 1
        
        true_labels.append(label)

# Results
eval_size = len(test_images)
method_names = ["Vanilla CLIP", "Patch Learning", "Patch Training", 
                "CLS Learning", "CLS Training", "Combined Learning", "INFER"]


for i, name in enumerate(method_names):
    accuracy = correct_counts[i] / eval_size
    print(f"{i+1}. {name:30} {accuracy:.4f}")

# Visualizations
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
cmaps = ['Blues', 'Greens', 'Oranges', 'Reds', 'Purples', 'RdPu', 'YlOrRd']

for i, (name, preds, cmap) in enumerate(zip(method_names, all_predictions, cmaps)):
    if i >= len(axes.flat):
        break
    row, col = i // 4, i % 4
    cm = confusion_matrix(true_labels, preds, labels=list(range(num_classes)))
    disp = ConfusionMatrixDisplay(cm, display_labels=[name.replace('_', ' ').title()[:10] for name in class_names])
    disp.plot(ax=axes[row, col], cmap=cmap, values_format='d')
    axes[row, col].set_title(name)
    axes[row, col].tick_params(axis='x', rotation=45)

# Hide unused subplots
for i in range(len(method_names), len(axes.flat)):
    axes.flat[i].axis('off')

plt.tight_layout()
plt.suptitle('Confusion Matrices for UCF-101 Action Recognition', y=1.02)
plt.show()

# Performance comparison
accuracies = [count / eval_size for count in correct_counts]
short_names = ['Vanilla', 'P-Learn', 'P-Train', 'C-Learn', 'C-Train', 'Comb-L', 'Comb-T']

plt.figure(figsize=(12, 6))
bars = plt.bar(short_names, accuracies, color=['blue', 'green', 'orange', 'red', 'purple', 'pink', 'brown'])
plt.ylabel('Accuracy')
plt.title('UCF-101 Action Recognition Performance Comparison')
plt.xticks(rotation=45)
plt.ylim(0, 1)

for bar, acc in zip(bars, accuracies):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{acc:.3f}', ha='center', va='bottom')

plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Display sample images from each action class
plt.figure(figsize=(15, 3))
for i, class_name in enumerate(class_names):
    # Find first image of this class in test set
    class_indices = np.where(test_labels == i)[0]
    if len(class_indices) > 0:
        sample_idx = class_indices[0]
        sample_image = test_images[sample_idx]
        
        plt.subplot(1, num_classes, i + 1)
        plt.imshow(sample_image.astype(np.uint8))
        plt.title(f'{class_name.replace("_", " ").title()}')
        plt.axis('off')

plt.suptitle('Sample Frames from UCF-101 Action Classes')
plt.tight_layout()
plt.show()

# Create a detailed performance analysis
print(f"\n{'='*60}")
print("DETAILED ANALYSIS")
print(f"{'='*60}")

# Per-class accuracy analysis for best method
best_method_idx = np.argmax(accuracies)
best_predictions = all_predictions[best_method_idx]
best_method_name = method_names[best_method_idx]

print(f"\nPer-class accuracy for best method ({best_method_name}):")
for class_idx, class_name in enumerate(class_names):
    class_mask = np.array(true_labels) == class_idx
    if np.sum(class_mask) > 0:
        class_predictions = np.array(best_predictions)[class_mask]
        class_accuracy = np.sum(class_predictions == class_idx) / len(class_predictions)
        print(f"  {class_name.replace('_', ' ').title():20}: {class_accuracy:.4f} ({np.sum(class_predictions == class_idx)}/{len(class_predictions)})")

# Feature dimensionality info
print(f"\nFeature dimensions:")
print(f"  CLS features: {cls_features.shape}")
print(f"  Patch features: {patch_features.shape}")
print(f"  Combined CLS features: {combined_cls_features.shape}")

# Head analysis summary
print(f"\nSelected attention heads:")
print(f"  Patch heads: {len(common_patch_heads)} heads across {len(set([layer for layer, head in common_patch_heads]))} layers")
print(f"  CLS heads: {len(common_cls_heads)} heads across {len(set([layer for layer, head in common_cls_heads]))} layers")


# Additional analysis: Compare with random baseline
random_accuracy = 1.0 / num_classes


# Show improvement over vanilla CLIP
vanilla_accuracy = accuracies[0]
best_accuracy = max(accuracies)
improvement = (best_accuracy - vanilla_accuracy) / vanilla_accuracy * 100
