import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Try to import AutoAttack
try:
    from autoattack import AutoAttack
    AUTOATTACK_AVAILABLE = True
    print("AutoAttack imported successfully")
except ImportError:
    AUTOATTACK_AVAILABLE = False
    print("Warning: AutoAttack not available. Install with: pip install autoattack")

from models import *
try:
    from utils import progress_bar
    PROGRESS_BAR_AVAILABLE = True
except (ValueError, OSError):
    PROGRESS_BAR_AVAILABLE = False
    print("Warning: Progress bar not available, using simple progress display")
    
    def progress_bar(current, total, msg=None):
        if current % 10 == 0 or current == total - 1:
            print(f"Progress: {current+1}/{total} {msg or ''}")

# Adversarial attack parameters
lower_limit, upper_limit = 0, 1
CIFAR_MEAN = torch.tensor([0.4914, 0.4822, 0.4465]).view(1, 3, 1, 1)
CIFAR_STD = torch.tensor([0.2023, 0.1994, 0.2010]).view(1, 3, 1, 1)

def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)

class DynamicPrototypeModelL0(nn.Module):
    """Model trained only with L0 similarity"""
    def __init__(self, backbone, num_classes=10, embedding_dim=512, dropout_rate=0.3):
        super().__init__()
        self.backbone = backbone
        self.class_prototypes = nn.Parameter(torch.randn(num_classes, embedding_dim) * 0.1)
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self._prototypes_initialized = False
        
        # ADD DROPOUT LAYERS TO COMBAT OVERFITTING
        self.dropout1 = nn.Dropout(dropout_rate)  # After backbone
        self.dropout2 = nn.Dropout(dropout_rate)  # Before L0 computation
        
    def forward(self, x, targets=None, return_individual=False, tau=0.75, compute_separation=False, margin=0.2, class_boost=0.5):
        # Extract image embeddings
        image_embeddings = self.backbone(x)
        
        # ADD DROPOUT 1: After backbone extraction
        image_embeddings = self.dropout1(image_embeddings)
        
        # Normalize image embeddings only (keep prototypes as raw vectors for separation)
        image_embeddings = F.normalize(image_embeddings, p=2, dim=1)
        
        # ADD DROPOUT 2: Before L0 computation
        image_embeddings = self.dropout2(image_embeddings)
        
        # Compute L0 distances only
        l0_similarities = self.compute_l0_similarity(image_embeddings, self.class_prototypes, tau=tau)
        
        # Apply class boost to L0 similarities if targets are provided
        if targets is not None:
            batch_size = l0_similarities.size(0)
            class_indices = targets.view(-1, 1)
            one_hot = torch.zeros_like(l0_similarities)
            one_hot.scatter_(1, class_indices, 1)
            l0_similarities = l0_similarities # * one_hot)
        
        # Compute separation loss in forward pass if requested
        separation_loss = None
        if compute_separation:
            forced_prototypes = self.class_prototypes + 0.0
            separation_loss = self._compute_separation_loss(forced_prototypes, margin)
        
        if return_individual:
            if compute_separation:
                return l0_similarities, None, l0_similarities, separation_loss
            else:
                return l0_similarities, None, l0_similarities
        
        if compute_separation:
            return l0_similarities, separation_loss
        else:
            return l0_similarities
    
    def _compute_separation_loss(self, class_prototypes, margin=0.2):
        """Compute separation loss directly in forward pass to maintain computation graph"""
        num_classes = class_prototypes.size(0)
        distances = torch.cdist(class_prototypes, class_prototypes, p=2)
        mask = torch.eye(num_classes, device=class_prototypes.device)
        off_diagonal_distances = distances * (1 - mask)
        min_distances_per_class = torch.min(off_diagonal_distances, dim=1)[0]
        violations_per_class = F.relu(margin - min_distances_per_class)
        separation_loss = violations_per_class.sum()
        return separation_loss

    def compute_l0_similarity(self, img_features, text_features, tau=0.75):
        """Compute L0-based similarity between image features and prototypes"""
        bsz = img_features.shape[0]
        num_classes = text_features.shape[0]
        dim_size = img_features.shape[1]
        
        # Reshape for broadcasting
        img_features_expanded = img_features.unsqueeze(1)
        text_features_expanded = text_features.unsqueeze(0)
        
        # Calculate absolute difference for all pairs at once
        diff = torch.abs(img_features_expanded - text_features_expanded)
        
        # Calculate thresholds (tau * mean difference) for each pair
        thresholds = tau * torch.mean(diff, dim=2, keepdim=True)
        
        # Use sigmoid to create a smooth approximation of the threshold function
        temperature = 0.5
        smooth_indicator = torch.sigmoid((diff - thresholds) / temperature)
        
        # Sum the smooth indicators to get an approximation of the L0 count
        l0_approximation = torch.sum(smooth_indicator, dim=2)
        
        # Convert to similarity score (inverse of L0 approximation)
        scale_factor = 1 / dim_size
        similarity_scores = (dim_size - l0_approximation) * scale_factor
        
        # REDUCE TEMPERATURE SCALING TO COMBAT OVERFITTING
        temp_scale = 3.0  # Reduced from 8.0 to 3.0
        logits_l0 = similarity_scores * temp_scale
        
        return logits_l0
    
    def get_prototypes(self):
        """Get raw prototypes for regularization (not normalized)"""
        return self.class_prototypes
    
    def get_image_embeddings(self, x):
        """Extract normalized image embeddings"""
        embeddings = self.backbone(x)
        return F.normalize(embeddings, p=2, dim=1)
    
    def initialize_prototypes_from_data(self, dataloader, device):
        """Initialize class prototypes from the centroid of image embeddings for each class"""
        if self._prototypes_initialized:
            print("Prototypes already initialized from data")
            return
        
        print("Initializing prototypes from data centroids...")
        
        # Collect embeddings for each class
        class_embeddings = {i: [] for i in range(self.num_classes)}
        
        self.eval()
        with torch.no_grad():
            for batch_idx, batch_data in enumerate(dataloader):
                if batch_idx >= 100:  # Use first 100 batches for initialization
                    break
                
                # Handle different batch data formats
                if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 2:
                    inputs, targets = batch_data[0], batch_data[1]
                else:
                    continue
                    
                inputs, targets = inputs.to(device), targets.to(device)
                embeddings = self.backbone(inputs)
                
                # Group embeddings by class
                for i, target in enumerate(targets):
                    class_idx = target.item()
                    class_embeddings[class_idx].append(embeddings[i])
        
        # Compute centroids for each class
        centroids = torch.zeros(self.num_classes, self.embedding_dim, device=device)
        valid_classes = 0
        
        for class_idx in range(self.num_classes):
            if len(class_embeddings[class_idx]) > 0:
                class_emb = torch.stack(class_embeddings[class_idx])
                centroid = class_emb.mean(dim=0)
                centroids[class_idx] = centroid
                valid_classes += 1
                print(f"Class {class_idx}: {len(class_embeddings[class_idx])} samples, centroid norm: {torch.norm(centroid, p=2).item():.4f}")
            else:
                print(f"Warning: No samples found for class {class_idx}")
        
        # Set prototypes to centroids
        
        with torch.no_grad():
            self.class_prototypes.data = centroids
        
        print(f"Initialized {valid_classes}/{self.num_classes} prototypes from data centroids")
        
        # Verify initial separation
        self._verify_prototype_separation()
        self._prototypes_initialized = True
    
    def _verify_prototype_separation(self):
        """Verify that prototypes have good initial separation"""
        with torch.no_grad():
            distances = torch.cdist(self.class_prototypes, self.class_prototypes, p=2)
            mask = torch.eye(self.num_classes, device=self.class_prototypes.device)
            off_diagonal = distances * (1 - mask)
            min_separation = off_diagonal.min().item()
            max_separation = off_diagonal.max().item()
            avg_separation = off_diagonal.mean().item()
            
            print(f"Prototype separation - Min: {min_separation:.4f}, Max: {max_separation:.4f}, Avg: {avg_separation:.4f}")
            if min_separation < 1.0:
                print("⚠️  Warning: Some prototypes are very close together!")

class DynamicPrototypeModelKL(nn.Module):
    """Model trained only with KL similarity"""
    def __init__(self, backbone, num_classes=10, embedding_dim=512, dropout_rate=0.3):
        super().__init__()
        self.backbone = backbone
        self.class_prototypes = nn.Parameter(torch.randn(num_classes, embedding_dim) * 0.1)
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self._prototypes_initialized = False
        
        # ADD DROPOUT LAYERS TO COMBAT OVERFITTING
        self.dropout1 = nn.Dropout(dropout_rate)  # After backbone
        self.dropout2 = nn.Dropout(dropout_rate)  # Before KL computation
        
    def forward(self, x, targets=None, return_individual=False, tau=0.75, compute_separation=False, margin=0.2, class_boost=0.5):
        # Extract image embeddings
        image_embeddings = self.backbone(x)
        
        # ADD DROPOUT 1: After backbone extraction
        image_embeddings = self.dropout1(image_embeddings)
        
        # Normalize image embeddings only (keep prototypes as raw vectors for separation)
        image_embeddings = F.normalize(image_embeddings, p=2, dim=1)
        
        # ADD DROPOUT 2: Before KL computation
        image_embeddings = self.dropout2(image_embeddings)
        
        # FIX: Create forced_prototypes ONCE and use it everywhere
        forced_prototypes = self.class_prototypes + 0.0  # This maintains gradients
        
        # Compute KL distances using the same forced_prototypes
        kl_similarities = self.compute_kl_similarity(image_embeddings, forced_prototypes)
        
        # Apply class boost to KL similarities if targets are provided
        if targets is not None:
            batch_size = kl_similarities.size(0)
            class_indices = targets.view(-1, 1)
            one_hot = torch.zeros_like(kl_similarities)
            one_hot.scatter_(1, class_indices, 1)
            kl_similarities = kl_similarities - (class_boost * one_hot)
        
        # DEBUG: Print gradient information every 100 batches
        if not hasattr(self, '_kl_debug_counter'):
            self._kl_debug_counter = 0
        self._kl_debug_counter += 1
        
        if self._kl_debug_counter % 100 == 0:
            print(f"\n=== KL Debug (Batch {self._kl_debug_counter}) ===")
            print(f"Image embeddings requires_grad: {image_embeddings.requires_grad}")
            print(f"Prototypes requires_grad: {self.class_prototypes.requires_grad}")
            print(f"Forced prototypes requires_grad: {forced_prototypes.requires_grad}")
            print(f"KL similarities requires_grad: {kl_similarities.requires_grad}")
            
            if image_embeddings.grad_fn is not None:
                print(f"Image embeddings grad_fn: {image_embeddings.grad_fn}")
            else:
                print("⚠️  Image embeddings grad_fn: None (no gradient flow!)")
                
            if self.class_prototypes.grad_fn is not None:
                print(f"Prototypes grad_fn: {self.class_prototypes.grad_fn}")
            else:
                print("⚠️  Prototypes grad_fn: None (no gradient flow!)")
                
            if forced_prototypes.grad_fn is not None:
                print(f"Forced prototypes grad_fn: {forced_prototypes.grad_fn}")
            else:
                print("⚠️  Forced prototypes grad_fn: None (no gradient flow!)")
                
            if kl_similarities.grad_fn is not None:
                print(f"KL similarities grad_fn: {kl_similarities.grad_fn}")
            else:
                print("⚠️  KL similarities grad_fn: None (no gradient flow!)")
            
            print(f"KL similarities range: {kl_similarities.min().item():.6f} to {kl_similarities.max().item():.6f}")
            print(f"KL predictions: {torch.argmax(kl_similarities, dim=1)[:5].tolist()}")
            print("=" * 50)
        
        # Compute separation loss using the SAME forced_prototypes
        separation_loss = None
        if compute_separation:
            separation_loss = self._compute_separation_loss(forced_prototypes, margin)
        
        if return_individual:
            if compute_separation:
                return None, kl_similarities, kl_similarities, separation_loss
            else:
                return None, kl_similarities, kl_similarities
        
        if compute_separation:
            return kl_similarities, separation_loss
        else:
            return kl_similarities
    
    def _compute_separation_loss(self, class_prototypes, margin=0.2):
        """Compute separation loss directly in forward pass to maintain computation graph"""
        num_classes = class_prototypes.size(0)
        distances = torch.cdist(class_prototypes, class_prototypes, p=2)
        mask = torch.eye(num_classes, device=class_prototypes.device)
        off_diagonal_distances = distances * (1 - mask)
        min_distances_per_class = torch.min(off_diagonal_distances, dim=1)[0]
        violations_per_class = F.relu(margin - min_distances_per_class)
        separation_loss = violations_per_class.sum()
        return separation_loss

    
    def compute_kl_similarity(self, img_features, text_features):
        bsz = img_features.shape[0]
        num_classes = text_features.shape[0]
        
        # Transform features to probability-like distributions
        t=0.1
        img_dist=F.softmax(img_features/t, dim=-1)
        text_dist=F.softmax(text_features/t, dim=-1)
        # img_dist = self.to_signed_prob_smooth(img_features, gamma=1.5, alpha=1e-6, l2_norm=True)
        # text_dist = self.to_signed_prob_smooth(text_features, gamma=1.5, alpha=1e-6, l2_norm=True)
        
        # Expand dimensions for broadcasting
        img_expanded = img_dist.unsqueeze(1)  # [batch, 1, 2*features]
        text_expanded = text_dist.unsqueeze(0)  # [1, classes, 2*features]
        
        # Compute KL divergence: KL(text||img) = sum(text * log(text/img))
        # Add small epsilon to avoid log(0)
        epsilon = 1e-10
        img_dist_safe = torch.clamp(img_expanded, min=epsilon, max=1.0)
        text_dist_safe = torch.clamp(text_expanded, min=epsilon, max=1.0)
        
        # Compute KL divergence
        kl_div = torch.sum(text_dist_safe * torch.log(text_dist_safe / img_dist_safe), dim=2)
        
        # Convert KL divergence to similarity (lower KL = higher similarity)
        # Use exponential transformation for better gradient flow
        kl_similarity =-kl_div # torch.exp(-kl_div)
        
        # Apply temperature scaling to make differences clearer
        temperature = 8.0
        kl_similarity = kl_similarity * temperature
        
        # DEBUG: Print internal KL computation gradients every 100 batches
        if not hasattr(self, '_kl_internal_debug_counter'):
            self._kl_internal_debug_counter = 0
        self._kl_internal_debug_counter += 1
        
        if self._kl_internal_debug_counter % 100 == 0:
            print(f"\n=== KL Internal Debug (Batch {self._kl_internal_debug_counter}) ===")
            print(f"img_features requires_grad: {img_features.requires_grad}")
            print(f"text_features requires_grad: {text_features.requires_grad}")
            print(f"img_dist requires_grad: {img_dist.requires_grad}")
            print(f"text_dist requires_grad: {text_dist.requires_grad}")
            print(f"kl_div requires_grad: {kl_div.requires_grad}")
            print(f"kl_similarity requires_grad: {kl_similarity.requires_grad}")
            
            # Check specific gradient functions
            if img_features.grad_fn is not None:
                print(f"img_features grad_fn: {img_features.grad_fn}")
            if text_features.grad_fn is not None:
                print(f"text_features grad_fn: {text_features.grad_fn}")
            if kl_similarity.grad_fn is not None:
                print(f"kl_similarity grad_fn: {kl_similarity.grad_fn}")
            
            print(f"KL divergence range: {kl_div.min().item():.6f} to {kl_div.max().item():.6f}")
            print(f"Final KL similarity range: {kl_similarity.min().item():.6f} to {kl_similarity.max().item():.6f}")
            print("=" * 50)
        
        return kl_similarity
    
    def get_prototypes(self):
        """Get raw prototypes for regularization (not normalized)"""
        return self.class_prototypes
    
    def get_image_embeddings(self, x):
        """Extract normalized image embeddings"""
        embeddings = self.backbone(x)
        return F.normalize(embeddings, p=2, dim=1)
    
    def initialize_prototypes_from_data(self, dataloader, device):
        """Initialize class prototypes from the centroid of image embeddings for each class"""
        if self._prototypes_initialized:
            print("Prototypes already initialized from data")
            return
        
        print("Initializing prototypes from data centroids...")
        
        # Collect embeddings for each class
        class_embeddings = {i: [] for i in range(self.num_classes)}
        
        self.eval()
        with torch.no_grad():
            for batch_idx, batch_data in enumerate(dataloader):
                if batch_idx >= 100:  # Use first 100 batches for initialization
                    break
                
                # Handle different batch data formats
                if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 2:
                    inputs, targets = batch_data[0], batch_data[1]
                else:
                    continue
                    
                inputs, targets = inputs.to(device), targets.to(device)
                embeddings = self.backbone(inputs)
                
                # Group embeddings by class
                for i, target in enumerate(targets):
                    class_idx = target.item()
                    class_embeddings[class_idx].append(embeddings[i])
        
        # Compute centroids for each class
        centroids = torch.zeros(self.num_classes, self.embedding_dim, device=device)
        valid_classes = 0
        
        for class_idx in range(self.num_classes):
            if len(class_embeddings[class_idx]) > 0:
                class_emb = torch.stack(class_embeddings[class_idx])
                centroid = class_emb.mean(dim=0)
                centroids[class_idx] = centroid
                valid_classes += 1
                print(f"Class {class_idx}: {len(class_embeddings[class_idx])} samples, centroid norm: {torch.norm(centroid, p=2).item():.4f}")
            else:
                print(f"Warning: No samples found for class {class_idx}")
        
        # Set prototypes to centroids
        with torch.no_grad():
            self.class_prototypes.data = centroids
        
        print(f"Initialized {valid_classes}/{self.num_classes} prototypes from data centroids")
        
        # Verify initial separation
        self._verify_prototype_separation()
        self._prototypes_initialized = True
    
    def _verify_prototype_separation(self):
        """Verify that prototypes have good initial separation"""
        with torch.no_grad():
            distances = torch.cdist(self.class_prototypes, self.class_prototypes, p=2)
            mask = torch.eye(self.num_classes, device=self.class_prototypes.device)
            off_diagonal = distances * (1 - mask)
            min_separation = off_diagonal.min().item()
            max_separation = off_diagonal.max().item()
            avg_separation = off_diagonal.mean().item()
            
            print(f"Prototype separation - Min: {min_separation:.4f}, Max: {max_separation:.4f}, Avg: {avg_separation:.4f}")
            if min_separation < 1.0:
                print("⚠️  Warning: Some prototypes are very close together!")

class ConsensusEvaluator:
    """Evaluates two separately trained models using consensus prediction"""
    def __init__(self, l0_model, kl_model, device):
        self.l0_model = l0_model
        self.kl_model = kl_model
        self.device = device
        
    def predict_with_consensus(self, x, tau=0.75):
        """Make predictions using consensus between L0 and KL metrics"""
        # Get predictions from each model
        with torch.no_grad():
            l0_sims = self.l0_model(x, tau=tau)
            kl_sims = self.kl_model(x, tau=tau)
        
        # Get predictions from each metric (most likely class)
        l0_preds = torch.argmax(l0_sims, dim=1)
        kl_preds = torch.argmax(kl_sims, dim=1)
        
        # Check if L0 and KL predictions agree
        consensus_flags = (l0_preds == kl_preds)
        
        # Make predictions: use consensus prediction when available, otherwise unknown (-1)
        predictions = torch.where(consensus_flags, l0_preds, torch.full_like(l0_preds, -1))
        
        return predictions, consensus_flags, l0_preds, kl_preds
    
    def evaluate_accuracy(self, dataloader, set_name="Test"):
        """Evaluate accuracy using consensus prediction"""
        self.l0_model.eval()
        self.kl_model.eval()
        
        total = 0
        consensus_correct = 0
        consensus_total = 0
        l0_correct = 0
        kl_correct = 0
        
        # ADD: Track true consensus accuracy (on entire test set)
        true_consensus_correct = 0
        
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(dataloader):
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                
                # Get consensus predictions
                predictions, consensus_flags, l0_preds, kl_preds = self.predict_with_consensus(inputs)
                
                # Individual metric accuracies
                l0_correct += (l0_preds == targets).sum().item()
                kl_correct += (kl_preds == targets).sum().item()
                
                # Consensus accuracy (only when L0 and KL agree)
                consensus_mask = consensus_flags
                if consensus_mask.sum() > 0:
                    consensus_correct += (predictions[consensus_mask] == targets[consensus_mask]).sum().item()
                    consensus_total += consensus_mask.sum().item()
                
                # ADD: True consensus accuracy (on entire test set)
                # When no consensus, count as incorrect
                true_consensus_correct += (predictions == targets).sum().item()
                
                total += targets.size(0)
        
        # Calculate accuracies
        l0_acc = 100. * l0_correct / total
        kl_acc = 100. * kl_correct / total
        consensus_acc = 100. * consensus_correct / consensus_total if consensus_total > 0 else 0
        consensus_rate = 100. * consensus_total / total
        
        # ADD: True consensus accuracy (on entire test set)
        true_consensus_acc = 100. * true_consensus_correct / total
        
        print(f'{set_name} Results:')
        print(f'  Consensus Rate: {consensus_rate:.2f}% ({consensus_total}/{total})')
        print(f'  Consensus Accuracy (when agree): {consensus_acc:.2f}% ({consensus_correct}/{consensus_total})')
        print(f'  TRUE Consensus Accuracy (entire set): {true_consensus_acc:.2f}% ({true_consensus_correct}/{total})')
        print(f'  L0 Accuracy: {l0_acc:.2f}% ({l0_correct}/{total})')
        print(f'  KL Accuracy: {kl_acc:.2f}% ({kl_correct}/{total})')
        
        # ADD: Detailed breakdown to explain the difference
        print(f'  Note: "Consensus Accuracy" only counts predictions where L0 and KL agree.')
        print(f'        "TRUE Consensus Accuracy" counts all predictions (unknown = incorrect).')
        print(f'  Breakdown:')
        print(f'    - L0 and KL agree: {consensus_total} images ({consensus_rate:.1f}%)')
        print(f'    - L0 and KL disagree: {total - consensus_total} images ({100-consensus_rate:.1f}%)')
        print(f'    - When they agree, correct: {consensus_correct} images ({consensus_acc:.1f}%)')
        print(f'    - When they disagree, all marked as unknown (incorrect)')
        print(f'    - Total correct: {true_consensus_correct} out of {total} images')
        print(f'    - TRUE accuracy: {true_consensus_acc:.2f}%')
        
        return consensus_acc, l0_acc, kl_acc, consensus_rate

def train_l0_model(model, trainloader, valloader, testloader, device, args):
    """Train L0 model separately"""
    print("==> Training L0 Model...")
    
    criterion = nn.CrossEntropyLoss()
    
    # Separate optimizers for backbone and prototypes
    backbone_params = list(model.backbone.parameters())
    prototype_params = [model.class_prototypes]
    
    backbone_optimizer = optim.SGD(backbone_params, lr=args.lr * 0.1,
                                   momentum=0.9, weight_decay=1e-3)  # Increased from 5e-4
    prototype_optimizer = optim.SGD(prototype_params, lr=args.lr,
                                    momentum=0.9, weight_decay=1e-3)  # Increased from 1e-4
    
    # Schedulers
    backbone_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(backbone_optimizer, T_max=args.epochs)
    prototype_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(prototype_optimizer, T_max=args.epochs)
    
    best_acc = 0
    patience_counter = 0
    
    for epoch in range(args.epochs):
        print(f'\nEpoch: {epoch}')
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Zero gradients
            backbone_optimizer.zero_grad()
            prototype_optimizer.zero_grad()
            
            # Forward pass with separation loss
            outputs, separation_loss = model(inputs, targets=targets, tau=args.tau, 
                                           compute_separation=True, margin=args.margin, 
                                           class_boost=args.class_boost)
            
            # Compute classification loss
            classification_loss = F.cross_entropy(outputs, targets)
            
            # Combined loss
            total_loss = args.alpha * classification_loss + args.beta * separation_loss
            
            total_loss.backward()
            
            # Step optimizers
            backbone_optimizer.step()
            prototype_optimizer.step()
            
            # Update metrics
            train_loss += total_loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            if batch_idx % 50 == 0:
                print(f'Batch {batch_idx}: Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})')
        
        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, targets in valloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs, tau=args.tau)
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
        
        val_acc = 100. * val_correct / val_total
        print(f'Validation Accuracy: {val_acc:.2f}%')
        
        # Display prototype separation analysis during validation
        print(f'\n=== PROTOTYPE SEPARATION ANALYSIS (Epoch {epoch}) ===')
        with torch.no_grad():
            prototypes = model.get_prototypes()
            distances = torch.cdist(prototypes, prototypes, p=2)
            mask = torch.eye(prototypes.size(0), device=prototypes.device)
            off_diagonal = distances * (1 - mask)
            
            print(f"Min separation: {off_diagonal.min().item():.4f}")
            print(f"Max separation: {off_diagonal.max().item():.4f}")
            print(f"Avg separation: {off_diagonal.mean().item():.4f}")
            
            # Show distance matrix (compact format)
            print("Distance Matrix:")
            for i in range(min(5, prototypes.size(0))):  # Show first 5 rows to avoid clutter
                row_str = "  "
                for j in range(min(5, prototypes.size(0))):
                    if i == j:
                        row_str += " 0.0000"
                    else:
                        row_str += f" {distances[i, j].item():.4f}"
                print(row_str)
            if prototypes.size(0) > 5:
                print(f"  ... (showing first 5x5, total: {prototypes.size(0)}x{prototypes.size(0)})")
            
            # Warning if separation is poor
            if off_diagonal.min().item() < 1.0:
                print("⚠️  Warning: Some prototypes are very close together!")
            elif off_diagonal.min().item() < 2.0:
                print("⚠️  Caution: Prototype separation could be improved")
            else:
                print("✅ Good prototype separation maintained")
        print("=" * 60)
        
        # Save best model
        if val_acc > best_acc:
            print('Saving best L0 model..')
            state = {
                'net': model.state_dict(),
                'acc': val_acc,
                'epoch': epoch,
                'model_type': 'l0'
            }
            torch.save(state, args.l0_checkpoint)
            best_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= args.patience:
                print(f'Early stopping after {patience_counter} epochs without improvement')
                break
        
        backbone_scheduler.step()
        prototype_scheduler.step()
    
    print(f'L0 Training completed! Best validation accuracy: {best_acc:.2f}%')
    # Note: Test evaluation is done separately using ConsensusEvaluator
    # after both L0 and KL models are trained
    
    return best_acc

def train_kl_model(model, trainloader, valloader, testloader, device, args):
    """Train KL model separately"""
    print("==> Training KL Model...")
    
    criterion = nn.CrossEntropyLoss()
    
    # Separate optimizers for backbone and prototypes
    backbone_params = list(model.backbone.parameters())
    prototype_params = [model.class_prototypes]
    
    backbone_optimizer = optim.SGD(backbone_params, lr=args.lr * 0.1,
                                   momentum=0.9, weight_decay=1e-3)  # Increased from 5e-4
    prototype_optimizer = optim.SGD(prototype_params, lr=args.lr,
                                    momentum=0.9, weight_decay=1e-3)  # Increased from 1e-4
    
    # Schedulers
    backbone_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(backbone_optimizer, T_max=args.epochs)
    prototype_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(prototype_optimizer, T_max=args.epochs)
    
    best_acc = 0
    patience_counter = 0
    
    for epoch in range(args.epochs):
        print(f'\nEpoch: {epoch}')
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Zero gradients
            backbone_optimizer.zero_grad()
            prototype_optimizer.zero_grad()
            
            # Forward pass with separation loss
            outputs, separation_loss = model(inputs, targets=targets, tau=args.tau, 
                                           compute_separation=True, margin=args.margin, 
                                           class_boost=args.class_boost)
            
            # Compute classification loss
            classification_loss = F.cross_entropy(outputs, targets)
            
            # Combined loss
            total_loss = args.alpha * classification_loss + args.beta * separation_loss
            
            total_loss.backward()
            
            # Step optimizers
            backbone_optimizer.step()
            prototype_optimizer.step()
            
            # Update metrics
            train_loss += total_loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            if batch_idx % 50 == 0:
                print(f'Batch {batch_idx}: Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})')
        
        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, targets in valloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs, tau=args.tau)
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
        
        val_acc = 100. * val_correct / val_total
        print(f'Validation Accuracy: {val_acc:.2f}%')
        
        # Display prototype separation analysis during validation
        print(f'\n=== PROTOTYPE SEPARATION ANALYSIS (Epoch {epoch}) ===')
        with torch.no_grad():
            prototypes = model.get_prototypes()
            distances = torch.cdist(prototypes, prototypes, p=2)
            mask = torch.eye(prototypes.size(0), device=prototypes.device)
            off_diagonal = distances * (1 - mask)
            
            print(f"Min separation: {off_diagonal.min().item():.4f}")
            print(f"Max separation: {off_diagonal.max().item():.4f}")
            print(f"Avg separation: {off_diagonal.mean().item():.4f}")
            
            # Show distance matrix (compact format)
            print("Distance Matrix:")
            for i in range(min(5, prototypes.size(0))):  # Show first 5 rows to avoid clutter
                row_str = "  "
                for j in range(min(5, prototypes.size(0))):
                    if i == j:
                        row_str += " 0.0000"
                    else:
                        row_str += f" {distances[i, j].item():.4f}"
                print(row_str)
            if prototypes.size(0) > 5:
                print(f"  ... (showing first 5x5, total: {prototypes.size(0)}x{prototypes.size(0)})")
            
            # Warning if separation is poor
            if off_diagonal.min().item() < 1.0:
                print("⚠️  Warning: Some prototypes are very close together!")
            elif off_diagonal.min().item() < 2.0:
                print("⚠️  Caution: Prototype separation could be improved")
            else:
                print("✅ Good prototype separation maintained")
        print("=" * 60)
        
        # Save best model
        if val_acc > best_acc:
            print('Saving best KL model..')
            state = {
                'net': model.state_dict(),
                'acc': val_acc,
                'epoch': epoch,
                'model_type': 'kl'
            }
            torch.save(state, args.kl_checkpoint)
            best_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= args.patience:
                print(f'Early stopping after {patience_counter} epochs without improvement')
                break
        
        backbone_scheduler.step()
        prototype_scheduler.step()
    
    print(f'KL Training completed! Best validation accuracy: {best_acc:.2f}%')
    # Note: Test evaluation is done separately using ConsensusEvaluator
    # after both L0 and KL models are trained
    
    return best_acc

def ensemble_attack_pgd(l0_model, kl_model, X, target, alpha, attack_iters, norm, device, epsilon=0, tau=0.75):
    """PGD attack that targets both L0 and KL models simultaneously"""
    X = X.to(device)
    target = target.to(device)
    
    # Initialize perturbation
    delta = torch.zeros_like(X, device=device)
    if norm == "l_inf":
        delta.uniform_(-epsilon, epsilon)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0), -1)
        n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r / n * epsilon
    
    # Project to valid bounds
    delta = torch.max(torch.min(delta, upper_limit - X), lower_limit - X)
    delta.requires_grad = True
    
    for _ in range(attack_iters):
        # Forward pass through BOTH models
        l0_output = l0_model(X + delta, tau=tau)
        kl_output = kl_model(X + delta, tau=tau)
        
        # Combined loss (maximize loss for both models)
        l0_loss = F.cross_entropy(l0_output, target)
        kl_loss = F.cross_entropy(kl_output, target)
        combined_loss = l0_loss + kl_loss  # or max(l0_loss, kl_loss)
        
        # Backward pass
        combined_loss.backward()
        
        # Update perturbation
        grad = delta.grad.detach()
        d = delta[:, :, :, :]
        g = grad[:, :, :, :]
        x = X[:, :, :, :]
        
        if norm == "l_inf":
            d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
        elif norm == "l_2":
            g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
            scaled_g = g / (g_norm + 1e-10)
            d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
        
        # Project to valid bounds
        d = torch.max(torch.min(d, upper_limit - x), lower_limit - x)
        delta.data[:, :, :, :] = d
        delta.grad.zero_()
    
    return delta

def attack_pgd(model, X, target, alpha, attack_iters, norm, device, epsilon=0, restarts=1, tau=0.75, mean=None, std=None):
    """Projected Gradient Descent attack with proper normalized space handling"""
    X = X.to(device)
    target = target.to(device)
    #lower_limit = ((0 - mean) / std).to(device)
    #upper_limit = ((1 - mean) / std).to(device)
    #std_avg = std.mean().item() 
    # scale eps/alpha for normalized space (per-channel), shape (1,3,1,1)
    #epsilon_scaled = epsilon / std_avg
    #alpha_scaled = alpha / std_avg
    epsilon_scaled=epsilon
    alpha_scaled=alpha
    # Initialize perturbation
    delta = torch.zeros_like(X, device=device)
    if norm == "l_inf":
        delta.uniform_(-epsilon_scaled, epsilon_scaled)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0), -1)
        n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r / n * epsilon_scaled
    else:
        raise ValueError(f"Norm {norm} not supported")
    
    # Project to valid bounds
    delta = torch.max(torch.min(delta, upper_limit - X), lower_limit - X)
    delta.requires_grad = True
    
    for _ in range(attack_iters):
        # Forward pass
        output = model(X + delta, tau=tau)
        loss = F.cross_entropy(output, target)
        
        # Backward pass
        loss.backward()
        
        # Get gradient
        grad = delta.grad.detach()
        d = delta[:, :, :, :]
        g = grad[:, :, :, :]
        x = X[:, :, :, :]
        
        if norm == "l_inf":
            d = torch.clamp(d + alpha_scaled * torch.sign(g), min=-epsilon_scaled, max=epsilon_scaled)
        elif norm == "l_2":
            g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
            scaled_g = g / (g_norm + 1e-10)
            d = (d + scaled_g * alpha_scaled).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon_scaled).view_as(d)
        
        # Project to valid bounds
        d = torch.max(torch.min(d, upper_limit - x), lower_limit - x)
        delta.data[:, :, :, :] = d
        delta.grad.zero_()
    
    return delta

def attack_cw(model, X, target, alpha, attack_iters, norm, device, epsilon=0, restarts=1, tau=0.75, mean=None, std=None):
    """Carlini & Wagner attack with proper normalized space handling"""
    X = X.to(device)
    target = target.to(device)
    epsilon_scaled = epsilon
    alpha_scaled = alpha
    
    # Initialize perturbation
    delta = torch.zeros_like(X, device=device)
    if norm == "l_inf":
        delta.uniform_(-epsilon_scaled, epsilon_scaled)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0), -1)
        n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r / n * epsilon_scaled
    else:
        raise ValueError(f"Norm {norm} not supported")
    
    # Project to valid bounds
    delta = torch.max(torch.min(delta, upper_limit - X), lower_limit - X)
    delta.requires_grad = True
    
    for _ in range(attack_iters):
        # Forward pass
        output = model(X + delta, tau=tau)
        
        # C&W loss: maximize the difference between correct and wrong logits
        num_classes = output.size(1)
        label_mask = torch.zeros_like(output)
        label_mask.scatter_(1, target.unsqueeze(1), 1)
        
        correct_logit = torch.sum(label_mask * output, dim=1)
        wrong_logit, _ = torch.max((1 - label_mask) * output - 1e4 * label_mask, dim=1)
        
        # C&W loss: minimize (correct_logit - wrong_logit + margin)
        loss = -torch.sum(F.relu(correct_logit - wrong_logit + 50))
        
        # Backward pass
        loss.backward()
        
        # Get gradient
        grad = delta.grad.detach()
        d = delta[:, :, :, :]
        g = grad[:, :, :, :]
        x = X[:, :, :, :]
        
        if norm == "l_inf":
            d = torch.clamp(d + alpha_scaled * torch.sign(g), min=-epsilon_scaled, max=epsilon_scaled)
        elif norm == "l_2":
            g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
            scaled_g = g / (g_norm + 1e-10)
            d = (d + scaled_g * alpha_scaled).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon_scaled).view_as(d)
        
        # Project to valid bounds
        d = torch.max(torch.min(d, upper_limit - x), lower_limit - x)
        delta.data[:, :, :, :] = d
        delta.grad.zero_()
    
    return delta

def attack_auto(model, X, target, alpha, attack_iters, norm, device, epsilon=0, restarts=1, tau=0.75, mean=None, std=None):
    """AutoAttack wrapper for comprehensive adversarial evaluation"""
    if not AUTOATTACK_AVAILABLE:
        print("Warning: AutoAttack not available, falling back to PGD attack")
        return attack_pgd(model, X, target, alpha, attack_iters, norm, device, epsilon, restarts, tau, mean, std)
    
    # Convert to CPU for AutoAttack (it expects CPU tensors)
    X_cpu = X.cpu()
    target_cpu = target.cpu()
    
    # Define the forward pass function for AutoAttack
    def forward_pass(images):
        images = images.to(device)
        with torch.no_grad():
            output = model(images, tau=tau)
        return output.cpu()
    
    # Create AutoAttack adversary
    adversary = AutoAttack(forward_pass, norm='Linf' if norm == 'l_inf' else 'L2', 
                          eps=epsilon, version='standard', verbose=False, device='cpu')
    
    # Run AutoAttack
    try:
        x_adv = adversary.run_standard_evaluation(X_cpu, target_cpu, bs=X_cpu.shape[0])
        return (x_adv - X_cpu).to(device)  # Return perturbation
    except Exception as e:
        print(f"AutoAttack failed: {e}, falling back to PGD attack")
        return attack_pgd(model, X, target, alpha, attack_iters, norm, device, epsilon, restarts, tau, mean, std)

def ensemble_attack_cw(l0_model, kl_model, X, target, alpha, attack_iters, norm, device, epsilon=0, tau=0.75):
    """C&W attack that targets both L0 and KL models simultaneously"""
    X = X.to(device)
    target = target.to(device)
    
    # Initialize perturbation
    delta = torch.zeros_like(X, device=device)
    if norm == "l_inf":
        delta.uniform_(-epsilon, epsilon)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0), -1)
        n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r / n * epsilon
    
    # Project to valid bounds
    delta = torch.max(torch.min(delta, upper_limit - X), lower_limit - X)
    delta.requires_grad = True
    
    for _ in range(attack_iters):
        # Forward pass through BOTH models
        l0_output = l0_model(X + delta, tau=tau)
        kl_output = kl_model(X + delta, tau=tau)
        
        # C&W loss: maximize the difference between correct and wrong logits for both models
        num_classes = l0_output.size(1)
        label_mask = torch.zeros_like(l0_output)
        label_mask.scatter_(1, target.unsqueeze(1), 1)
        
        # L0 C&W loss
        l0_correct_logit = torch.sum(label_mask * l0_output, dim=1)
        l0_wrong_logit, _ = torch.max((1 - label_mask) * l0_output - 1e4 * label_mask, dim=1)
        l0_cw_loss = -torch.sum(F.relu(l0_correct_logit - l0_wrong_logit + 50))
        
        # KL C&W loss
        kl_correct_logit = torch.sum(label_mask * kl_output, dim=1)
        kl_wrong_logit, _ = torch.max((1 - label_mask) * kl_output - 1e4 * label_mask, dim=1)
        kl_cw_loss = -torch.sum(F.relu(kl_correct_logit - kl_wrong_logit + 50))
        
        # Combined C&W loss
        combined_loss = l0_cw_loss + kl_cw_loss
        
        # Backward pass
        combined_loss.backward()
        
        # Update perturbation
        grad = delta.grad.detach()
        d = delta[:, :, :, :]
        g = grad[:, :, :, :]
        x = X[:, :, :, :]
        
        if norm == "l_inf":
            d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
        elif norm == "l_2":
            g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
            scaled_g = g / (g_norm + 1e-10)
            d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
        
        # Project to valid bounds
        d = torch.max(torch.min(d, upper_limit - x), lower_limit - x)
        delta.data[:, :, :, :] = d
        delta.grad.zero_()
    
    return delta

def ensemble_attack_auto(l0_model, kl_model, X, target, alpha, attack_iters, norm, device, epsilon=0, tau=0.75):
    """AutoAttack wrapper for ensemble evaluation"""
    if not AUTOATTACK_AVAILABLE:
        print("Warning: AutoAttack not available, falling back to ensemble PGD attack")
        return ensemble_attack_pgd(l0_model, kl_model, X, target, alpha, attack_iters, norm, device, epsilon, tau)
    
    # Convert to CPU for AutoAttack (it expects CPU tensors)
    X_cpu = X.cpu()
    target_cpu = target.cpu()
    
    # Define the forward pass function for AutoAttack (ensemble)
    def forward_pass(images):
        images = images.to(device)
        with torch.no_grad():
            l0_sims = l0_model(images, tau=tau)
            kl_sims = kl_model(images, tau=tau)
            # Combine predictions (simple average)
            combined_sims = (l0_sims + kl_sims) / 2
        return combined_sims.cpu()
    
    # Create AutoAttack adversary
    adversary = AutoAttack(forward_pass, norm='Linf' if norm == 'l_inf' else 'L2', 
                          eps=epsilon, version='standard', verbose=False, device='cpu')
    
    # Run AutoAttack
    try:
        x_adv = adversary.run_standard_evaluation(X_cpu, target_cpu, bs=X_cpu.shape[0])
        return (x_adv - X_cpu).to(device)  # Return perturbation
    except Exception as e:
        print(f"AutoAttack failed: {e}, falling back to ensemble PGD attack")
        return ensemble_attack_pgd(l0_model, kl_model, X, target, alpha, attack_iters, norm, device, epsilon, tau)


def evaluate_adversarial_robustness(evaluator, testloader, device, args):
    """Evaluate adversarial robustness using consensus prediction with 4-category breakdown"""
    print("==> Evaluating adversarial robustness...")
    
    # CRITICAL: Set models to eval mode and enable gradients for attacks
    evaluator.l0_model.eval()
    evaluator.kl_model.eval()
    torch.set_grad_enabled(True)  # Attacks need gradients
    
    print("✅ Models set to eval mode for reliable adversarial evaluation")
    print("✅ Gradients enabled for PGD attack computation")
    
    # Print prototype separation analysis for both models
    print("\n=== PROTOTYPE SEPARATION ANALYSIS ===")
    
    # L0 Model prototype distances
    with torch.no_grad():
        l0_prototypes = evaluator.l0_model.get_prototypes()
        l0_distances = torch.cdist(l0_prototypes, l0_prototypes, p=2)
        l0_mask = torch.eye(l0_prototypes.size(0), device=l0_prototypes.device)
        l0_off_diagonal = l0_distances * (1 - l0_mask)
        
        print("L0 Model Prototype Distances:")
        print(f"  Min separation: {l0_off_diagonal.min().item():.4f}")
        print(f"  Max separation: {l0_off_diagonal.max().item():.4f}")
        print(f"  Avg separation: {l0_off_diagonal.mean().item():.4f}")
        print(f"  Full distance matrix:")
        for i in range(l0_prototypes.size(0)):
            row_str = "    "
            for j in range(l0_prototypes.size(0)):
                if i == j:
                    row_str += " 0.0000"
                else:
                    row_str += f" {l0_distances[i, j].item():.4f}"
            print(row_str)
    
    # KL Model prototype distances
    with torch.no_grad():
        kl_prototypes = evaluator.kl_model.get_prototypes()
        kl_distances = torch.cdist(kl_prototypes, kl_prototypes, p=2)
        kl_mask = torch.eye(kl_prototypes.size(0), device=kl_prototypes.device)
        kl_off_diagonal = kl_distances * (1 - kl_mask)
        
        print("\nKL Model Prototype Distances:")
        print(f"  Min separation: {kl_off_diagonal.min().item():.4f}")
        print(f"  Max separation: {kl_off_diagonal.max().item():.4f}")
        print(f"  Avg separation: {kl_off_diagonal.mean().item():.4f}")
        print(f"  Full distance matrix:")
        for i in range(kl_prototypes.size(0)):
            row_str = "    "
            for j in range(kl_prototypes.size(0)):
                if i == j:
                    row_str += " 0.0000"
                else:
                    row_str += f" {kl_distances[i, j].item():.4f}"
            print(row_str)
    
    print("=" * 60)
    
    # Clean accuracy
    clean_consensus_acc, clean_l0_acc, clean_kl_acc, clean_consensus_rate = evaluator.evaluate_accuracy(testloader, "Clean")
    
    # Adversarial evaluation with 4-category breakdown
    total = 0
    
    # Category counters for clean predictions
    clean_same_correct = 0      # L0 and KL same and correct
    clean_same_incorrect = 0    # L0 and KL same but incorrect
    clean_diff_one_correct = 0  # L0 and KL different but one correct
    clean_diff_both_wrong = 0   # L0 and KL different and both incorrect
    
    # Category counters for adversarial predictions (L0 attack)
    adv_same_correct_l0_attack = 0        # L0 and KL same and correct
    adv_same_incorrect_l0_attack = 0      # L0 and KL same but incorrect
    adv_diff_one_correct_l0_attack = 0    # L0 and KL different but one correct
    adv_diff_both_wrong_l0_attack = 0     # L0 and KL different and both incorrect
    
    # Category counters for adversarial predictions (KL attack)
    adv_same_correct_kl_attack = 0        # L0 and KL same and correct
    adv_same_incorrect_kl_attack = 0      # L0 and KL same but incorrect
    adv_diff_one_correct_kl_attack = 0    # L0 and KL different but one correct
    adv_diff_both_wrong_kl_attack = 0     # L0 and KL different and both incorrect
    
    # Track individual metric accuracies for adversarial data (L0 attack)
    adv_l0_correct_l0_attack = 0
    adv_kl_correct_l0_attack = 0
    
    # Track individual metric accuracies for adversarial data (KL attack)
    adv_l0_correct_kl_attack = 0
    adv_kl_correct_kl_attack = 0
    
    # Category counters for adversarial predictions (Ensemble attack)
    adv_same_correct_ensemble_attack = 0        # L0 and KL same and correct
    adv_same_incorrect_ensemble_attack = 0      # L0 and KL same but incorrect
    adv_diff_one_correct_ensemble_attack = 0    # L0 and KL different but one correct
    adv_diff_both_wrong_ensemble_attack = 0     # L0 and KL different and both incorrect
    
    # Track individual metric accuracies for adversarial data (Ensemble attack)
    adv_l0_correct_ensemble_attack = 0
    adv_kl_correct_ensemble_attack = 0
    
    # NOTE: Models are already in eval mode from above, ensuring dropout is disabled
    # and batch normalization uses running statistics for consistent gradients
    for batch_idx, (inputs, targets) in enumerate(tqdm(testloader, desc='Evaluating')):
        inputs, targets = inputs.to(device), targets.to(device)
        batch_size = inputs.size(0)
        
        # Clean consensus prediction
        clean_predictions, clean_consensus_flags, clean_l0_preds, clean_kl_preds = evaluator.predict_with_consensus(inputs)
        
        # Generate adversarial examples for L0 model
        if args.attack_type == 'pgd':
            delta_l0 = attack_pgd(evaluator.l0_model, inputs, targets, args.attack_stepsize, 
                                10, args.norm, device, args.attack_epsilon, tau=args.tau, 
                                mean=CIFAR_MEAN, std=CIFAR_STD)
            
            # Generate adversarial examples for KL model
            delta_kl = attack_pgd(evaluator.kl_model, inputs, targets, args.attack_stepsize, 
                                10, args.norm, device, args.attack_epsilon, tau=args.tau, 
                                mean=CIFAR_MEAN, std=CIFAR_STD)

            delta_ensemble=ensemble_attack_pgd(evaluator.l0_model, evaluator.kl_model, inputs, targets, args.attack_stepsize, 
                                10, args.norm, device, args.attack_epsilon, tau=args.tau)
        elif args.attack_type == 'cw':
            delta_l0 = attack_cw(evaluator.l0_model, inputs, targets, args.attack_stepsize, 
                                10, args.norm, device, args.attack_epsilon, tau=args.tau, 
                                mean=CIFAR_MEAN, std=CIFAR_STD)
            
            # Generate adversarial examples for KL model
            delta_kl = attack_cw(evaluator.kl_model, inputs, targets, args.attack_stepsize, 
                                10, args.norm, device, args.attack_epsilon, tau=args.tau, 
                                mean=CIFAR_MEAN, std=CIFAR_STD)

            delta_ensemble=ensemble_attack_cw(evaluator.l0_model, evaluator.kl_model, inputs, targets, args.attack_stepsize, 
                                10, args.norm, device, args.attack_epsilon, tau=args.tau)
        elif args.attack_type == 'auto':
            delta_l0 = attack_auto(evaluator.l0_model, inputs, targets, args.attack_stepsize, 
                                10, args.norm, device, args.attack_epsilon, tau=args.tau, 
                                mean=CIFAR_MEAN, std=CIFAR_STD)

            delta_kl = attack_auto(evaluator.kl_model, inputs, targets, args.attack_stepsize, 
                                10, args.norm, device, args.attack_epsilon, tau=args.tau, 
                                mean=CIFAR_MEAN, std=CIFAR_STD)

            delta_ensemble=ensemble_attack_auto(evaluator.l0_model, evaluator.kl_model, inputs, targets, args.attack_stepsize, 
                                10, args.norm, device, args.attack_epsilon, tau=args.tau)
        
        # Adversarial consensus prediction using L0 adversarial examples
        adv_predictions_l0, adv_consensus_flags_l0, adv_l0_preds_l0, adv_kl_preds_l0 = evaluator.predict_with_consensus(inputs + delta_l0)
        
        # Adversarial consensus prediction using KL adversarial examples
        adv_predictions_kl, adv_consensus_flags_kl, adv_l0_preds_kl, adv_kl_preds_kl = evaluator.predict_with_consensus(inputs + delta_kl)
        
        # Adversarial consensus prediction using ensemble adversarial examples
        adv_predictions_ensemble, adv_consensus_flags_ensemble, adv_l0_preds_ensemble, adv_kl_preds_ensemble = evaluator.predict_with_consensus(inputs + delta_ensemble)
        
        # Count individual metric accuracies for adversarial data (L0 attack)
        adv_l0_correct_l0_attack += (adv_l0_preds_l0 == targets).sum().item()
        adv_kl_correct_l0_attack += (adv_kl_preds_l0 == targets).sum().item()
        
        # Count individual metric accuracies for adversarial data (KL attack)
        adv_l0_correct_kl_attack += (adv_l0_preds_kl == targets).sum().item()
        adv_kl_correct_kl_attack += (adv_kl_preds_kl == targets).sum().item()
        
        # Count individual metric accuracies for adversarial data (Ensemble attack)
        adv_l0_correct_ensemble_attack += (adv_l0_preds_ensemble == targets).sum().item()
        adv_kl_correct_ensemble_attack += (adv_kl_preds_ensemble == targets).sum().item()
        
        # Categorize clean predictions
        for i in range(batch_size):
            l0_pred = clean_l0_preds[i].item()
            kl_pred = clean_kl_preds[i].item()
            target = targets[i].item()
            
            if l0_pred == kl_pred:  # Same prediction
                if l0_pred == target:  # Both correct
                    clean_same_correct += 1
                else:  # Both incorrect
                    clean_same_incorrect += 1
            else:  # Different predictions
                if l0_pred == target or kl_pred == target:  # One correct
                    clean_diff_one_correct += 1
                else:  # Both incorrect
                    clean_diff_both_wrong += 1
        
        # Categorize adversarial predictions (L0 attack)
        for i in range(batch_size):
            l0_pred = adv_l0_preds_l0[i].item()
            kl_pred = adv_kl_preds_l0[i].item()
            target = targets[i].item()
            
            if l0_pred == kl_pred:  # Same prediction
                if l0_pred == target:  # Both correct
                    adv_same_correct_l0_attack += 1
                else:  # Both incorrect
                    adv_same_incorrect_l0_attack += 1
            else:  # Different predictions
                if l0_pred == target or kl_pred == target:  # One correct
                    adv_diff_one_correct_l0_attack += 1
                else:  # Both incorrect
                    adv_diff_both_wrong_l0_attack += 1
        
        # Categorize adversarial predictions (KL attack)
        for i in range(batch_size):
            l0_pred = adv_l0_preds_kl[i].item()
            kl_pred = adv_kl_preds_kl[i].item()
            target = targets[i].item()
            
            if l0_pred == kl_pred:  # Same prediction
                if l0_pred == target:  # Both correct
                    adv_same_correct_kl_attack += 1
                else:  # Both incorrect
                    adv_same_incorrect_kl_attack += 1
            else:  # Different predictions
                if l0_pred == target or kl_pred == target:  # One correct
                    adv_diff_one_correct_kl_attack += 1
                else:  # Both incorrect
                    adv_diff_both_wrong_kl_attack += 1
        
        # Categorize adversarial predictions (Ensemble attack)
        for i in range(batch_size):
            l0_pred = adv_l0_preds_ensemble[i].item()
            kl_pred = adv_kl_preds_ensemble[i].item()
            target = targets[i].item()
            
            if l0_pred == kl_pred:  # Same prediction
                if l0_pred == target:  # Both correct
                    adv_same_correct_ensemble_attack += 1
                else:  # Both incorrect
                    adv_same_incorrect_ensemble_attack += 1
            else:  # Different predictions
                if l0_pred == target or kl_pred == target:  # One correct
                    adv_diff_one_correct_ensemble_attack += 1
                else:  # Both incorrect
                    adv_diff_both_wrong_ensemble_attack += 1
        
        total += batch_size
    
    # Calculate percentages
    clean_same_correct_pct = 100. * clean_same_correct / total
    clean_same_incorrect_pct = 100. * clean_same_incorrect / total
    clean_diff_one_correct_pct = 100. * clean_diff_one_correct / total
    clean_diff_both_wrong_pct = 100. * clean_diff_both_wrong / total
    
    # Calculate percentages for L0 attack
    adv_same_correct_l0_attack_pct = 100. * adv_same_correct_l0_attack / total
    adv_same_incorrect_l0_attack_pct = 100. * adv_same_incorrect_l0_attack / total
    adv_diff_one_correct_l0_attack_pct = 100. * adv_diff_one_correct_l0_attack / total
    adv_diff_both_wrong_l0_attack_pct = 100. * adv_diff_both_wrong_l0_attack / total
    
    # Calculate percentages for KL attack
    adv_same_correct_kl_attack_pct = 100. * adv_same_correct_kl_attack / total
    adv_same_incorrect_kl_attack_pct = 100. * adv_same_incorrect_kl_attack / total
    adv_diff_one_correct_kl_attack_pct = 100. * adv_diff_one_correct_kl_attack / total
    adv_diff_both_wrong_kl_attack_pct = 100. * adv_diff_both_wrong_kl_attack / total
    
    # Calculate percentages for Ensemble attack
    adv_same_correct_ensemble_attack_pct = 100. * adv_same_correct_ensemble_attack / total
    adv_same_incorrect_ensemble_attack_pct = 100. * adv_same_incorrect_ensemble_attack / total
    adv_diff_one_correct_ensemble_attack_pct = 100. * adv_diff_one_correct_ensemble_attack / total
    adv_diff_both_wrong_ensemble_attack_pct = 100. * adv_diff_both_wrong_ensemble_attack / total
    
    # Calculate consensus rates
    clean_consensus_rate = 100. * (clean_same_correct + clean_same_incorrect) / total
    adv_consensus_rate_l0_attack = 100. * (adv_same_correct_l0_attack + adv_same_incorrect_l0_attack) / total
    adv_consensus_rate_kl_attack = 100. * (adv_same_correct_kl_attack + adv_same_incorrect_kl_attack) / total
    adv_consensus_rate_ensemble_attack = 100. * (adv_same_correct_ensemble_attack + adv_same_incorrect_ensemble_attack) / total
    
    # Calculate accuracies
    clean_consensus_acc = 100. * clean_same_correct / (clean_same_correct + clean_same_incorrect) if (clean_same_correct + clean_same_incorrect) > 0 else 0
    adv_consensus_acc_l0_attack = 100. * adv_same_correct_l0_attack / (adv_same_correct_l0_attack + adv_same_incorrect_l0_attack) if (adv_same_correct_l0_attack + adv_same_incorrect_l0_attack) > 0 else 0
    adv_consensus_acc_kl_attack = 100. * adv_same_correct_kl_attack / (adv_same_correct_kl_attack + adv_same_incorrect_kl_attack) if (adv_same_correct_kl_attack + adv_same_incorrect_kl_attack) > 0 else 0
    adv_consensus_acc_ensemble_attack = 100. * adv_same_correct_ensemble_attack / (adv_same_correct_ensemble_attack + adv_same_incorrect_ensemble_attack) if (adv_same_correct_ensemble_attack + adv_same_incorrect_ensemble_attack) > 0 else 0
    
    # Calculate true consensus accuracy (entire test set)
    clean_true_consensus_acc = 100. * clean_same_correct / total
    adv_true_consensus_acc_l0_attack = 100. * adv_same_correct_l0_attack / total
    adv_true_consensus_acc_kl_attack = 100. * adv_same_correct_kl_attack / total
    adv_true_consensus_acc_ensemble_attack = 100. * adv_same_correct_ensemble_attack / total
    
    # Calculate individual metric accuracies for adversarial data (L0 attack)
    adv_l0_acc_l0_attack = 100. * adv_l0_correct_l0_attack / total
    adv_kl_acc_l0_attack = 100. * adv_kl_correct_l0_attack / total
    
    # Calculate individual metric accuracies for adversarial data (KL attack)
    adv_l0_acc_kl_attack = 100. * adv_l0_correct_kl_attack / total
    adv_kl_acc_kl_attack = 100. * adv_kl_correct_kl_attack / total
    
    # Calculate individual metric accuracies for adversarial data (Ensemble attack)
    adv_l0_acc_ensemble_attack = 100. * adv_l0_correct_ensemble_attack / total
    adv_kl_acc_ensemble_attack = 100. * adv_kl_correct_ensemble_attack / total
    
    # Calculate robustness drops (L0 attack)
    consensus_robustness_drop_l0_attack = clean_consensus_acc - adv_consensus_acc_l0_attack
    true_consensus_robustness_drop_l0_attack = clean_true_consensus_acc - adv_true_consensus_acc_l0_attack
    l0_robustness_drop_l0_attack = clean_l0_acc - adv_l0_acc_l0_attack
    kl_robustness_drop_l0_attack = clean_kl_acc - adv_kl_acc_l0_attack
    
    # Calculate robustness drops (KL attack)
    consensus_robustness_drop_kl_attack = clean_consensus_acc - adv_consensus_acc_kl_attack
    true_consensus_robustness_drop_kl_attack = clean_true_consensus_acc - adv_true_consensus_acc_kl_attack
    l0_robustness_drop_kl_attack = clean_l0_acc - adv_l0_acc_kl_attack
    kl_robustness_drop_kl_attack = clean_kl_acc - adv_kl_acc_kl_attack
    
    # Calculate robustness drops (Ensemble attack)
    consensus_robustness_drop_ensemble_attack = clean_consensus_acc - adv_consensus_acc_ensemble_attack
    true_consensus_robustness_drop_ensemble_attack = clean_true_consensus_acc - adv_true_consensus_acc_ensemble_attack
    l0_robustness_drop_ensemble_attack = clean_l0_acc - adv_l0_acc_ensemble_attack
    kl_robustness_drop_ensemble_attack = clean_kl_acc - adv_kl_acc_ensemble_attack
    
    print(f'\n=== CLEAN PREDICTIONS BREAKDOWN ===')
    print(f'Total images: {total}')
    print(f'1. L0 and KL same and correct: {clean_same_correct} ({clean_same_correct_pct:.2f}%)')
    print(f'2. L0 and KL same but incorrect: {clean_same_incorrect} ({clean_same_incorrect_pct:.2f}%)')
    print(f'3. L0 and KL different but one correct: {clean_diff_one_correct} ({clean_diff_one_correct_pct:.2f}%)')
    print(f'4. L0 and KL different and both incorrect: {clean_diff_both_wrong} ({clean_diff_both_wrong_pct:.2f}%)')
    print(f'')
    print(f'Consensus Rate: {clean_consensus_rate:.2f}% (Groups 1+2)')
    print(f'Consensus Accuracy: {clean_consensus_acc:.2f}% (Group 1 / Groups 1+2)')
    print(f'TRUE Consensus Accuracy: {clean_true_consensus_acc:.2f}% (Group 1 / Total)')
    
    print(f'\n=== ADVERSARIAL PREDICTIONS BREAKDOWN (L0 ATTACK) ===')
    print(f'1. L0 and KL same and correct: {adv_same_correct_l0_attack} ({adv_same_correct_l0_attack_pct:.2f}%)')
    print(f'2. L0 and KL same but incorrect: {adv_same_incorrect_l0_attack} ({adv_same_incorrect_l0_attack_pct:.2f}%)')
    print(f'3. L0 and KL different but one correct: {adv_diff_one_correct_l0_attack} ({adv_diff_one_correct_l0_attack_pct:.2f}%)')
    print(f'4. L0 and KL different and both incorrect: {adv_diff_both_wrong_l0_attack} ({adv_diff_both_wrong_l0_attack_pct:.2f}%)')
    print(f'')
    print(f'Consensus Rate: {adv_consensus_rate_l0_attack:.2f}% (Groups 1+2)')
    print(f'Consensus Accuracy: {adv_consensus_acc_l0_attack:.2f}% (Group 1 / Groups 1+2)')
    print(f'TRUE Consensus Accuracy: {adv_true_consensus_acc_l0_attack:.2f}% (Group 1 / Total)')
    
    print(f'\n=== ADVERSARIAL PREDICTIONS BREAKDOWN (KL ATTACK) ===')
    print(f'1. L0 and KL same and correct: {adv_same_correct_kl_attack} ({adv_same_correct_kl_attack_pct:.2f}%)')
    print(f'2. L0 and KL same but incorrect: {adv_same_incorrect_kl_attack} ({adv_same_incorrect_kl_attack_pct:.2f}%)')
    print(f'3. L0 and KL different but one correct: {adv_diff_one_correct_kl_attack} ({adv_diff_one_correct_kl_attack_pct:.2f}%)')
    print(f'4. L0 and KL different and both incorrect: {adv_diff_both_wrong_kl_attack} ({adv_diff_both_wrong_kl_attack_pct:.2f}%)')
    print(f'')
    print(f'Consensus Rate: {adv_consensus_rate_kl_attack:.2f}% (Groups 1+2)')
    print(f'Consensus Accuracy: {adv_consensus_acc_kl_attack:.2f}% (Group 1 / Groups 1+2)')
    print(f'TRUE Consensus Accuracy: {adv_true_consensus_acc_kl_attack:.2f}% (Group 1 / Total)')
    
    print(f'\n=== ADVERSARIAL PREDICTIONS BREAKDOWN (ENSEMBLE ATTACK) ===')
    print(f'1. L0 and KL same and correct: {adv_same_correct_ensemble_attack} ({adv_same_correct_ensemble_attack_pct:.2f}%)')
    print(f'2. L0 and KL same but incorrect: {adv_same_incorrect_ensemble_attack} ({adv_same_incorrect_ensemble_attack_pct:.2f}%)')
    print(f'3. L0 and KL different but one correct: {adv_diff_one_correct_ensemble_attack} ({adv_diff_one_correct_ensemble_attack_pct:.2f}%)')
    print(f'4. L0 and KL different and both incorrect: {adv_diff_both_wrong_ensemble_attack} ({adv_diff_both_wrong_ensemble_attack_pct:.2f}%)')
    print(f'')
    print(f'Consensus Rate: {adv_consensus_rate_ensemble_attack:.2f}% (Groups 1+2)')
    print(f'Consensus Accuracy: {adv_consensus_acc_ensemble_attack:.2f}% (Group 1 / Groups 1+2)')
    print(f'TRUE Consensus Accuracy: {adv_true_consensus_acc_ensemble_attack:.2f}% (Group 1 / Total)')
    
    print(f'\n=== INDIVIDUAL METRIC PERFORMANCE (L0 ATTACK) ===')
    print(f'L0 Accuracy: {adv_l0_acc_l0_attack:.2f}% ({adv_l0_correct_l0_attack}/{total})')
    print(f'KL Accuracy: {adv_kl_acc_l0_attack:.2f}% ({adv_kl_correct_l0_attack}/{total})')
    
    print(f'\n=== INDIVIDUAL METRIC PERFORMANCE (KL ATTACK) ===')
    print(f'L0 Accuracy: {adv_l0_acc_kl_attack:.2f}% ({adv_l0_correct_kl_attack}/{total})')
    print(f'KL Accuracy: {adv_kl_acc_kl_attack:.2f}% ({adv_kl_correct_kl_attack}/{total})')
    
    print(f'\n=== INDIVIDUAL METRIC PERFORMANCE (ENSEMBLE ATTACK) ===')
    print(f'L0 Accuracy: {adv_l0_acc_ensemble_attack:.2f}% ({adv_l0_correct_ensemble_attack}/{total})')
    print(f'KL Accuracy: {adv_kl_acc_ensemble_attack:.2f}% ({adv_kl_correct_ensemble_attack}/{total})')
    
    print(f'\n=== ROBUSTNESS ANALYSIS (L0 ATTACK) ===')
    print(f'Consensus Robustness Drop: {consensus_robustness_drop_l0_attack:.2f}%')
    print(f'TRUE Consensus Robustness Drop: {true_consensus_robustness_drop_l0_attack:.2f}%')
    print(f'L0 Robustness Drop: {l0_robustness_drop_l0_attack:.2f}%')
    print(f'KL Robustness Drop: {kl_robustness_drop_l0_attack:.2f}%')
    
    print(f'\n=== ROBUSTNESS ANALYSIS (KL ATTACK) ===')
    print(f'Consensus Robustness Drop: {consensus_robustness_drop_kl_attack:.2f}%')
    print(f'TRUE Consensus Robustness Drop: {true_consensus_robustness_drop_kl_attack:.2f}%')
    print(f'L0 Robustness Drop: {l0_robustness_drop_kl_attack:.2f}%')
    print(f'KL Robustness Drop: {kl_robustness_drop_kl_attack:.2f}%')
    
    print(f'\n=== ROBUSTNESS ANALYSIS (ENSEMBLE ATTACK) ===')
    print(f'Consensus Robustness Drop: {consensus_robustness_drop_ensemble_attack:.2f}%')
    print(f'TRUE Consensus Robustness Drop: {true_consensus_robustness_drop_ensemble_attack:.2f}%')
    print(f'L0 Robustness Drop: {l0_robustness_drop_ensemble_attack:.2f}%')
    print(f'KL Robustness Drop: {kl_robustness_drop_ensemble_attack:.2f}%')
    
    print(f'')
    print(f'KEY INSIGHT:')
    print(f'  L0 Attack - Adversarial accuracy = 1 - (Group 2 rate) = {100-adv_same_incorrect_l0_attack_pct:.2f}%')
    print(f'  KL Attack - Adversarial accuracy = 1 - (Group 2 rate) = {100-adv_same_incorrect_kl_attack_pct:.2f}%')
    print(f'  Ensemble Attack - Adversarial accuracy = 1 - (Group 2 rate) = {100-adv_same_incorrect_ensemble_attack_pct:.2f}%')
    print(f'This means: When L0 and KL agree under attack, they are correct:')
    print(f'  - L0 Attack: {100-adv_same_incorrect_l0_attack_pct:.2f}% of the time')
    print(f'  - KL Attack: {100-adv_same_incorrect_kl_attack_pct:.2f}% of the time')
    print(f'  - Ensemble Attack: {100-adv_same_incorrect_ensemble_attack_pct:.2f}% of the time')
    
    # Restore original gradient state
    torch.set_grad_enabled(False)
    print("✅ Restored original gradient state")
    
    return {
        'clean_same_correct': clean_same_correct,
        'clean_same_incorrect': clean_same_incorrect,
        'clean_diff_one_correct': clean_diff_one_correct,
        'clean_diff_both_wrong': clean_diff_both_wrong,
        'clean_consensus_rate': clean_consensus_rate,
        'clean_consensus_acc': clean_consensus_acc,
        'clean_true_consensus_acc': clean_true_consensus_acc,
        'clean_l0_acc': clean_l0_acc,
        'clean_kl_acc': clean_kl_acc,
        
        # L0 Attack Results
        'adv_same_correct_l0_attack': adv_same_correct_l0_attack,
        'adv_same_incorrect_l0_attack': adv_same_incorrect_l0_attack,
        'adv_diff_one_correct_l0_attack': adv_diff_one_correct_l0_attack,
        'adv_diff_both_wrong_l0_attack': adv_diff_both_wrong_l0_attack,
        'adv_consensus_rate_l0_attack': adv_consensus_rate_l0_attack,
        'adv_consensus_acc_l0_attack': adv_consensus_acc_l0_attack,
        'adv_true_consensus_acc_l0_attack': adv_true_consensus_acc_l0_attack,
        'adv_l0_acc_l0_attack': adv_l0_acc_l0_attack,
        'adv_kl_acc_l0_attack': adv_kl_acc_l0_attack,
        'consensus_robustness_drop_l0_attack': consensus_robustness_drop_l0_attack,
        'true_consensus_robustness_drop_l0_attack': true_consensus_robustness_drop_l0_attack,
        'l0_robustness_drop_l0_attack': l0_robustness_drop_l0_attack,
        'kl_robustness_drop_l0_attack': kl_robustness_drop_l0_attack,
        
        # KL Attack Results
        'adv_same_correct_kl_attack': adv_same_correct_kl_attack,
        'adv_same_incorrect_kl_attack': adv_same_incorrect_kl_attack,
        'adv_diff_one_correct_kl_attack': adv_diff_one_correct_kl_attack,
        'adv_diff_both_wrong_kl_attack': adv_diff_both_wrong_kl_attack,
        'adv_consensus_rate_kl_attack': adv_consensus_rate_kl_attack,
        'adv_consensus_acc_kl_attack': adv_consensus_acc_kl_attack,
        'adv_true_consensus_acc_kl_attack': adv_true_consensus_acc_kl_attack,
        'adv_l0_acc_kl_attack': adv_l0_acc_kl_attack,
        'adv_kl_acc_kl_attack': adv_kl_acc_kl_attack,
        'consensus_robustness_drop_kl_attack': consensus_robustness_drop_kl_attack,
        'true_consensus_robustness_drop_kl_attack': true_consensus_robustness_drop_kl_attack,
        'l0_robustness_drop_kl_attack': l0_robustness_drop_kl_attack,
        'kl_robustness_drop_kl_attack': kl_robustness_drop_kl_attack,
        
        # Ensemble Attack Results
        'adv_same_correct_ensemble_attack': adv_same_correct_ensemble_attack,
        'adv_same_incorrect_ensemble_attack': adv_same_incorrect_ensemble_attack,
        'adv_diff_one_correct_ensemble_attack': adv_diff_one_correct_ensemble_attack,
        'adv_diff_both_wrong_ensemble_attack': adv_diff_both_wrong_ensemble_attack,
        'adv_consensus_rate_ensemble_attack': adv_consensus_rate_ensemble_attack,
        'adv_consensus_acc_ensemble_attack': adv_consensus_acc_ensemble_attack,
        'adv_true_consensus_acc_ensemble_attack': adv_true_consensus_acc_ensemble_attack,
        'adv_l0_acc_ensemble_attack': adv_l0_acc_ensemble_attack,
        'adv_kl_acc_ensemble_attack': adv_kl_acc_ensemble_attack,
        'consensus_robustness_drop_ensemble_attack': consensus_robustness_drop_ensemble_attack,
        'true_consensus_robustness_drop_ensemble_attack': true_consensus_robustness_drop_ensemble_attack,
        'l0_robustness_drop_ensemble_attack': l0_robustness_drop_ensemble_attack,
        'kl_robustness_drop_ensemble_attack': kl_robustness_drop_ensemble_attack,
        
        'adversarial_accuracy_l0_attack': 100 - adv_same_incorrect_l0_attack_pct,
        'adversarial_accuracy_kl_attack': 100 - adv_same_incorrect_kl_attack_pct,
        'adversarial_accuracy_ensemble_attack': 100 - adv_same_incorrect_ensemble_attack_pct
    }

def main():
    parser = argparse.ArgumentParser(description='Train L0 and KL models separately, then evaluate with consensus')
    parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
    parser.add_argument('--epochs', type=int, default=100, help='number of training epochs')
    parser.add_argument('--alpha', type=float, default=1.0, help='weight for classification loss')
    parser.add_argument('--beta', type=float, default=0.01, help='weight for separation regularization')
    parser.add_argument('--margin', type=float, default=0.2, help='minimum distance between prototypes')
    parser.add_argument('--tau', type=float, default=0.75, help='threshold parameter for L0 similarity')
    parser.add_argument('--class_boost', type=float, default=0.5, help='boost value for correct class similarities')
    parser.add_argument('--patience', type=int, default=10, help='early stopping patience')  # Reduced from 20
    parser.add_argument('--backbone', type=str, default='resnet18', 
                       choices=['resnet18', 'resnet50', 'vgg19', 'densenet121'], 
                       help='backbone architecture')
    parser.add_argument('--pretrained', action='store_true', help='use pretrained ImageNet weights')
    parser.add_argument('--load_model', type=str, help='load pretrained model from main.py checkpoint')
    parser.add_argument('--l0_checkpoint', type=str, default='./checkpoint_l0/ckpt_l0.pth', 
                       help='L0 model checkpoint path')
    parser.add_argument('--kl_checkpoint', type=str, default='./checkpoint_kl/ckpt_kl.pth', 
                       help='KL model checkpoint path')
    parser.add_argument('--train_l0', action='store_true', help='train L0 model')
    parser.add_argument('--train_kl', action='store_true', help='train KL model')
    parser.add_argument('--evaluate', action='store_true', help='evaluate consensus on test set')
    parser.add_argument('--eval_robustness', action='store_true', help='evaluate adversarial robustness')
    parser.add_argument('--attack_epsilon', type=int, default=2, help='perturbation budget (will be divided by 255)')
    parser.add_argument('--attack_stepsize', type=int, default=2, help='attack step size (will be divided by 255)')
    parser.add_argument('--attack_type', type=str, default='pgd', choices=['pgd', 'cw','auto'], help='attack type')
    parser.add_argument('--norm', type=str, default='l_inf', choices=['l_inf', 'l_2'], help='norm for attack')
    parser.add_argument('--dropout_rate', type=float, default=0.3, help='dropout rate for L0 and KL models')
    
    args = parser.parse_args()
    
    # Convert epsilon and step size from integer to float
    args.attack_epsilon = args.attack_epsilon / 255.0
    args.attack_stepsize = args.attack_stepsize / 255.0
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'Using device: {device}')
    
    # Data preparation
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True, num_workers=2)
    
    # Split test set into validation and test sets
    testset_full = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    
    torch.manual_seed(42)
    test_size = len(testset_full)
    val_size = test_size // 2
    test_size = test_size - val_size
    
    valset, testset = torch.utils.data.random_split(testset_full, [val_size, test_size])
    
    valloader = torch.utils.data.DataLoader(
        valset, batch_size=100, shuffle=False, num_workers=2)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2)
    
    print(f'Training set: {len(trainset)} images')
    print(f'Validation set: {len(valset)} images')
    print(f'Test set: {len(testset)} images')
    
    # Build backbone
    print('==> Building backbone..')
    if args.pretrained:
        print('==> Loading pretrained ImageNet weights for backbone..')
        if args.backbone == 'resnet18':
            backbone = torchvision.models.resnet18(pretrained=True)
        elif args.backbone == 'resnet50':
            backbone = torchvision.models.resnet50(pretrained=True)
        elif args.backbone == 'vgg19':
            backbone = torchvision.models.vgg19(pretrained=True)
        elif args.backbone == 'densenet121':
            backbone = torchvision.models.densenet121(pretrained=True)
        else:
            backbone = ResNet18()
        
        # Remove the final classification layer
        if hasattr(backbone, 'fc'):
            backbone.fc = nn.Identity()
        elif hasattr(backbone, 'classifier'):
            backbone.classifier = nn.Identity()
        
        print(f'==> Loaded pretrained {args.backbone} backbone')
    else:
        print('==> Using random initialization for backbone..')
        if args.backbone == 'resnet18':
            backbone = ResNet18()
            backbone.linear = nn.Identity()
        elif args.backbone == 'vgg19':
            backbone = VGG('VGG19')
            backbone.classifier = nn.Identity()
        elif args.backbone == 'densenet121':
            backbone = DenseNet121()
            backbone.classifier = nn.Identity()
        else:
            backbone = ResNet18()
            backbone.linear = nn.Identity()
    
    # Load pretrained weights if specified
    if args.load_model:
        print(f'==> Loading model from {args.load_model}..')
        if os.path.isfile(args.load_model):
            checkpoint = torch.load(args.load_model)
            # Load backbone weights (excluding final layer)
            backbone_state_dict = {}
            for key, value in checkpoint['net'].items():
                if key.startswith('module.'):
                    clean_key = key[7:]
                else:
                    clean_key = key
                
                if not clean_key.startswith('linear.') and not clean_key.startswith('classifier.'):
                    backbone_state_dict[clean_key] = value
            
            backbone.load_state_dict(backbone_state_dict)
            print(f'==> Loaded backbone weights from checkpoint')
    
    # Training L0 model
    if args.train_l0:
        print("==> Training L0 Model...")
        print(f"==> Using dropout rate: {args.dropout_rate} to combat overfitting")
        l0_model = DynamicPrototypeModelL0(backbone, num_classes=10, embedding_dim=512, dropout_rate=args.dropout_rate)
        l0_model = l0_model.to(device)
        
        # Initialize prototypes from data centroids
        l0_model.initialize_prototypes_from_data(trainloader, device)
        
        # Create checkpoint directory
        os.makedirs(os.path.dirname(args.l0_checkpoint), exist_ok=True)
        
        # Train L0 model
        l0_best_acc = train_l0_model(l0_model, trainloader, valloader, testloader, device, args)
        print(f"L0 training completed with best accuracy: {l0_best_acc:.2f}%")
    
    # Training KL model
    if args.train_kl:
        print("==> Training KL Model...")
        print(f"==> Using dropout rate: {args.dropout_rate} to combat overfitting")
        kl_model = DynamicPrototypeModelKL(backbone, num_classes=10, embedding_dim=512, dropout_rate=args.dropout_rate)
        kl_model = kl_model.to(device)
        
        # Initialize prototypes from data centroids
        kl_model.initialize_prototypes_from_data(trainloader, device)
        
        # Create checkpoint directory
        os.makedirs(os.path.dirname(args.kl_checkpoint), exist_ok=True)
        
        # Train KL model
        kl_best_acc = train_kl_model(kl_model, trainloader, valloader,testloader,  device, args)
        print(f"KL training completed with best accuracy: {kl_best_acc:.2f}%")
    
    # Evaluation
    if args.evaluate or args.eval_robustness:
        print("==> Loading trained models for evaluation...")
        
        # Load L0 model with its own backbone
        if os.path.isfile(args.l0_checkpoint):
            print("==> Loading L0 model...")
            # Create L0 model with fresh backbone
            l0_backbone = ResNet18()
            l0_backbone.linear = nn.Identity()
            l0_model = DynamicPrototypeModelL0(l0_backbone, num_classes=10, embedding_dim=512, dropout_rate=args.dropout_rate)
            l0_model = l0_model.to(device)
            
            # Load L0 checkpoint
            l0_checkpoint = torch.load(args.l0_checkpoint)
            l0_model.load_state_dict(l0_checkpoint['net'])
            
            # DO NOT reinitialize prototypes - use the trained ones!
            print(f"Loaded L0 model with accuracy: {l0_checkpoint['acc']:.2f}%")
        else:
            print(f"L0 checkpoint not found at {args.l0_checkpoint}")
            return
        
        # Load KL model with its own backbone
        if os.path.isfile(args.kl_checkpoint):
            print("==> Loading KL model...")
            # Create KL model with fresh backbone
            kl_backbone = ResNet18()
            kl_backbone.linear = nn.Identity()
            kl_model = DynamicPrototypeModelKL(kl_backbone, num_classes=10, embedding_dim=512, dropout_rate=args.dropout_rate)
            kl_model = kl_model.to(device)
            
            # Load KL checkpoint
            kl_checkpoint = torch.load(args.kl_checkpoint)
            kl_model.load_state_dict(kl_checkpoint['net'])
            
            # DO NOT reinitialize prototypes - use the trained ones!
            print(f"Loaded KL model with accuracy: {kl_checkpoint['acc']:.2f}%")
        else:
            print(f"KL checkpoint not found at {args.kl_checkpoint}")
            return
        
        # Create consensus evaluator
        evaluator = ConsensusEvaluator(l0_model, kl_model, device)
        
        # Evaluate on test set
        if args.evaluate:
            print("==> Evaluating consensus on test set...")
            consensus_acc, l0_acc, kl_acc, consensus_rate = evaluator.evaluate_accuracy(testloader, "Test")
        
        # Evaluate adversarial robustness
        if args.eval_robustness:
            print("==> Evaluating adversarial robustness...")
            robustness_results = evaluate_adversarial_robustness(evaluator, testloader, device, args)
            
if __name__ == '__main__':
    main()
