import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Import math proof analysis function
from math_part import math_proof_analysis

# Try to import AutoAttack
try:
    from autoattack import AutoAttack
    AUTOATTACK_AVAILABLE = True
    print("AutoAttack imported successfully")
except ImportError:
    AUTOATTACK_AVAILABLE = False
    print("Warning: AutoAttack not available. Install with: pip install autoattack")

from models import *
try:
    from utils import progress_bar
    PROGRESS_BAR_AVAILABLE = True
except (ValueError, OSError):
    PROGRESS_BAR_AVAILABLE = False
    print("Warning: Progress bar not available, using simple progress display")

def progress_bar(current, total, msg=None):
    if current % 10 == 0 or current == total - 1:
        print(f"Progress: {current+1}/{total} {msg or ''}")

# Adversarial attack parameters
# CIFAR-10 normalization constants
CIFAR_MEAN = torch.tensor([0.4914, 0.4822, 0.4465]).view(1, 3, 1, 1)
CIFAR_STD = torch.tensor([0.2023, 0.1994, 0.2010]).view(1, 3, 1, 1)
lower_limit, upper_limit = 0, 1

def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)

def corr_penalty(a, b):
    """Compute Pearson correlation penalty between two logit tensors"""
    # a, b: (B, C) logits
    a = a - a.mean(dim=1, keepdim=True)
    b = b - b.mean(dim=1, keepdim=True)
    num = (a * b).sum(dim=1).mean()
    den = (a.norm(dim=1) * b.norm(dim=1)).mean() + 1e-8
    return num / den  # Pearson correlation in [-1,1]

def ortho_reg(W1, W2):
    """Compute orthogonality regularization between two weight matrices"""
    M = W1.t() @ W2
    return (M * M).mean()

def _norm_logits_for_loss(z):
    """Standardize logits for loss computation to equalize gradient scales"""
    z = z - z.mean(dim=1, keepdim=True)
    z = z / (z.std(dim=1, keepdim=True) + 1e-8)
    return z
def replace_activations(module: nn.Module, new_act: nn.Module):
    """Replace ReLU activations with a new activation function"""
    for name, child in module.named_children():
        # torchvision ResNet uses nn.ReLU(inplace=True)
        if isinstance(child, nn.ReLU):
            setattr(module, name, new_act.__class__(**new_act.__dict__['_parameters']))
        else:
            replace_activations(child, new_act)

class DynamicPrototypeModelL0KL(nn.Module):

    def __init__(self, backbone, num_classes=10, embedding_dim=512, dropout_rate=0.3):
        super().__init__()
        self.backbone = backbone
        # Initialize prototypes as learnable parameters (will be set to centroids later)
        self.class_prototypes = nn.Parameter(torch.randn(num_classes, embedding_dim) * 0.1)
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self._prototypes_initialized = False
        # ADD DROPOUT LAYERS TO COMBAT OVERFITTING
        self.dropout1 = nn.Dropout(dropout_rate)  # After backbone
        self.dropout2 = nn.Dropout(dropout_rate)  # Before similarity computation
        
        # Two lightweight per-head projections (image feature views)
        self.proj_l0 = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
        self.proj_kl = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
        # Mild init for L0: almost identity -> easy learning
        with torch.no_grad():
            self.proj_l0.weight.zero_()
            eye = torch.eye(self.embedding_dim, device=self.proj_l0.weight.device)
            self.proj_l0.weight.add_(eye)                            # identity
            self.proj_l0.weight.add_(0.01 * torch.randn_like(eye))   # tiny noise

                
        # Good default init so each head starts a bit different but stable
        #nn.init.orthogonal_(self.proj_l0.weight, gain=1.0)
        nn.init.orthogonal_(self.proj_kl.weight, gain=1.0)
        
        # Prototype projections for different geometric views
        self.proto_proj_l0 = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
        self.proto_proj_kl = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
        nn.init.orthogonal_(self.proto_proj_l0.weight, gain=1.0)
        nn.init.orthogonal_(self.proto_proj_kl.weight, gain=1.0)
        
        # Learnable temperature scales for each head
        self.l0_logit_scale = nn.Parameter(torch.tensor(1.0))
        self.kl_logit_scale = nn.Parameter(torch.tensor(1.0))
        
        # Initialize with helpful bump so L0 starts with wider spread
        with torch.no_grad():
            self.l0_logit_scale.copy_(torch.tensor(6.0))  # try 4–12
            self.kl_logit_scale.copy_(torch.tensor(1.0))
        
        # Per-head dropout fields for slight stochastic diversity during training
        self.dropout_l0 = nn.Dropout(p=dropout_rate)         # already have dropout_rate arg
        self.dropout_kl = nn.Dropout(p=dropout_rate * 0.7)   # asymmetric on purpose
        self._prototypes_frozen = False
    def freeze_prototypes(self):
        """Freeze prototype parameters"""
        # Fix: class_prototypes is a Parameter, not a module
        self.class_prototypes.requires_grad_(False)
        self._prototypes_frozen = True
        print("🔒 Prototypes frozen - no further updates")
    
    def unfreeze_prototypes(self):
        """Unfreeze prototype parameters"""
        # Fix: class_prototypes is a Parameter, not a module
        self.class_prototypes.requires_grad_(True)
        self._prototypes_frozen = False
        print("🔄 Prototypes unfrozen - can be updated")
    
    def get_prototype_status(self):
        """Get current prototype training status"""
        return {
            'frozen': self._prototypes_frozen,
            'requires_grad': self.class_prototypes.requires_grad,
            'norm': torch.norm(self.class_prototypes, p=2, dim=1).mean().item()
        }
    def forward(self, x, targets=None, return_individual=False, tau=0.75, compute_separation=False, margin=0.2, class_boost=0.5, l0_weight=1.0, kl_weight=0.0,dot_weight=0.0):
        # Extract image embeddings
        image_embeddings = self.backbone(x)
        # ADD DROPOUT 1: After backbone extraction
        image_embeddings = self.dropout1(image_embeddings)
        # Normalize image embeddings only (keep prototypes as raw vectors for separation)
        image_embeddings = F.normalize(image_embeddings, p=2, dim=1)
        # ADD DROPOUT 2: Before similarity computation
        image_embeddings = self.dropout2(image_embeddings)
        
        # Per-head feature views (do NOT change prototypes)
        z_l0 = F.normalize(image_embeddings, p=2, dim=1)
        z_kl = F.normalize(image_embeddings, p=2, dim=1)


        z_l0_processed=F.softmax(z_l0/0.05, dim=1)
        z_kl_processed=F.softmax(z_kl/0.05, dim=1)

        protos_l0 = self.class_prototypes
        protos_l0= F.softmax(protos_l0/0.05, dim=1)
        protos_kl = self.class_prototypes
        protos_kl= F.softmax(protos_kl/0.05, dim=1)
        # Compute L0 and KL distances based on weights - use projected prototypes
        if l0_weight > 0:
            l0_similarities, l0_distances = self.compute_l0_similarity(z_l0_processed, protos_l0, tau=tau)  # Use raw image embeddings for L0
        else:
            l0_similarities = torch.zeros(image_embeddings.size(0), self.num_classes, device=image_embeddings.device)
            l0_distances = torch.zeros(image_embeddings.size(0), self.num_classes, device=image_embeddings.device)
            
        if kl_weight > 0:
            kl_similarities, kl_distances = self.compute_kl_similarity(z_kl_processed, protos_kl)  # Use processed z_kl for KL
        else:
            # Skip KL computation entirely when kl_weight = 0
            # This ensures the model behaves exactly like respective L0 training
            kl_similarities = torch.zeros(z_kl.size(0), self.num_classes, device=z_kl.device)
            kl_distances = torch.zeros(z_kl.size(0), self.num_classes, device=z_kl.device)
        
        # Compute dot product similarity as true baseline (detached, no projections)
        base_feats = F.normalize(image_embeddings, p=2, dim=1)   # no proj, no dropout
        base_protos = F.normalize(self.class_prototypes, p=2, dim=1)
        dot_similarities = base_feats @ base_protos.t()


        if targets is not None and l0_weight > 0:
            batch_size = l0_similarities.size(0)
            class_indices = targets.view(-1, 1)
            # Create one-hot encoding for correct classes
            one_hot = torch.zeros_like(l0_similarities)
            one_hot.scatter_(1, class_indices, 1)

        # Apply learnable temperature scales
        l0_normalized = l0_similarities * self.l0_logit_scale
        kl_normalized = kl_similarities * self.kl_logit_scale

        if return_individual:
            self._log_metric_dominance(l0_normalized, kl_normalized, l0_weight, kl_weight)
            # Debug L0 similarities and gradients
            if not hasattr(self, '_l0_debug_counter'):
                self._l0_debug_counter = 0
            self._l0_debug_counter += 1
            if self._l0_debug_counter % 100 == 0:
                print(f"\n=== Similarity Debug (Batch {self._l0_debug_counter}) ===")
                print(f"Raw L0 similarities - Min: {l0_similarities.min().item():.6f}, Max: {l0_similarities.max().item():.6f}")
                print(f"Normalized L0 - Min: {l0_normalized.min().item():.6f}, Max: {l0_normalized.max().item():.6f}")
                print(f"KL similarities - Min: {kl_similarities.min().item():.6f}, Max: {kl_similarities.max().item():.6f}")
                print(f"Dot similarities - Min: {dot_similarities.min().item():.6f}, Max: {dot_similarities.max().item():.6f}")
                print(f"L0 predictions: {torch.argmax(l0_normalized, dim=1)[:5].tolist()}")
                print(f"KL predictions: {torch.argmax(kl_normalized, dim=1)[:5].tolist()}")
                print(f"Dot predictions: {torch.argmax(dot_similarities, dim=1)[:5].tolist()}")
                print(f"L0 weight: {l0_weight}, KL weight: {kl_weight}")
                print(f"Using projected features: L0 (z_l0), KL (z_kl), Dot (z_l0), shared prototypes")
                print("=" * 50)


        combined_similarities = l0_weight * l0_normalized + kl_weight * kl_normalized+dot_weight*dot_similarities

        if l0_weight == 1.0 and kl_weight == 0.0:
            combined_similarities = l0_normalized
            if not hasattr(self, '_l0_only_warning_shown'):
                print(f"🚀 L0-ONLY MODE ACTIVATED: Using only L0 similarities")
                print(f" This should behave exactly like respective L0 training")
                print(f" L0 weight: {l0_weight}, KL weight: {kl_weight}")
                self._l0_only_warning_shown = True


        elif l0_weight == 0.0 and kl_weight == 1.0:
            combined_similarities = kl_normalized
            if not hasattr(self, '_kl_only_warning_shown'):
                print(f"🚀 KL-ONLY MODE ACTIVATED: Using only KL similarities")
                print(f" This should behave exactly like respective L0 training")
                print(f" L0 weight: {l0_weight}, KL weight: {kl_weight}")
                self._kl_only_warning_shown = True

        separation_loss = None
        if compute_separation:
            forced_prototypes = self.class_prototypes + 0.0
            separation_loss = self._compute_separation_loss(forced_prototypes, margin)

        if return_individual:
            if compute_separation:
                return l0_normalized, kl_normalized, combined_similarities, separation_loss, l0_distances, kl_distances, dot_similarities
            else:
                return l0_normalized, kl_normalized, combined_similarities, l0_distances, kl_distances, dot_similarities
        if compute_separation:
            return combined_similarities, separation_loss
        else:
            return combined_similarities

    def _compute_separation_loss(self, class_prototypes, margin=0.2):
        """Compute separation loss directly in forward pass to maintain computation graph"""
        num_classes = class_prototypes.size(0)
        

        class_prototypes_expanded = class_prototypes.unsqueeze(1)  # (num_classes, 1, embedding_dim)
        class_prototypes_expanded_t = class_prototypes.unsqueeze(0)  # (1, num_classes, embedding_dim)
        
        # Absolute difference for each pair of prototypes across all dimensions
        abs_diffs = torch.abs(class_prototypes_expanded - class_prototypes_expanded_t)  # (num_classes, num_classes, embedding_dim)
        
        mask = torch.eye(num_classes, device=class_prototypes.device).unsqueeze(-1)  # (num_classes, num_classes, 1)
        off_diagonal_diffs = abs_diffs * (1 - mask)  # (num_classes, num_classes, embedding_dim)
        
        # Find minimum distance across all dimensions for each pair
        min_diffs_per_pair = off_diagonal_diffs.min(dim=2)[0]  # (num_classes, num_classes)

        diagonal_mask = torch.eye(num_classes, device=class_prototypes.device).bool()
        off_diagonal_min_diffs = min_diffs_per_pair.masked_fill(diagonal_mask, float('inf'))
        min_sep = off_diagonal_min_diffs.min()  # Keep as tensor for gradient computation
        
        # Separation loss: margin - min_separation (positive when min_sep < margin)
        #separation_loss = F.gelu(margin - min_sep)  # Use ReLU to ensure non-negative loss
        # Use a more aggressive loss function for better gradients
        # Instead of GELU, use a combination of ReLU and exponential for better gradient flow
        separation_loss = torch.exp(-min_sep * 10) + 1.0 / (min_sep + 1e-8)
        
        return separation_loss

    def _log_metric_dominance(self, l0_normalized, kl_normalized, l0_weight, kl_weight):
        """Log statistics to check which metric dominates"""
        # Calculate variance of each metric across classes
        l0_variance = torch.var(l0_normalized, dim=1).mean().item()
        kl_variance = torch.var(kl_normalized, dim=1).mean().item()
        # Calculate range of each metric across classes
        l0_range = (l0_normalized.max(dim=1)[0] - l0_normalized.min(dim=1)[0]).mean().item()
        kl_range = (kl_normalized.max(dim=1)[0] - kl_normalized.min(dim=1)[0]).mean().item()
        # Calculate standard deviation
        l0_std = torch.std(l0_normalized, dim=1).mean().item()
        kl_std = torch.std(kl_normalized, dim=1).mean().item()
        # Determine dominance
        l0_dominance_score = l0_variance + l0_range + l0_std
        kl_dominance_score = kl_variance + kl_range + kl_std
        # Log every 100 batches to avoid spam
        if not hasattr(self, '_batch_counter'):
            self._batch_counter = 0
        self._batch_counter += 1
        if self._batch_counter % 100 == 0:
            print(f"L0 vs KL Dominance - L0: {l0_dominance_score:.3f}, KL: {kl_dominance_score:.3f}")
            print(f"Weights - L0: {l0_weight}, KL: {kl_weight}")
            if l0_weight > 0 and kl_weight > 0:
                if l0_dominance_score > kl_dominance_score * 1.2:
                    print("⚠️ L0 is DOMINATING")
                elif kl_dominance_score > l0_dominance_score * 1.2:
                    print("⚠️ KL is DOMINATING")
                else:
                    print("✅ Metrics are BALANCED")
            elif l0_weight > 0:
                print("✅ L0-ONLY mode")
            elif kl_weight > 0:
                print("✅ KL-ONLY mode")


    def compute_l0_similarity(self, img_features, text_features, tau=0.75):
                                     # (B, D) already normalized in forward()

        img_expanded = img_features.unsqueeze(1)                           # (B,1,D)
        text_expanded = text_features.unsqueeze(0)                         # (1,C,D)

        diff = (img_expanded - text_expanded).abs()                    # (B,C,D)

        mu = diff.mean(dim=2, keepdim=True)
        sigma = diff.std(dim=2, keepdim=True)
        thresholds = mu + 0.5 * sigma                 # tune 0.3–0.8

        # sharp but *not* the same temperature as KL
        temperature = 0.2
        gate = torch.sigmoid((diff-thresholds) / temperature)     # 1 if |Δ| below thr
                 # count "kept" dims (≈ L0 of small diffs)
        kept = gate.sum(dim=2)                               # (B,C)
        sim = kept / diff.size(2)                            # normalize to [0,1]
        logits = 10*sim                                         # raw similarity, will be scaled by learnable parameter
        l0_distance = (diff > thresholds).float().sum(dim=2)
        return logits, l0_distance


    def compute_kl_similarity(self, img_features, text_features):
        """Compute KL divergence-based similarity between image features and prototypes"""
        bsz = img_features.shape[0]
        num_classes = text_features.shape[0]
       
        # Expand dimensions for broadcasting
        img_expanded = img_features.unsqueeze(1)  # (bsz, 1, 2*512)
        text_expanded = text_features.unsqueeze(0)  # (1, num_classes, 2*512)
        # Compute KL divergence: KL(text||img) = sum(text * log(text/img))
        # Add small epsilon to avoid log(0)
        epsilon = 1e-10
        img_dist_safe = torch.clamp(img_expanded, min=epsilon, max=1.0)
        text_dist_safe = torch.clamp(text_expanded, min=epsilon, max=1.0)
        # Compute KL divergence
        kl_div = torch.sum(text_dist_safe * torch.log(text_dist_safe / img_dist_safe), dim=2)
        # Convert KL divergence to similarity (lower KL = higher similarity)
        # Use exponential transformation for better gradient flow
        kl_similarity = 1-kl_div # torch.exp(-kl_div)
        # Raw similarity, will be scaled by learnable parameter
        return kl_similarity, kl_div

    def compute_dot_similarity(self, img_features, text_features):
        """Compute dot product similarity between image features and prototypes"""
        # Normalize prototypes to unit vectors for cosine similarity
        text_features_norm = F.normalize(text_features, p=2, dim=1)
        # Compute dot product (cosine similarity since both are normalized)
        dot_similarities = torch.mm(img_features, text_features_norm.t())
        # Apply temperature scaling to make differences clearer
        temperature = 3.0
        dot_similarities = dot_similarities * temperature
        return dot_similarities

    def initialize_prototypes_from_data(self, dataloader, device):
        """Initialize class prototypes from the centroid of image embeddings for each class"""
        if self._prototypes_initialized:
            print("Prototypes already initialized from data")
            return
        print("Initializing prototypes from data centroids...")
        # Collect embeddings for each class
        class_embeddings = {i: [] for i in range(self.num_classes)}
        self.eval()
        with torch.no_grad():
            for batch_idx, batch_data in enumerate(dataloader):
                if batch_idx >= 100:  # Use first 100 batches for initialization
                    break
                # Handle different batch data formats
                if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 2:
                    inputs, targets = batch_data[0], batch_data[1]
                else:
                    continue
                inputs, targets = inputs.to(device), targets.to(device)
                embeddings = self.backbone(inputs)
                # Group embeddings by class
                for i, target in enumerate(targets):
                    class_idx = target.item()
                    class_embeddings[class_idx].append(embeddings[i])
            # Compute centroids for each class
            centroids = torch.zeros(self.num_classes, self.embedding_dim, device=device)
            valid_classes = 0
            for class_idx in range(self.num_classes):
                if len(class_embeddings[class_idx]) > 0:
                    class_emb = torch.stack(class_embeddings[class_idx])
                    centroid = class_emb.mean(dim=0)
                    centroids[class_idx] = centroid
                    valid_classes += 1
                    print(f"Class {class_idx}: {len(class_embeddings[class_idx])} samples, centroid norm: {torch.norm(centroid, p=2).item():.4f}")
                else:
                    print(f"Warning: No samples found for class {class_idx}")
            # Set prototypes to centroids
          
            self.class_prototypes.data = centroids
            print(f"Initialized {valid_classes}/{self.num_classes} prototypes from data centroids")
            # Verify initial separation
            self._verify_prototype_separation()
            self._prototypes_initialized = True

    def _verify_prototype_separation(self):
        """Verify that prototypes have good initial separation"""
        with torch.no_grad():
            distances = torch.cdist(self.class_prototypes, self.class_prototypes, p=2)
            mask = torch.eye(self.num_classes, device=self.class_prototypes.device)
            off_diagonal = distances * (1 - mask)
            min_separation = off_diagonal.min().item()
            max_separation = off_diagonal.max().item()
            avg_separation = off_diagonal.mean().item()
            print(f"Prototype separation - Min: {min_separation:.4f}, Max: {max_separation:.4f}, Avg: {avg_separation:.4f}")
            if min_separation < 1.0:
                print("⚠️ Warning: Some prototypes are very close together!")

    def get_prototypes(self):
        """Get raw prototypes for regularization (not normalized)"""
        return self.class_prototypes

    def get_image_embeddings(self, x):
        """Extract normalized image embeddings"""
        embeddings = self.backbone(x)
        return F.normalize(embeddings, p=2, dim=1)

    def get_processed_embeddings(self, x):
        """Extract processed embeddings (with to_signed_prob_smooth) for mathematical proof analysis"""
        # Extract image embeddings
        image_embeddings = self.backbone(x)
        # Apply dropout
        image_embeddings = self.dropout1(image_embeddings)
        # Normalize image embeddings
        image_embeddings = F.normalize(image_embeddings, p=2, dim=1)
        # Apply dropout
        image_embeddings = self.dropout2(image_embeddings)
        
        # Per-head feature views
        z_l0 = F.normalize(image_embeddings, p=2, dim=1)
        z_kl = F.normalize(image_embeddings, p=2, dim=1)
        
        # Apply to_signed_prob_smooth to avoid zero elements
        z_l0_processed = F.softmax(z_l0/0.05, dim=1)
        z_kl_processed = F.softmax(z_kl/0.05, dim=1)
        
        # Also process projected prototypes to match dimensions for mathematical proof analysis
        protos_l0_projected = self.class_prototypes
        protos_kl_projected = self.class_prototypes
        protos_l0_processed = F.softmax(protos_l0_projected/0.05, dim=1)
        protos_kl_processed = F.softmax(protos_kl_projected/0.05, dim=1)
        
        return z_l0_processed, z_kl_processed, protos_l0_processed, protos_kl_processed

    def set_training_mode(self, l0_weight, kl_weight,dot_weight):
        """Set the training mode by freezing/unfreezing appropriate projection parameters"""
        if l0_weight == 0.0:
            # KL-only mode: freeze L0 projections
            for param in self.proj_l0.parameters():
                param.requires_grad_(False)
            for param in self.proto_proj_l0.parameters():
                param.requires_grad_(False)
            for param in self.proj_kl.parameters():
                param.requires_grad_(True)
            for param in self.proto_proj_kl.parameters():
                param.requires_grad_(True)
            print("✅ Training mode set to KL-ONLY: L0 frozen, KL trainable")
        elif kl_weight == 0.0:
            # L0-only mode: freeze KL projections
            for param in self.proj_l0.parameters():
                param.requires_grad_(True)
            for param in self.proto_proj_l0.parameters():
                param.requires_grad_(True)
            for param in self.proj_kl.parameters():
                param.requires_grad_(False)
            for param in self.proto_proj_kl.parameters():
                param.requires_grad_(False)
            print("✅ Training mode set to L0-ONLY: L0 trainable, KL frozen")
        else:
            # Combined mode: both trainable
            for param in self.proj_l0.parameters():
                param.requires_grad_(True)
            for param in self.proto_proj_l0.parameters():
                param.requires_grad_(True)
            for param in self.proj_kl.parameters():
                param.requires_grad_(True)
            for param in self.proto_proj_kl.parameters():
                param.requires_grad_(True)
            print("✅ Training mode set to COMBINED: Both L0 and KL trainable")

    def get_training_mode_status(self):
        """Get the current training mode status"""
        l0_trainable = any(p.requires_grad for p in self.proj_l0.parameters())
        kl_trainable = any(p.requires_grad for p in self.proj_kl.parameters())
        
        if l0_trainable and kl_trainable:
            return "COMBINED"
        elif l0_trainable and not kl_trainable:
            return "L0-ONLY"
        elif kl_trainable and not l0_trainable:
            return "KL-ONLY"
        else:
            return "UNKNOWN"

    def predict_with_consensus(self, x, tau=0.75):
        """ Make predictions using consensus between L0 and KL metrics.
        Returns predictions and consensus flags.
        Args:
            x: Input images
            tau: Threshold parameter for L0 similarity computation
        Returns:
            predictions: Predicted class indices (-1 for unknown when no consensus)
            consensus_flags: Boolean flags indicating if L0 and KL agreed
            l0_preds: L0-based predictions
            kl_preds: KL-based predictions
            dot_preds: Dot product-based predictions
        """
        # Get individual similarities
        l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = self.forward(x, return_individual=True, tau=tau)
        # Get predictions from each metric (most likely class)
        l0_preds = torch.argmax(l0_sims, dim=1)
        kl_preds = torch.argmax(kl_sims, dim=1)
        dot_preds = torch.argmax(dot_sims, dim=1)
        # Check if L0 and KL predictions agree
        consensus_flags = (l0_preds == kl_preds)
        # Make predictions: use consensus prediction when available, otherwise unknown (-1)
        predictions = torch.where(consensus_flags, l0_preds, torch.full_like(l0_preds, -1))
        return predictions, consensus_flags, l0_preds, kl_preds, dot_preds

def check_prototype_status(model):
    """Check and display prototype training status"""
    print(f"\n=== PROTOTYPE STATUS ===")
    print(f"Prototypes trainable: {model.class_prototypes.requires_grad}")
    print(f"Prototype gradients: {model.class_prototypes.grad is not None if model.class_prototypes.grad is not None else 'None'}")
    
    with torch.no_grad():
        prototypes = model.class_prototypes
        print(f"Prototype norms: {torch.norm(prototypes, p=2, dim=1)}")
        print(f"Prototype mean norm: {torch.norm(prototypes, p=2, dim=1).mean():.4f}")
    
    return model.class_prototypes.requires_grad


def train_model(model, trainloader, valloader, testloader, device, args):
    """Train the combined L0+KL model"""
    print(f"==> Training Combined L0+KL Model...")
    print(f" L0 weight: {args.l0_weight}, KL weight: {args.kl_weight}, Dot weight: {args.dot_weight}")
    print(f"NEW: Using decorrelated heads with per-head projections for better adversarial robustness")

    # Stage 1: Train prototypes for maximum separation
    print("\n" + "="*60)
    print("STAGE 1: TRAINING PROTOTYPES FOR MAXIMUM SEPARATION")
    print("="*60)
    
    # Freeze all parameters except prototypes
    for name, param in model.named_parameters():
        if 'class_prototypes' in name:
            param.requires_grad_(True)
            print(f"✅ {name}: trainable")
        else:
            param.requires_grad_(False)
            print(f"🔒 {name}: frozen")
    
    # Stage 1 optimizer (only prototypes)
    #prototype_optimizer = optim.SGD([model.class_prototypes], lr=args.lr*10, momentum=0.9, weight_decay=1e-3)
    prototype_optimizer = optim.Adam([model.class_prototypes], lr=args.lr * 10, weight_decay=1e-4)
    prototype_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(prototype_optimizer, T_max=args.epochs // 4)
    
    best_separation = 0
    best_separation_epoch = 0
    patience_counter = 0
    
    for epoch in range(args.epochs // 4):  # First half of epochs for prototype training
        print(f'\nStage 1 - Epoch: {epoch}')
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            prototype_optimizer.zero_grad()
            
            # Forward pass with separation loss only
            l0_sims, kl_sims, combined_sims, separation_loss, l0_dists, kl_dists, dot_sims = model(
                inputs, targets=targets, tau=args.tau, 
                compute_separation=True, margin=args.margin, 
                class_boost=args.class_boost, l0_weight=args.l0_weight, 
                kl_weight=args.kl_weight, dot_weight=args.dot_weight, 
                return_individual=True
            )
            
            # Only use separation loss for prototype training
            total_loss = args.beta * separation_loss
            
            # Logging
            if batch_idx % 50 == 0:
                print(f"Stage 1 - Batch {batch_idx}: Loss: {total_loss.item():.4f} (separation loss: {separation_loss.item():.4f})")
            
            total_loss.backward()
            # Add gradient clipping for stability
            torch.nn.utils.clip_grad_norm_([model.class_prototypes], max_norm=1.0)
        
            prototype_optimizer.step()
            
            # Update metrics
            train_loss += total_loss.item()
            _, predicted = combined_sims.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        
        # Validation for prototype separation
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, targets in valloader:
                inputs, targets = inputs.to(device), targets.to(device)
                l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = model(
                    inputs, tau=args.tau, l0_weight=args.l0_weight, 
                    kl_weight=args.kl_weight, dot_weight=args.dot_weight, 
                    return_individual=True
                )
                _, predicted = combined_sims.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
        
        val_acc = 100. * val_correct / val_total
        print(f'Stage 1 - Validation Accuracy: {val_acc:.2f}%')

        # Check prototype separation
        with torch.no_grad():
            prototypes = model.get_prototypes()
            # Compute dimension-wise separation
            num_classes = prototypes.size(0)
            prototypes_expanded = prototypes.unsqueeze(1)  # (num_classes, 1, embedding_dim)
            prototypes_expanded_t = prototypes.unsqueeze(0)  # (1, num_classes, embedding_dim)
            
            # Absolute difference for each pair of prototypes across all dimensions
            abs_diffs = torch.abs(prototypes_expanded - prototypes_expanded_t)  # (num_classes, num_classes, embedding_dim)
            
            # Remove diagonal (self-comparison) - use proper 3D masking
            mask = torch.eye(num_classes, device=prototypes.device).unsqueeze(-1)  # (num_classes, num_classes, 1)
            off_diagonal_diffs = abs_diffs * (1 - mask)  # (num_classes, num_classes, embedding_dim)
            
            # Find minimum distance across all dimensions for each pair
            min_diffs_per_pair = off_diagonal_diffs.min(dim=2)[0]  # (num_classes, num_classes)
            
            # Remove diagonal elements from min_diffs_per_pair and find overall minimum
            diagonal_mask = torch.eye(num_classes, device=prototypes.device).bool()
            off_diagonal_min_diffs = min_diffs_per_pair.masked_fill(diagonal_mask, float('inf'))
            min_separation = off_diagonal_min_diffs.min().item()
            
            # Calculate max and avg separation for comparison
            max_separation = off_diagonal_min_diffs.max().item()
            avg_separation = off_diagonal_min_diffs.mean().item()
            
            print(f"Stage 1 - Prototype separation (dimension-wise) - Min: {min_separation:.4f}, Max: {max_separation:.4f}, Avg: {avg_separation:.4f}")
            
            if min_separation > best_separation:
                best_separation = min_separation
                best_separation_epoch = epoch
                patience_counter = 0
                print(f"✅ New best separation: {best_separation:.4f}")
                
                # Save best separation checkpoint
                stage1_checkpoint = args.checkpoint.replace('.pth', '_stage1_best.pth')
                state = {
                    'net': model.state_dict(),
                    'separation': best_separation,
                    'epoch': epoch,
                    'val_acc': val_acc,
                    'stage': 'stage1_best_separation'
                }
                torch.save(state, stage1_checkpoint)
                print(f'💾 Saved best separation checkpoint to {stage1_checkpoint}')
            else:
                patience_counter += 1
                if patience_counter >= args.patience:
                    print(f'Early stopping after {patience_counter} epochs without improvement')
                    break
        
        prototype_scheduler.step()
    print(f"\nStage 1 completed! Best separation: {best_separation:.4f}")

     # Stage 2: Freeze prototypes and train L0/KL projections
    print("\n" + "="*60)
    print("STAGE 2: FREEZING PROTOTYPES AND TRAINING L0/KL PROJECTIONS")
    print("="*60)

    # Load best separation checkpoint for Stage 2
    stage1_checkpoint = args.checkpoint.replace('.pth', '_stage1_best.pth')
    if os.path.isfile(stage1_checkpoint):
        print(f"🔄 Loading best separation checkpoint from Stage 1...")
        stage1_state = torch.load(stage1_checkpoint)
        model.load_state_dict(stage1_state['net'])
        print(f"✅ Loaded Stage 1 best separation: {stage1_state['separation']:.4f} at epoch {stage1_state['epoch']}")
    else:
        print(f"⚠️ Stage 1 checkpoint not found, using current prototypes")

    # Freeze prototypes
    model.class_prototypes.requires_grad_(False)
    print("🔒 Prototypes frozen")

    # Freeze prototypes, unfreeze backbone
    for name, param in model.named_parameters():
        if 'backbone' in name:
            param.requires_grad_(True)
            print(f"✅ {name}: trainable")
        else:
            param.requires_grad_(False)
            print(f"🔒 {name}: frozen")

    # Unfreeze L0/KL projections based on weights
    if args.l0_weight > 0:
        model.l0_logit_scale.requires_grad_(True)
    
    if args.kl_weight > 0:
        model.kl_logit_scale.requires_grad_(True)
    
    # Stage 2 optimizer (L0/KL projections only)
    stage2_params = []
    for name, param in model.named_parameters():
        if 'backbone' in name:
            stage2_params.append(param)
    if args.l0_weight > 0:
        stage2_params.append(model.l0_logit_scale)
    
    if args.kl_weight > 0:
        stage2_params.append(model.kl_logit_scale)
    
    stage2_optimizer = optim.SGD(stage2_params, lr=args.lr * 0.1, momentum=0.9, weight_decay=1e-3)
    stage2_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(stage2_optimizer, T_max=args.epochs//4)
    
    best_acc = 0
    patience_counter = 0
    
    for epoch in range(args.epochs // 4, args.epochs):  # Second half of epochs for projection training
        print(f'\nStage 2 - Epoch: {epoch}')
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            stage2_optimizer.zero_grad()
            
            # Forward pass with L0/KL losses
            l0_sims, kl_sims, combined_sims,separation_loss, l0_dists, kl_dists, dot_sims = model(
                inputs, targets=targets, tau=args.tau, 
                compute_separation=True, margin=args.margin, 
                class_boost=args.class_boost, l0_weight=args.l0_weight, 
                kl_weight=args.kl_weight, dot_weight=args.dot_weight, 
                return_individual=True
            )
            
            # Compute L0 and KL losses
            loss_l0 = F.cross_entropy(_norm_logits_for_loss(l0_sims), targets)
            loss_kl = F.cross_entropy(_norm_logits_for_loss(kl_sims), targets)
            loss_dot = F.cross_entropy(_norm_logits_for_loss(dot_sims), targets)
            
            # Weighted combination of L0 and KL losses
            cls_loss = args.l0_weight * loss_l0 + args.kl_weight * loss_kl + args.dot_weight * loss_dot
            
            # Decorrelation loss
            deco = corr_penalty(l0_sims.detach(), kl_sims)
            
            # Total loss (no separation loss since prototypes are frozen)
            total_loss = args.alpha * cls_loss + args.deco_weight * deco
            
            # Orthogonality regularization
            if args.ortho_weight > 0:
                ortho_loss = ortho_reg(model.proj_l0.weight, model.proj_kl.weight)
                total_loss = total_loss
            
            # Logging
            if batch_idx % 50 == 0:
                print(f"Stage 2 - Batch {batch_idx}: Loss: {total_loss.item():.4f} (L0: {loss_l0.item():.4f}, KL: {loss_kl.item():.4f}, Dot: {loss_dot.item():.4f}, deco: {deco.item():.4f})")
            
            total_loss.backward()
            
            stage2_optimizer.step()
            
            # Update metrics
            train_loss += total_loss.item()
            _, predicted = combined_sims.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, targets in valloader:
                inputs, targets = inputs.to(device), targets.to(device)
                l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = model(
                    inputs, tau=args.tau, l0_weight=args.l0_weight, 
                    kl_weight=args.kl_weight, dot_weight=args.dot_weight, 
                    return_individual=True
                )
                _, predicted = combined_sims.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
        
        val_acc = 100. * val_correct / val_total
        print(f'Stage 2 - Validation Accuracy: {val_acc:.2f}%')

        # Save best model
        if val_acc > best_acc:
            print('Saving best model..')
            state = {
                'net': model.state_dict(),
                'acc': val_acc,
                'epoch': epoch,
                'l0_weight': args.l0_weight,
                'kl_weight': args.kl_weight,
                'dot_weight': args.dot_weight,
                'alpha': args.alpha,
                'beta': args.beta,
                'margin': args.margin,
                'tau': args.tau,
                'class_boost': args.class_boost,
                'model_type': 'combined_l0_kl',
                'stage1_separation': best_separation
            }
            torch.save(state, args.checkpoint)
            print(f'Saved best model to {args.checkpoint}')
            best_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= args.patience:
                print(f'Early stopping after {patience_counter} epochs without improvement')
                break
        # Save last epoch checkpoint
        last_epoch = epoch
        last_checkpoint = args.checkpoint.replace('.pth', '_last.pth')
        last_state = {
            'net': model.state_dict(),
            'acc': val_acc,
            'epoch': epoch,
            'l0_weight': args.l0_weight,
            'kl_weight': args.kl_weight,
            'dot_weight': args.dot_weight,
            'alpha': args.alpha,
            'beta': args.beta,
            'margin': args.margin,
            'tau': args.tau,
            'class_boost': args.class_boost,
            'model_type': 'combined_l0_kl',
            'stage1_separation': best_separation
        }
        torch.save(last_state, last_checkpoint)
        print(f'Saved last epoch checkpoint to {last_checkpoint}')
        
        stage2_scheduler.step()
    
    print(f'\nTraining completed!')
    print(f'Stage 1 - Best separation: {best_separation:.4f}')
    print(f'Stage 2 - Best validation accuracy: {best_acc:.2f}%')
    print(f'Last epoch: {last_epoch}, Last checkpoint saved to: {last_checkpoint}')
    
    return best_acc



def evaluate_model(model, testloader, device, args, l0_weight=None, kl_weight=None,dot_weight=None):
    """Evaluate the trained model on test set"""
    print("==> Evaluating model on test set...")
    model.eval()
    # Use weights from checkpoint if not provided
    if l0_weight is None or kl_weight is None:
        print(f"Using weights from checkpoint: L0={l0_weight}, KL={kl_weight}")
    total = 0
    correct = 0
    # Individual metric accuracies
    l0_correct = 0
    kl_correct = 0
    dot_correct = 0
    combined_correct = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            # Get individual similarities and combined output
            l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = model(inputs, return_individual=True, tau=args.tau, l0_weight=l0_weight, kl_weight=kl_weight,dot_weight=dot_weight)
            # Individual predictions
            l0_preds = torch.argmax(l0_sims, dim=1)
            kl_preds = torch.argmax(kl_sims, dim=1)
            dot_preds = torch.argmax(dot_sims, dim=1)
            combined_preds = torch.argmax(combined_sims, dim=1)
            # Count correct predictions
            l0_correct += (l0_preds == targets).sum().item()
            kl_correct += (kl_preds == targets).sum().item()
            dot_correct += (dot_preds == targets).sum().item()
            combined_correct += (combined_preds == targets).sum().item()
            total += targets.size(0)
    # Calculate accuracies
    l0_acc = 100. * l0_correct / total
    kl_acc = 100. * kl_correct / total
    dot_acc = 100. * dot_correct / total
    combined_acc = 100. * combined_correct / total
    print(f'\n=== TEST RESULTS ===')
    print(f'Total test images: {total}')
    print(f'L0 Accuracy: {l0_acc:.2f}% ({l0_correct}/{total})')
    print(f'KL Accuracy: {kl_acc:.2f}% ({kl_correct}/{total})')
    print(f'Dot Accuracy: {dot_acc:.2f}% ({dot_correct}/{total})')
    print(f'Combined Accuracy: {combined_acc:.2f}% ({combined_correct}/{total})')
    print(f'L0 weight: {l0_weight}, KL weight: {kl_weight}')
    
    # Display prototype separation analysis during evaluation
    print(f'\n=== PROTOTYPE SEPARATION ANALYSIS (EVALUATION) ===')
    with torch.no_grad():
        prototypes = model.get_prototypes()
        distances = torch.cdist(prototypes, prototypes, p=2)
        mask = torch.eye(prototypes.size(0), device=prototypes.device)
        off_diagonal = distances * (1 - mask)
        
        print(f"Min separation: {off_diagonal.min().item():.4f}")
        print(f"Max separation: {off_diagonal.max().item():.4f}")
        print(f"Avg separation: {off_diagonal.mean().item():.4f}")
        
        # Show distance matrix (compact format)
        print("Distance Matrix:")
        for i in range(min(5, prototypes.size(0))):  # Show first 5 rows to avoid clutter
            row_str = "  "
            for j in range(min(5, prototypes.size(0))):
                if i == j:
                    row_str += " 0.0000"
                else:
                    row_str += f" {distances[i, j].item():.4f}"
            print(row_str)
        if prototypes.size(0) > 5:
            print(f"  ... (showing first 5x5, total: {prototypes.size(0)}x{prototypes.size(0)})")
        
        # Warning if separation is poor
        if off_diagonal.min().item() < 1.0:
            print("Warning: Some prototypes are very close together!")
        elif off_diagonal.min().item() < 2.0:
            print("Caution: Prototype separation could be improved")
        else:
            print("Good prototype separation maintained")
    print("=" * 60)
    
    return combined_acc, l0_acc, kl_acc, dot_acc

def evaluate_consensus_accuracy(model, testloader, device, args, l0_weight=None, kl_weight=None,dot_weight=None):
    """Evaluate model using consensus prediction with 4-category breakdown (like respective.py)"""
    print("==> Evaluating consensus accuracy with 4-category breakdown...")
    model.eval()
    total = 0
    # Category counters for predictions
    same_correct = 0  # L0 and KL same and correct
    same_incorrect = 0  # L0 and KL same but incorrect
    diff_one_correct = 0  # L0 and KL different but one correct
    diff_both_wrong = 0  # L0 and KL different and both incorrect
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(testloader, desc='Evaluating consensus')):
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size = inputs.size(0)
            # Get individual similarities and predictions
            l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = model(inputs, return_individual=True, tau=args.tau, l0_weight=l0_weight, kl_weight=kl_weight,dot_weight=dot_weight)
            # Get predictions from each metric (most likely class)
            l0_preds = torch.argmax(l0_sims, dim=1)
            kl_preds = torch.argmax(kl_sims, dim=1)
            dot_preds = torch.argmax(dot_sims, dim=1)
            # Categorize predictions
            for i in range(batch_size):
                l0_pred = l0_preds[i].item()
                kl_pred = kl_preds[i].item()
                target = targets[i].item()
                if l0_pred == kl_pred:
                    # Same prediction
                    if l0_pred == target:
                        # Both correct
                        same_correct += 1
                    else:
                        # Both incorrect
                        same_incorrect += 1
                else:
                    # Different predictions
                    if l0_pred == target or kl_pred == target:
                        # One correct
                        diff_one_correct += 1
                    else:
                        # Both incorrect
                        diff_both_wrong += 1
            total += batch_size
    # Calculate percentages
    same_correct_pct = 100. * same_correct / total
    same_incorrect_pct = 100. * same_incorrect / total
    diff_one_correct_pct = 100. * diff_one_correct / total
    diff_both_wrong_pct = 100. * diff_both_wrong / total
    # Calculate consensus rate
    consensus_rate = 100. * (same_correct + same_incorrect) / total
    # Calculate consensus accuracy (only when L0 and KL agree)
    consensus_acc = 100. * same_correct / (same_correct + same_incorrect) if (same_correct + same_incorrect) > 0 else 0
    # Calculate true consensus accuracy (entire test set)
    true_consensus_acc = 100. * same_correct / total
    print(f'\n=== CONSENSUS EVALUATION RESULTS ===')
    print(f'Total test images: {total}')
    print(f'')
    print(f'1. L0 and KL same and correct: {same_correct} ({same_correct_pct:.2f}%)')
    print(f'2. L0 and KL same but incorrect: {same_incorrect} ({same_incorrect_pct:.2f}%)')
    print(f'3. L0 and KL different but one correct: {diff_one_correct} ({diff_one_correct_pct:.2f}%)')
    print(f'4. L0 and KL different and both incorrect: {diff_both_wrong} ({diff_both_wrong_pct:.2f}%)')
    print(f'')
    print(f'Consensus Rate: {consensus_rate:.2f}% (Groups 1+2)')
    print(f'Consensus Accuracy: {consensus_acc:.2f}% (Group 1 / Groups 1+2)')
    print(f'TRUE Consensus Accuracy: {true_consensus_acc:.2f}% (Group 1 / Total)')
    print(f'')
    print(f'Note: "Consensus Accuracy" only counts predictions where L0 and KL agree.')
    print(f' "TRUE Consensus Accuracy" counts all predictions (unknown = incorrect).')
    print(f'Breakdown:')
    print(f' - L0 and KL agree: {same_correct + same_incorrect} images ({consensus_rate:.1f}%)')
    print(f' - L0 and KL disagree: {diff_one_correct + diff_both_wrong} images ({100-consensus_rate:.1f}%)')
    print(f' - When they agree, correct: {same_correct} images ({consensus_acc:.1f}%)')
    print(f' - When they disagree, all marked as unknown (incorrect)')
    print(f' - Total correct: {same_correct} out of {total} images')
    print(f' - TRUE accuracy: {true_consensus_acc:.2f}%')
    
    # Display prototype separation analysis during consensus evaluation
    print(f'\n=== PROTOTYPE SEPARATION ANALYSIS (CONSENSUS EVALUATION) ===')
    with torch.no_grad():
        prototypes = model.get_prototypes()
        distances = torch.cdist(prototypes, prototypes, p=2)
        mask = torch.eye(prototypes.size(0), device=prototypes.device)
        off_diagonal = distances * (1 - mask)
        
        print(f"Min separation: {off_diagonal.min().item():.4f}")
        print(f"Max separation: {off_diagonal.max().item():.4f}")
        print(f"Avg separation: {off_diagonal.mean().item():.4f}")
        
        # Show distance matrix (compact format)
        print("Distance Matrix:")
        for i in range(min(5, prototypes.size(0))):  # Show first 5 rows to avoid clutter
            row_str = "  "
            for j in range(min(5, prototypes.size(0))):
                if i == j:
                    row_str += " 0.0000"
                else:
                    row_str += f" {distances[i, j].item():.4f}"
            print(row_str)
        if prototypes.size(0) > 5:
            print(f"  ... (showing first 5x5, total: {prototypes.size(0)}x{prototypes.size(0)})")
        
        # Warning if separation is poor
        if off_diagonal.min().item() < 1.0:
            print("Warning: Some prototypes are very close together!")
        elif off_diagonal.min().item() < 2.0:
            print("Caution: Prototype separation could be improved")
        else:
            print("Good prototype separation maintained")
    print("=" * 60)
    
    return {
        'same_correct': same_correct,
        'same_incorrect': same_incorrect,
        'diff_one_correct': diff_one_correct,
        'diff_both_wrong': diff_both_wrong,
        'consensus_rate': consensus_rate,
        'consensus_acc': consensus_acc,
        'true_consensus_acc': true_consensus_acc,
        'total': total
    }

def attack_pgd(model, X, target, alpha, attack_iters, norm, device, epsilon=0, restarts=1, tau=0.75, mean=None, std=None):
    """Projected Gradient Descent attack with proper normalized space handling"""
    X = X.to(device)
    target = target.to(device)
    #lower_limit = ((0 - mean) / std).to(device)
    #upper_limit = ((1 - mean) / std).to(device)
    #std_avg = std.mean().item() 
    # scale eps/alpha for normalized space (per-channel), shape (1,3,1,1)
    #epsilon_scaled = epsilon / std_avg
    #alpha_scaled = alpha / std_avg
    epsilon_scaled=epsilon
    alpha_scaled=alpha
    # Initialize perturbation
    delta = torch.zeros_like(X, device=device)
    if norm == "l_inf":
        delta.uniform_(-epsilon_scaled, epsilon_scaled)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0), -1)
        n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r / n * epsilon_scaled
    else:
        raise ValueError(f"Norm {norm} not supported")
    
    # Project to valid bounds
    delta = torch.max(torch.min(delta, upper_limit - X), lower_limit - X)
    delta.requires_grad = True
    
    for _ in range(attack_iters):
        # Forward pass
        # Get L0 and KL similarities from perturbed input
        l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = model(X + delta, return_individual=True, tau=tau)
        # Use combined similarities for attack (has gradients), not detached dot similarities
        # Use combined similarities for attack (has gradients), not detached dot similarities
        loss = F.cross_entropy(dot_sims, target)
        #loss = -torch.mean(torch.abs(l0_sims - dot_sims))
        #output = model(X + delta, tau=tau)
        #loss = F.cross_entropy(output, target)
        
        # Backward pass
        loss.backward()
        
        # Get gradient
        grad = delta.grad.detach()
        d = delta[:, :, :, :]
        g = grad[:, :, :, :]
        x = X[:, :, :, :]
        
        if norm == "l_inf":
            d = torch.clamp(d + alpha_scaled * torch.sign(g), min=-epsilon_scaled, max=epsilon_scaled)
        elif norm == "l_2":
            g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
            scaled_g = g / (g_norm + 1e-10)
            d = (d + scaled_g * alpha_scaled).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon_scaled).view_as(d)
        
        # Project to valid bounds
        d = torch.max(torch.min(d, upper_limit - x), lower_limit - x)
        delta.data[:, :, :, :] = d
        delta.grad.zero_()
    
    return delta



def evaluate_adversarial_robustness(model, testloader, device, args, l0_weight=None, kl_weight=None,dot_weight=None):
    """Evaluate adversarial robustness using consensus prediction with 4-category breakdown and Proposition 4 analysis"""
    print("==> Evaluating adversarial robustness...")
    
    # Get clean predictions directly instead of calling full consensus evaluation
    print("==> Getting clean predictions for baseline...")
    model.eval()
    total = 0
    clean_same_correct = 0
    clean_same_incorrect = 0
    clean_diff_one_correct = 0
    clean_diff_both_wrong = 0
    
    # Track individual metric accuracies for clean data
    clean_l0_correct = 0
    clean_kl_correct = 0
    clean_dot_correct = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(testloader, desc='Clean baseline')):
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size = inputs.size(0)
            # Get individual similarities and predictions
            l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = model(inputs, return_individual=True, tau=args.tau, l0_weight=l0_weight, kl_weight=kl_weight,dot_weight=dot_weight)
            # Get predictions from each metric
            l0_preds = torch.argmax(l0_sims, dim=1)
            kl_preds = torch.argmax(kl_sims, dim=1)
            dot_preds = torch.argmax(dot_sims, dim=1)
            # Count individual metric accuracies
            clean_l0_correct += (l0_preds == targets).sum().item()
            clean_kl_correct += (kl_preds == targets).sum().item()
            clean_dot_correct += (dot_preds == targets).sum().item()
            # Categorize clean predictions
            for i in range(batch_size):
                l0_pred = l0_preds[i].item()
                kl_pred = kl_preds[i].item()
                target = targets[i].item()
                if l0_pred == kl_pred:
                    if l0_pred == target:
                        clean_same_correct += 1
                    else:
                        clean_same_incorrect += 1
                else:
                    if l0_pred == target or kl_pred == target:
                        clean_diff_one_correct += 1
                    else:
                        clean_diff_both_wrong += 1
            total += batch_size
    
    # Calculate clean percentages
    clean_same_correct_pct = 100. * clean_same_correct / total
    clean_same_incorrect_pct = 100. * clean_same_incorrect / total
    clean_diff_one_correct_pct = 100. * clean_diff_one_correct / total
    clean_diff_both_wrong_pct = 100. * clean_diff_both_wrong / total
    clean_consensus_rate = 100. * (clean_same_correct + clean_same_incorrect) / total
    clean_consensus_acc = 100. * clean_same_correct / (clean_same_correct + clean_same_incorrect) if (clean_same_correct + clean_same_incorrect) > 0 else 0
    clean_true_consensus_acc = 100. * clean_same_correct / total
    
    # Calculate individual metric accuracies for clean data
    clean_l0_acc = 100. * clean_l0_correct / total
    clean_kl_acc = 100. * clean_kl_correct / total
    clean_dot_acc = 100. * clean_dot_correct / total
    
    # Calculate detailed metric combination percentages for clean data
    clean_l0_dot_same_correct = 0
    clean_l0_dot_same_incorrect = 0
    clean_l0_dot_diff_correct = 0
    clean_l0_dot_diff_incorrect = 0
    
    clean_kl_dot_same_correct = 0
    clean_kl_dot_same_incorrect = 0
    clean_kl_dot_diff_correct = 0
    clean_kl_dot_diff_incorrect = 0
    
    clean_kl_l0_dot_same_correct = 0
    clean_kl_l0_dot_same_incorrect = 0
    clean_kl_l0_dot_diff_correct = 0
    clean_kl_l0_dot_diff_incorrect = 0
    
    # Re-calculate clean metric combinations
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size = inputs.size(0)
            l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = model(inputs, return_individual=True, tau=args.tau, l0_weight=l0_weight, kl_weight=kl_weight,dot_weight=dot_weight)
            l0_preds = torch.argmax(l0_sims, dim=1)
            kl_preds = torch.argmax(kl_sims, dim=1)
            dot_preds = torch.argmax(dot_sims, dim=1)
            
            for i in range(batch_size):
                l0_pred = l0_preds[i].item()
                kl_pred = kl_preds[i].item()
                dot_pred = dot_preds[i].item()
                target = targets[i].item()
                
                # L0 & Dot combinations
                if l0_pred == dot_pred:
                    if l0_pred == target:
                        clean_l0_dot_same_correct += 1
                    else:
                        clean_l0_dot_same_incorrect += 1
                else:
                    if l0_pred == target or dot_pred == target:
                        clean_l0_dot_diff_correct += 1
                    else:
                        clean_l0_dot_diff_incorrect += 1
                
                # KL & Dot combinations
                if kl_pred == dot_pred:
                    if kl_pred == target:
                        clean_kl_dot_same_correct += 1
                    else:
                        clean_kl_dot_same_incorrect += 1
                else:
                    if kl_pred == target or dot_pred == target:
                        clean_kl_dot_diff_correct += 1
                    else:
                        clean_kl_dot_diff_incorrect += 1
                
                # KL & L0 & Dot combinations (all three agree)
                if l0_pred == kl_pred == dot_pred:
                    if l0_pred == target:
                        clean_kl_l0_dot_same_correct += 1
                    else:
                        clean_kl_l0_dot_same_incorrect += 1
                else:
                    # At least one disagrees
                    correct_count = sum([l0_pred == target, kl_pred == target, dot_pred == target])
                    if correct_count >= 1:
                        clean_kl_l0_dot_diff_correct += 1
                    else:
                        clean_kl_l0_dot_diff_incorrect += 1
    
    # Adversarial evaluation with 4-category breakdown
    # Category counters for adversarial predictions
    adv_same_correct = 0  # L0 and KL same and correct
    adv_same_incorrect = 0  # L0 and KL same but incorrect
    adv_diff_one_correct = 0  # L0 and KL different but one correct
    adv_diff_both_wrong = 0  # L0 and KL different and both incorrect
    
    # Track individual metric accuracies for adversarial data
    adv_l0_correct = 0
    adv_kl_correct = 0
    adv_dot_correct = 0
    
    # Track detailed metric combination percentages for adversarial data
    adv_l0_dot_same_correct = 0
    adv_l0_dot_same_incorrect = 0
    adv_l0_dot_diff_correct = 0
    adv_l0_dot_diff_incorrect = 0
    
    adv_kl_dot_same_correct = 0
    adv_kl_dot_same_incorrect = 0
    adv_kl_dot_diff_correct = 0
    adv_kl_dot_diff_incorrect = 0
    
    adv_kl_l0_dot_same_correct = 0
    adv_kl_l0_dot_same_incorrect = 0
    adv_kl_l0_dot_diff_correct = 0
    adv_kl_l0_dot_diff_incorrect = 0
    
    # Track Equation (46) fulfillment for each case
    case_a_equation_fulfilled = 0  # Case a: same & incorrect
    case_a_total = 0
    case_c_equation_fulfilled = 0  # Case c: different & one correct
    case_c_total = 0
    case_d_equation_fulfilled = 0  # Case d: different & both wrong
    case_d_total = 0
    
    print(f"\n==> Generating adversarial examples with PGD attack...")
    print(f"Attack parameters: epsilon={args.attack_epsilon:.6f}, steps=10, norm={args.norm}")
    model.eval()
    
    for batch_idx, (inputs, targets) in enumerate(tqdm(testloader, desc='Adversarial evaluation')):
        inputs, targets = inputs.to(device), targets.to(device)
        batch_size = inputs.size(0)
        
        # Generate adversarial examples with PGD
        delta = attack_pgd(model, inputs, targets, args.attack_stepsize, 10, args.norm, device, args.attack_epsilon, tau=args.tau, mean=CIFAR_MEAN, std=CIFAR_STD)
        
        # Get clean predictions for this batch
        with torch.no_grad():
            l0_sims_clean, kl_sims_clean, combined_sims_clean, l0_dists_clean, kl_dists_clean, dot_sims_clean = model(inputs, return_individual=True, tau=args.tau, l0_weight=l0_weight, kl_weight=kl_weight,dot_weight=dot_weight)
            l0_preds_clean = torch.argmax(l0_sims_clean, dim=1)
            kl_preds_clean = torch.argmax(kl_sims_clean, dim=1)
            dot_preds_clean = torch.argmax(dot_sims_clean, dim=1)
        
        # Adversarial consensus prediction
        with torch.no_grad():
            l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = model(inputs + delta, return_individual=True, tau=args.tau, l0_weight=l0_weight, kl_weight=kl_weight,dot_weight=dot_weight)
            # Get predictions from each metric
            l0_preds = torch.argmax(l0_sims, dim=1)
            kl_preds = torch.argmax(kl_sims, dim=1)
            dot_preds = torch.argmax(dot_sims, dim=1)
            
            # Count individual metric accuracies for adversarial data
            adv_l0_correct += (l0_preds == targets).sum().item()
            adv_kl_correct += (kl_preds == targets).sum().item()
            adv_dot_correct += (dot_preds == targets).sum().item()
            
            # Calculate detailed metric combinations for adversarial data
            for i in range(batch_size):
                l0_pred = l0_preds[i].item()
                kl_pred = kl_preds[i].item()
                dot_pred = dot_preds[i].item()
                target = targets[i].item()
                
                # L0 & Dot combinations
                if l0_pred == dot_pred:
                    if l0_pred == target:
                        adv_l0_dot_same_correct += 1
                    else:
                        adv_l0_dot_same_incorrect += 1
                else:
                    if l0_pred == target or dot_pred == target:
                        adv_l0_dot_diff_correct += 1
                    else:
                        adv_l0_dot_diff_incorrect += 1
                
                # KL & Dot combinations
                if kl_pred == dot_pred:
                    if kl_pred == target:
                        adv_kl_dot_same_correct += 1
                    else:
                        adv_kl_dot_same_incorrect += 1
                else:
                    if kl_pred == target or dot_pred == target:
                        adv_kl_dot_diff_correct += 1
                    else:
                        adv_kl_dot_diff_incorrect += 1
                
                # KL & L0 & Dot combinations (all three agree)
                if l0_pred == kl_pred == dot_pred:
                    if l0_pred == target:
                        adv_kl_l0_dot_same_correct += 1
                    else:
                        adv_kl_l0_dot_same_incorrect += 1
                else:
                    # At least one disagrees
                    correct_count = sum([l0_pred == target, kl_pred == target, dot_pred == target])
                    if correct_count >= 1:
                        adv_kl_l0_dot_diff_correct += 1
                    else:
                        adv_kl_l0_dot_diff_incorrect += 1
            
            # Categorize adversarial predictions
            for i in range(batch_size):
                l0_pred = l0_preds[i].item()
                kl_pred = kl_preds[i].item()
                target = targets[i].item()
                
                # Check if this sample was correct before attack
                l0_correct_before = (l0_preds_clean[i].item() == target)
                kl_correct_before = (kl_preds_clean[i].item() == target)
                both_correct_before = l0_correct_before and kl_correct_before
                
                if l0_pred == kl_pred:
                    if l0_pred == target:
                        adv_same_correct += 1
                    else:
                        adv_same_incorrect += 1
                        
                        # Case a: L0 and KL same but incorrect
                        if both_correct_before:
                            
                            print(f"\n  Case a analysis for sample {batch_idx * batch_size + i}:")
                            print(f"    Target: {target}, Pred: {l0_pred}")
                            print(f"    Was correct before attack: L0={l0_correct_before}, KL={kl_correct_before}")
                            
                            # Find conflicting class (the predicted class)
                            pred_class = l0_pred
                            
                            # Get KL distances (not similarities) for mathematical proof analysis
                            KL_y_prime_p_star = kl_dists_clean[i, pred_class]
                            KL_y_star_p_star = kl_dists_clean[i, target]
                            KL_y_prime_p_prime = kl_dists[i, pred_class]
                            KL_y_star_p_prime = kl_dists[i, target]
                            
                            # Get L0 distances (not similarities) for mathematical proof analysis
                            L0_y_prime_p_star = l0_dists_clean[i, pred_class]
                            L0_y_star_p_star = l0_dists_clean[i, target]
                            L0_y_prime_p_prime = l0_dists[i, pred_class]
                            L0_y_star_p_prime = l0_dists[i, target]

                            # Check mathematical conditions before analysis
                            Delta_KL_p_star = KL_y_prime_p_star - KL_y_star_p_star
                            Delta_KL_p_prime = KL_y_prime_p_prime - KL_y_star_p_prime
                            Delta_L0_p_star = L0_y_prime_p_star - L0_y_star_p_star
                            
                            # Only analyze if conditions are met
                            if Delta_KL_p_star > 0 and Delta_KL_p_prime < 0 and Delta_L0_p_star > 0:
                                case_a_total += 1
                                
                                
                                # Get prototypes and embeddings for analysis
                                # Get processed embeddings and prototypes to avoid zero elements and match dimensions
                                p_star_l0, p_star_kl, protos_l0, protos_kl = model.get_processed_embeddings(inputs[i:i+1])  # Clean image embeddings
                                p_prime_l0, p_prime_kl, _, _ = model.get_processed_embeddings(inputs[i:i+1] + delta[i:i+1])  # Adversarial image embeddings
                                # Use KL embeddings and prototypes for mathematical proof analysis
                                p_star = p_star_kl[0]  # Clean image embedding (KL)
                                p_prime = p_prime_kl[0]  # Adversarial image embedding (KL)
                                y_star = protos_kl[target]  # Target class prototype (processed)
                                y_prime = protos_kl[pred_class]  # Conflicting class prototype (processed)
                            
                                # Check if Equation (46) is fulfilled
                                if args.math_proof:
                                    output = math_proof_analysis(
                                        target, pred_class, y_star, y_prime, p_star, p_prime, args.attack_epsilon,
                                        KL_y_prime_p_star, KL_y_star_p_star, KL_y_prime_p_prime, KL_y_star_p_prime,
                                        L0_y_prime_p_star, L0_y_star_p_star, L0_y_prime_p_prime, L0_y_star_p_prime, args
                                    )
                                else:
                                    output = None
                                if output is None:
                                    print(f"    Sample skipped due to invalid mathematical conditions")
                                elif output:
                                    print(f"    ✅ Equation (46) fulfilled!")
                                    case_a_equation_fulfilled += 1
                                else:
                                    print(f"    ❌ Equation (46) not fulfilled")
                            else:
                                print(f"\n  Case a sample {batch_idx * batch_size + i} skipped - mathematical conditions not met:")
                                print(f"    Delta_KL_p_star: {Delta_KL_p_star:.6f} (need > 0)")
                                print(f"    Delta_KL_p_prime: {Delta_KL_p_prime:.6f} (need < 0)")
                                print(f"    Delta_L0_p_star: {Delta_L0_p_star:.6f} (need > 0)")
                else:
                    if l0_pred == target or kl_pred == target:
                        adv_diff_one_correct += 1
                        
                        # Case c: L0 and KL different but one correct
                        if both_correct_before:
                            # Determine which prediction is correct and which is conflicting
                            if l0_pred == target:
                                correct_pred = l0_pred
                                pred_class = kl_pred
                                correct_metric = "L0"
                            else:
                                correct_pred = kl_pred
                                pred_class = l0_pred
                                correct_metric = "KL"

                            print(f"    {correct_metric} is correct ({correct_pred}), {pred_class} is conflicting class")
                        
                            
                            # Get KL distances (not similarities) for mathematical proof analysis
                            KL_y_prime_p_star = kl_dists_clean[i, pred_class]
                            KL_y_star_p_star = kl_dists_clean[i, target]
                            KL_y_prime_p_prime = kl_dists[i, pred_class]
                            KL_y_star_p_prime = kl_dists[i, target]
                            
                            # Get L0 distances (not similarities) for mathematical proof analysis
                            L0_y_prime_p_star = l0_dists_clean[i, pred_class]
                            L0_y_star_p_star = l0_dists_clean[i, target]
                            L0_y_prime_p_prime = l0_dists[i, pred_class]
                            L0_y_star_p_prime = l0_dists[i, target]

                            # Check mathematical conditions before analysis
                            Delta_KL_p_star = KL_y_prime_p_star - KL_y_star_p_star
                            Delta_KL_p_prime = KL_y_prime_p_prime - KL_y_star_p_prime
                            Delta_L0_p_star = L0_y_prime_p_star - L0_y_star_p_star
                                

                            # Only analyze if conditions are met
                            if Delta_KL_p_star > 0 and Delta_KL_p_prime < 0 and Delta_L0_p_star > 0:
                                case_c_total += 1
                                
                                print(f"\n  Case c analysis for sample {batch_idx * batch_size + i}:")
                                print(f"    Target: {target}, L0 Pred: {l0_pred}, KL Pred: {kl_pred}")
                                print(f"    Was correct before attack: L0={l0_correct_before}, KL={kl_correct_before}")
                        
                                
                                # Get prototypes and embeddings for analysis
                                # Get processed embeddings and prototypes to avoid zero elements and match dimensions
                                p_star_l0, p_star_kl, protos_l0, protos_kl = model.get_processed_embeddings(inputs[i:i+1])  # Clean image embeddings
                                p_prime_l0, p_prime_kl, _, _ = model.get_processed_embeddings(inputs[i:i+1] + delta[i:i+1])  # Adversarial image embeddings
                                # Use KL embeddings and prototypes for mathematical proof analysis
                                p_star = p_star_kl[0]  # Clean image embedding (KL)
                                p_prime = p_prime_kl[0]  # Adversarial image embedding (KL)
                                y_star = protos_kl[target]  # Target class prototype (processed)
                                y_prime = protos_kl[pred_class]  # Conflicting class prototype (processed)
                                
                                # Check if Equation (46) is fulfilled
                                if args.math_proof:
                                    output = math_proof_analysis(
                                        target, pred_class, y_star, y_prime, p_star, p_prime, args.attack_epsilon,
                                        KL_y_prime_p_star, KL_y_star_p_star, KL_y_prime_p_prime, KL_y_star_p_prime,
                                        L0_y_prime_p_star, L0_y_star_p_star, L0_y_prime_p_prime, L0_y_star_p_prime, args
                                    )
                                else:
                                    output = None
                                
                                if output is None:
                                    print(f"    Sample skipped due to invalid mathematical conditions")
                                elif output:
                                    print(f"    ✅ Equation (46) fulfilled!")
                                    case_c_equation_fulfilled += 1
                                else:
                                    print(f"    ❌ Equation (46) not fulfilled")
                            else:
                                print(f"\n  Case c sample {batch_idx * batch_size + i} skipped - mathematical conditions not met:")
                                print(f"    Delta_KL_p_star: {Delta_KL_p_star:.6f} (need > 0)")
                                print(f"    Delta_KL_p_prime: {Delta_KL_p_prime:.6f} (need < 0)")
                                print(f"    Delta_L0_p_star: {Delta_L0_p_star:.6f} (need > 0)")
                        
                    else:
                        adv_diff_both_wrong += 1
                        
                        # Case d: L0 and KL different and both wrong
                        if both_correct_before:
                            # Find the most likely conflicting class (highest similarity)
                            l0_conf_sim = l0_sims[i, l0_pred]
                            kl_conf_sim = kl_sims[i, kl_pred]
                            
                            if l0_conf_sim > kl_conf_sim:
                                pred_class = l0_pred
                                print(f"    Using L0 prediction {pred_class} as conflicting class (higher similarity)")
                            else:
                                pred_class = kl_pred
                                print(f"    Using KL prediction {pred_class} as conflicting class (higher similarity)")
                            
                            # Get KL distances (not similarities) for mathematical proof analysis
                            KL_y_prime_p_star = kl_dists_clean[i, pred_class]
                            KL_y_star_p_star = kl_dists_clean[i, target]
                            KL_y_prime_p_prime = kl_dists[i, pred_class]
                            KL_y_star_p_prime = kl_dists[i, target]
                            
                            # Get L0 distances (not similarities) for mathematical proof analysis
                            L0_y_prime_p_star = l0_dists_clean[i, pred_class]
                            L0_y_star_p_star = l0_dists_clean[i, target]
                            L0_y_prime_p_prime = l0_dists[i, pred_class]
                            L0_y_star_p_prime = l0_dists[i, target]

                            # Check mathematical conditions before analysis
                            Delta_KL_p_star = KL_y_prime_p_star - KL_y_star_p_star
                            Delta_KL_p_prime = KL_y_prime_p_prime - KL_y_star_p_prime
                            Delta_L0_p_star = L0_y_prime_p_star - L0_y_star_p_star
                            
                            # Only analyze if conditions are met
                            if Delta_KL_p_star > 0 and Delta_KL_p_prime < 0 and Delta_L0_p_star > 0:
                                case_d_total += 1
                                print(f"\n  Case d analysis for sample {batch_idx * batch_size + i}:")
                                print(f"    Target: {target}, L0 Pred: {l0_pred}, KL Pred: {kl_pred}")
                                print(f"    Was correct before attack: L0={l0_correct_before}, KL={kl_correct_before}")
                                
                                # Get features for analysis
                                # Get processed embeddings and prototypes to avoid zero elements and match dimensions
                                p_star_l0, p_star_kl, protos_l0, protos_kl = model.get_processed_embeddings(inputs[i:i+1])  # Clean image embeddings
                                p_prime_l0, p_prime_kl, _, _ = model.get_processed_embeddings(inputs[i:i+1] + delta[i:i+1])  # Adversarial image embeddings
                                # Use KL embeddings and prototypes for mathematical proof analysis
                                p_star = p_star_kl[0]  # Clean image embedding (KL)
                                p_prime = p_prime_kl[0]  # Adversarial image embedding (KL)
                                y_star = protos_kl[target]  # Target class prototype (processed)
                                y_prime = protos_kl[pred_class]  # Conflicting class prototype (processed)
                                
                            
                                # Check if Equation (46) is fulfilled
                                if args.math_proof:
                                    output = math_proof_analysis(
                                        target, pred_class, y_star, y_prime, p_star, p_prime, args.attack_epsilon,
                                        KL_y_prime_p_star, KL_y_star_p_star, KL_y_prime_p_prime, KL_y_star_p_prime,
                                        L0_y_prime_p_star, L0_y_star_p_star, L0_y_prime_p_prime, L0_y_star_p_prime, args
                                    )
                                else:
                                    output = None
                                
                                if output is None:
                                    print(f"    Sample skipped due to invalid mathematical conditions")
                                elif output:
                                    print(f"    ✅ Equation (46) fulfilled!")
                                    case_d_equation_fulfilled += 1
                                else:
                                    print(f"    ❌ Equation (46) not fulfilled")
                            else:
                                print(f"\n  Case d sample {batch_idx * batch_size + i} skipped - mathematical conditions not met:")
                                print(f"    Delta_KL_p_star: {Delta_KL_p_star:.6f} (need > 0)")
                                print(f"    Delta_KL_p_prime: {Delta_KL_p_prime:.6f} (need < 0)")
                                print(f"    Delta_L0_p_star: {Delta_L0_p_star:.6f} (need > 0)")
    
    # Calculate adversarial percentages
    adv_same_correct_pct = 100. * adv_same_correct / total
    adv_same_incorrect_pct = 100. * adv_same_incorrect / total
    adv_diff_one_correct_pct = 100. * adv_diff_one_correct / total
    adv_diff_both_wrong_pct = 100. * adv_diff_both_wrong / total
    # Calculate adversarial consensus rates
    adv_consensus_rate = 100. * (adv_same_correct + adv_same_incorrect) / total
    # Calculate adversarial accuracies
    adv_consensus_acc = 100. * adv_same_correct / (adv_same_correct + adv_same_incorrect) if (adv_same_correct + adv_same_incorrect) > 0 else 0
    adv_true_consensus_acc = 100. * adv_same_correct / total
    
    # Calculate individual metric accuracies for adversarial data
    adv_l0_acc = 100. * adv_l0_correct / total
    adv_kl_acc = 100. * adv_kl_correct / total
    adv_dot_acc = 100. * adv_dot_correct / total
    
    # Calculate detailed metric combination percentages
    # Clean data percentages
    clean_l0_dot_same_correct_pct = 100. * clean_l0_dot_same_correct / total
    clean_l0_dot_same_incorrect_pct = 100. * clean_l0_dot_same_incorrect / total
    clean_l0_dot_diff_correct_pct = 100. * clean_l0_dot_diff_correct / total
    clean_l0_dot_diff_incorrect_pct = 100. * clean_l0_dot_diff_incorrect / total
    
    clean_kl_dot_same_correct_pct = 100. * clean_kl_dot_same_correct / total
    clean_kl_dot_same_incorrect_pct = 100. * clean_kl_dot_same_incorrect / total
    clean_kl_dot_diff_correct_pct = 100. * clean_kl_dot_diff_correct / total
    clean_kl_dot_diff_incorrect_pct = 100. * clean_kl_dot_diff_incorrect / total
    
    clean_kl_l0_dot_same_correct_pct = 100. * clean_kl_l0_dot_same_correct / total
    clean_kl_l0_dot_same_incorrect_pct = 100. * clean_kl_l0_dot_same_incorrect / total
    clean_kl_l0_dot_diff_correct_pct = 100. * clean_kl_l0_dot_diff_correct / total
    clean_kl_l0_dot_diff_incorrect_pct = 100. * clean_kl_l0_dot_diff_incorrect / total
    
    # Adversarial data percentages
    adv_l0_dot_same_correct_pct = 100. * adv_l0_dot_same_correct / total
    adv_l0_dot_same_incorrect_pct = 100. * adv_l0_dot_same_incorrect / total
    adv_l0_dot_diff_correct_pct = 100. * adv_l0_dot_diff_correct / total
    adv_l0_dot_diff_incorrect_pct = 100. * adv_l0_dot_diff_incorrect / total
    
    adv_kl_dot_same_correct_pct = 100. * adv_kl_dot_same_correct / total
    adv_kl_dot_same_incorrect_pct = 100. * adv_kl_dot_same_incorrect / total
    adv_kl_dot_diff_correct_pct = 100. * adv_kl_dot_diff_correct / total
    adv_kl_dot_diff_incorrect_pct = 100. * adv_kl_dot_diff_incorrect / total
    
    adv_kl_l0_dot_same_correct_pct = 100. * adv_kl_l0_dot_same_correct / total
    adv_kl_l0_dot_same_incorrect_pct = 100. * adv_kl_l0_dot_same_incorrect / total
    adv_kl_l0_dot_diff_correct_pct = 100. * adv_kl_l0_dot_diff_correct / total
    adv_kl_l0_dot_diff_incorrect_pct = 100. * adv_kl_l0_dot_diff_incorrect / total
    
    # Calculate robustness drops
    consensus_robustness_drop = clean_consensus_acc - adv_consensus_acc
    true_consensus_robustness_drop = clean_true_consensus_acc - adv_true_consensus_acc
    l0_robustness_drop = clean_l0_acc - adv_l0_acc
    kl_robustness_drop = clean_kl_acc - adv_kl_acc
    dot_robustness_drop = clean_dot_acc - adv_dot_acc
    print(f'\n=== ADVERSARIAL ROBUSTNESS RESULTS ===')
    print(f'Attack: PGD with ε={args.attack_epsilon:.6f}, {args.norm} norm')
    print(f'')
    print(f'CLEAN PREDICTIONS:')
    print(f'1. L0 and KL same and correct: {clean_same_correct} ({clean_same_correct_pct:.2f}%)')
    print(f'2. L0 and KL same but incorrect: {clean_same_incorrect} ({clean_same_incorrect_pct:.2f}%)')
    print(f'3. L0 and KL different but one correct: {clean_diff_one_correct} ({clean_diff_one_correct_pct:.2f}%)')
    print(f'4. L0 and KL different and both incorrect: {clean_diff_both_wrong} ({clean_diff_both_wrong_pct:.2f}%)')
    print(f'Clean Consensus Rate: {clean_consensus_rate:.2f}%')
    print(f'Clean Consensus Accuracy: {clean_consensus_acc:.2f}%')
    print(f'Clean TRUE Consensus Accuracy: {clean_true_consensus_acc:.2f}%')
    print(f'')
    print(f'INDIVIDUAL METRIC PERFORMANCE (CLEAN):')
    print(f'L0 Accuracy: {clean_l0_acc:.2f}% ({clean_l0_correct}/{total})')
    print(f'KL Accuracy: {clean_kl_acc:.2f}% ({clean_kl_correct}/{total})')
    print(f'Dot Accuracy: {clean_dot_acc:.2f}% ({clean_dot_correct}/{total})')
    print(f'')
    
    print(f'DETAILED METRIC COMBINATIONS (CLEAN):')
    print(f'L0 & Dot - Same & Correct: {clean_l0_dot_same_correct_pct:.2f}% ({clean_l0_dot_same_correct}/{total})')
    print(f'L0 & Dot - Same & Incorrect: {clean_l0_dot_same_incorrect_pct:.2f}% ({clean_l0_dot_same_incorrect}/{total})')
    print(f'L0 & Dot - Different & Correct: {clean_l0_dot_diff_correct_pct:.2f}% ({clean_l0_dot_diff_correct}/{total})')
    print(f'L0 & Dot - Different & Incorrect: {clean_l0_dot_diff_incorrect_pct:.2f}% ({clean_l0_dot_diff_incorrect}/{total})')
    print(f'')
    print(f'KL & Dot - Same & Correct: {clean_kl_dot_same_correct_pct:.2f}% ({clean_kl_dot_same_correct}/{total})')
    print(f'KL & Dot - Same & Incorrect: {clean_kl_dot_same_incorrect_pct:.2f}% ({clean_kl_dot_same_incorrect}/{total})')
    print(f'KL & Dot - Different & Correct: {clean_kl_dot_diff_correct_pct:.2f}% ({clean_kl_dot_diff_correct}/{total})')
    print(f'KL & Dot - Different & Incorrect: {clean_kl_dot_diff_incorrect_pct:.2f}% ({clean_kl_dot_diff_incorrect}/{total})')
    print(f'')
    print(f'KL & L0 & Dot - Same & Correct: {clean_kl_l0_dot_same_correct_pct:.2f}% ({clean_kl_l0_dot_same_correct}/{total})')
    print(f'KL & L0 & Dot - Same & Incorrect: {clean_kl_l0_dot_same_incorrect_pct:.2f}% ({clean_kl_l0_dot_same_incorrect}/{total})')
    print(f'KL & L0 & Dot - Different & Correct: {clean_kl_l0_dot_diff_correct_pct:.2f}% ({clean_kl_l0_dot_diff_correct}/{total})')
    print(f'KL & L0 & Dot - Different & Incorrect: {clean_kl_l0_dot_diff_incorrect_pct:.2f}% ({clean_kl_l0_dot_diff_incorrect}/{total})')
    print(f'')
    print(f'ADVERSARIAL PREDICTIONS:')
    print(f'1. L0 and KL same and correct: {adv_same_correct} ({adv_same_correct_pct:.2f}%)')
    print(f'2. L0 and KL same but incorrect: {adv_same_incorrect} ({adv_same_incorrect_pct:.2f}%)')
    print(f'3. L0 and KL different but one correct: {adv_diff_one_correct} ({adv_diff_one_correct_pct:.2f}%)')
    print(f'4. L0 and KL different and both incorrect: {adv_diff_both_wrong} ({adv_diff_both_wrong_pct:.2f}%)')
    print(f'Adversarial Consensus Rate: {adv_consensus_rate:.2f}%')
    print(f'Adversarial Consensus Accuracy: {adv_consensus_acc:.2f}%')
    print(f'Adversarial TRUE Consensus Accuracy: {adv_true_consensus_acc:.2f}%')
    print(f'')
    print(f'INDIVIDUAL METRIC PERFORMANCE (ADVERSARIAL):')
    print(f'L0 Accuracy: {adv_l0_acc:.2f}% ({adv_l0_correct}/{total})')
    print(f'KL Accuracy: {adv_kl_acc:.2f}% ({adv_kl_correct}/{total})')
    print(f'Dot Accuracy: {adv_dot_acc:.2f}% ({adv_dot_correct}/{total})')
    print(f'')
    
    print(f'DETAILED METRIC COMBINATIONS (ADVERSARIAL):')
    print(f'L0 & Dot - Same & Correct: {adv_l0_dot_same_correct_pct:.2f}% ({adv_l0_dot_same_correct}/{total})')
    print(f'L0 & Dot - Same & Incorrect: {adv_l0_dot_same_incorrect_pct:.2f}% ({adv_l0_dot_same_incorrect}/{total})')
    print(f'L0 & Dot - Different & Correct: {adv_l0_dot_diff_correct_pct:.2f}% ({adv_l0_dot_diff_correct}/{total})')
    print(f'L0 & Dot - Different & Incorrect: {adv_l0_dot_diff_incorrect_pct:.2f}% ({adv_l0_dot_diff_incorrect}/{total})')
    print(f'')
    print(f'KL & Dot - Same & Correct: {adv_kl_dot_same_correct_pct:.2f}% ({adv_kl_dot_same_correct}/{total})')
    print(f'KL & Dot - Same & Incorrect: {adv_kl_dot_same_incorrect_pct:.2f}% ({adv_kl_dot_same_incorrect}/{total})')
    print(f'KL & Dot - Different & Correct: {adv_kl_dot_diff_correct_pct:.2f}% ({adv_kl_dot_diff_correct}/{total})')
    print(f'KL & Dot - Different & Incorrect: {adv_kl_dot_diff_incorrect_pct:.2f}% ({adv_kl_dot_diff_incorrect}/{total})')
    print(f'')
    print(f'KL & L0 & Dot - Same & Correct: {adv_kl_l0_dot_same_correct_pct:.2f}% ({adv_kl_l0_dot_same_correct}/{total})')
    print(f'KL & L0 & Dot - Same & Incorrect: {adv_kl_l0_dot_same_incorrect_pct:.2f}% ({adv_kl_l0_dot_same_incorrect}/{total})')
    print(f'KL & L0 & Dot - Different & Correct: {adv_kl_l0_dot_diff_correct_pct:.2f}% ({adv_kl_l0_dot_diff_correct}/{total})')
    print(f'KL & L0 & Dot - Different & Incorrect: {adv_kl_l0_dot_diff_incorrect_pct:.2f}% ({adv_kl_l0_dot_diff_incorrect}/{total})')
    print(f'')
    print(f'=== ROBUSTNESS ANALYSIS ===')
    print(f'Consensus Robustness Drop: {consensus_robustness_drop:.2f}%')
    print(f'TRUE Consensus Robustness Drop: {true_consensus_robustness_drop:.2f}%')
    print(f'L0 Robustness Drop: {l0_robustness_drop:.2f}%')
    print(f'KL Robustness Drop: {kl_robustness_drop:.2f}%')
    print(f'Dot Robustness Drop: {dot_robustness_drop:.2f}%')
    print(f'')
    print(f'KEY INSIGHT: Adversarial accuracy = 1 - (Group 2 rate) = {100-adv_same_incorrect_pct:.2f}%')
    print(f'This means: When L0 and KL agree under attack, they are correct {100-adv_same_incorrect_pct:.2f}% of the time')
    print(f'')
    
    print(f'=== METRIC COMBINATION CHANGES (CLEAN → ADVERSARIAL) ===')
    print(f'L0 & Dot Changes:')
    print(f'  Same & Correct: {clean_l0_dot_same_correct_pct:.2f}% → {adv_l0_dot_same_correct_pct:.2f}% (Δ{adv_l0_dot_same_correct_pct-clean_l0_dot_same_correct_pct:+.2f}%)')
    print(f'  Same & Incorrect: {clean_l0_dot_same_incorrect_pct:.2f}% → {adv_l0_dot_same_incorrect_pct:.2f}% (Δ{adv_l0_dot_same_incorrect_pct-clean_l0_dot_same_incorrect_pct:+.2f}%)')
    print(f'  Different & Correct: {clean_l0_dot_diff_correct_pct:.2f}% → {adv_l0_dot_diff_correct_pct:.2f}% (Δ{adv_l0_dot_diff_correct_pct-clean_l0_dot_diff_correct_pct:+.2f}%)')
    print(f'  Different & Incorrect: {clean_l0_dot_diff_incorrect_pct:.2f}% → {adv_l0_dot_diff_incorrect_pct:.2f}% (Δ{adv_l0_dot_diff_incorrect_pct-clean_l0_dot_diff_incorrect_pct:+.2f}%)')
    print(f'')
    print(f'KL & Dot Changes:')
    print(f'  Same & Correct: {clean_kl_dot_same_correct_pct:.2f}% → {adv_kl_dot_same_correct_pct:.2f}% (Δ{adv_kl_dot_same_correct_pct-clean_kl_dot_same_correct_pct:+.2f}%)')
    print(f'  Same & Incorrect: {clean_kl_dot_same_incorrect_pct:.2f}% → {adv_kl_dot_same_incorrect_pct:.2f}% (Δ{adv_kl_dot_same_incorrect_pct-clean_kl_dot_same_incorrect_pct:+.2f}%)')
    print(f'  Different & Correct: {clean_kl_dot_diff_correct_pct:.2f}% → {adv_kl_dot_diff_correct_pct:.2f}% (Δ{adv_kl_dot_diff_correct_pct-clean_kl_dot_diff_correct_pct:+.2f}%)')
    print(f'  Different & Incorrect: {clean_kl_dot_diff_incorrect_pct:.2f}% → {adv_kl_dot_diff_incorrect_pct:.2f}% (Δ{adv_kl_dot_diff_incorrect_pct-clean_kl_dot_diff_incorrect_pct:+.2f}%)')
    print(f'')
    print(f'KL & L0 & Dot Changes:')
    print(f'  Same & Correct: {clean_kl_l0_dot_same_correct_pct:.2f}% → {adv_kl_l0_dot_same_correct_pct:.2f}% (Δ{adv_kl_l0_dot_same_correct_pct-clean_kl_l0_dot_same_correct_pct:+.2f}%)')
    print(f'  Same & Incorrect: {clean_kl_l0_dot_same_incorrect_pct:.2f}% → {adv_kl_l0_dot_same_incorrect_pct:.2f}% (Δ{adv_kl_l0_dot_same_incorrect_pct-clean_kl_l0_dot_same_incorrect_pct:+.2f}%)')
    print(f'  Different & Correct: {clean_kl_l0_dot_diff_correct_pct:.2f}% → {adv_kl_l0_dot_diff_correct_pct:.2f}% (Δ{adv_kl_l0_dot_diff_correct_pct-clean_kl_l0_dot_diff_correct_pct:+.2f}%)')
    print(f'  Different & Incorrect: {clean_kl_l0_dot_diff_incorrect_pct:.2f}% → {adv_kl_l0_dot_diff_incorrect_pct:.2f}% (Δ{adv_kl_l0_dot_diff_incorrect_pct-clean_kl_l0_dot_diff_incorrect_pct:+.2f}%)')
    
    # Print Equation (46) fulfillment analysis
    print(f'\n=== EQUATION (46) FULFILLMENT ANALYSIS ===')
    print(f'Proposition 4: KL(y\', p\') - KL(y*, p\') > KL(y\', p*) - KL(y*, p*)')
    print(f'')
    
    if case_a_total > 0:
        case_a_fulfillment_percentage = (case_a_equation_fulfilled / case_a_total) * 100
        print(f'Case a (L0 and KL same, not correct): {case_a_equation_fulfilled}/{case_a_total} = {case_a_fulfillment_percentage:.2f}% fulfill Equation (46)')
    else:
        print(f'Case a: No samples with both L0 and KL correct before attack')
    
    if case_c_total > 0:
        case_c_fulfillment_percentage = (case_c_equation_fulfilled / case_c_total) * 100
        print(f'Case c (L0 and KL different, one correct): {case_c_equation_fulfilled}/{case_c_total} = {case_c_fulfillment_percentage:.2f}% fulfill Equation (46)')
    else:
        print(f'Case c: No samples with both L0 and KL correct before attack')
    
    if case_d_total > 0:
        case_d_fulfillment_percentage = (case_d_equation_fulfilled / case_d_total) * 100
        print(f'Case d (L0 and KL different, both wrong): {case_d_equation_fulfilled}/{case_d_total} = {case_d_fulfillment_percentage:.2f}% fulfill Equation (46)')
    else:
        print(f'Case d: No samples with both L0 and KL correct before attack')
    
    # Overall fulfillment statistics
    total_conflicting_samples = case_a_total + case_c_total + case_d_total
    total_fulfilled = case_a_equation_fulfilled + case_c_equation_fulfilled + case_d_equation_fulfilled
    if total_conflicting_samples > 0:
        overall_fulfillment_percentage = (total_fulfilled / total_conflicting_samples) * 100
        print(f'')
        print(f'Overall: {total_fulfilled}/{total_conflicting_samples} = {overall_fulfillment_percentage:.2f}% of samples fulfill Equation (46)')
        print(f'This indicates the theoretical robustness guarantee from Proposition 4 is satisfied for {overall_fulfillment_percentage:.2f}% of cases')
    else:
        print(f'')
        print(f'No samples were correct before attack, skipping Equation (46) analysis')
    
    return {
        'clean_results': {
            'same_correct': clean_same_correct,
            'same_incorrect': clean_same_incorrect,
            'diff_one_correct': clean_diff_one_correct,
            'diff_both_wrong': clean_diff_both_wrong,
            'consensus_rate': clean_consensus_rate,
            'consensus_acc': clean_consensus_acc,
            'true_consensus_acc': clean_true_consensus_acc,
            'l0_acc': clean_l0_acc,
            'kl_acc': clean_kl_acc,
            'dot_acc': clean_dot_acc,
            'total': total
        },
        'adv_same_correct': adv_same_correct,
        'adv_same_incorrect': adv_same_incorrect,
        'adv_diff_one_correct': adv_diff_one_correct,
        'adv_diff_both_wrong': adv_diff_both_wrong,
        'adv_consensus_rate': adv_consensus_rate,
        'adv_consensus_acc': adv_consensus_acc,
        'adv_true_consensus_acc': adv_true_consensus_acc,
        'adv_l0_acc': adv_l0_acc,
        'adv_kl_acc': adv_kl_acc,
        'adv_dot_acc': adv_dot_acc,
        'consensus_robustness_drop': consensus_robustness_drop,
        'true_consensus_robustness_drop': true_consensus_robustness_drop,
        'l0_robustness_drop': l0_robustness_drop,
        'kl_robustness_drop': kl_robustness_drop,
        'dot_robustness_drop': dot_robustness_drop,
        'adversarial_accuracy': 100 - adv_same_incorrect_pct,
        'equation_46_analysis': {
            'case_a_fulfilled': case_a_equation_fulfilled,
            'case_a_total': case_a_total,
            'case_c_fulfilled': case_c_equation_fulfilled,
            'case_c_total': case_c_total,
            'case_d_fulfilled': case_d_equation_fulfilled,
            'case_d_total': case_d_total,
            'total_fulfilled': total_fulfilled,
            'total_conflicting_samples': total_conflicting_samples
        }
    }

def main():
    parser = argparse.ArgumentParser(description='Train CIFAR-10 with Combined L0+KL Prototypes')
    parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
    parser.add_argument('--epochs', type=int, default=20, help='number of training epochs')
    parser.add_argument('--alpha', type=float, default=1.0, help='weight for classification loss')
    parser.add_argument('--beta', type=float, default=1.0, help='weight for separation regularization')
    parser.add_argument('--margin', type=float, default=0.1, help='minimum distance between prototypes')
    parser.add_argument('--tau', type=float, default=0.75, help='threshold parameter for L0 similarity')
    parser.add_argument('--class_boost', type=float, default=0.5, help='boost value for correct class similarities')
    parser.add_argument('--patience', type=int, default=10, help='early stopping patience')
    parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet18', 'resnet50', 'vgg19', 'densenet121'], help='backbone architecture')
    parser.add_argument('--pretrained', action='store_true', help='use pretrained ImageNet weights for backbone')
    parser.add_argument('--checkpoint', type=str, default='./checkpoint_combined/ckpt_combined.pth', help='checkpoint path')
    parser.add_argument('--train', action='store_true', help='train the model')
    parser.add_argument('--evaluate', action='store_true', help='evaluate trained model')
    parser.add_argument('--consensus_eval', action='store_true', help='evaluate using consensus prediction (4-category breakdown)')
    parser.add_argument('--eval_robustness', action='store_true', help='evaluate adversarial robustness with PGD attack')
    parser.add_argument('--attack_epsilon', type=int, default=2, help='perturbation budget for adversarial attack (will be divided by 255)')
    parser.add_argument('--attack_stepsize', type=int, default=2, help='attack step size for adversarial attack (will be divided by 255)')
    parser.add_argument('--norm', type=str, default='l_inf', choices=['l_inf', 'l_2'], help='norm for adversarial attack')
    parser.add_argument('--l0_weight', type=float, default=1.0, help='weight for L0 similarity (0.0 to 1.0)')
    parser.add_argument('--kl_weight', type=float, default=0.0, help='weight for KL similarity (0.0 to 1.0)')
    parser.add_argument('--dropout_rate', type=float, default=0.3, help='dropout rate')
    # ADD: Model loading arguments for continuing training from pre-trained models
    parser.add_argument('--l0_load_model', type=str, help='load L0 model from checkpoint to continue training')
    parser.add_argument('--kl_load_model', type=str, help='load KL model from checkpoint to continue training')
    parser.add_argument('--load_model', type=str, help='load standard model checkpoint (from main.py) and convert to prototype architecture')
    parser.add_argument('--freeze_prototypes', action='store_true', 
                       help='freeze prototypes after initialization and only train encoder')
    # Add math proof argument
    parser.add_argument('--math_proof', action='store_true', help='run mathematical proof analysis during adversarial evaluation')
    # Add decorrelation and orthogonality arguments
    parser.add_argument('--deco_weight', type=float, default=0.0, #0.15,
                        help='weight for L0/KL logit decorrelation')
    parser.add_argument('--ortho_weight', type=float, default=0.0, #1e-3,
                        help='weight for proj_l0/proj_kl orthogonality')
    parser.add_argument('--dot_weight', type=float, default=1.0, help='weight for dot similarity (0.0 to 1.0)')
    parser.add_argument('--use_last_epoch', action='store_true', help='use last epoch checkpoint')
    
    
    args = parser.parse_args()

    # Validate weights
    if args.l0_weight < 0 or args.kl_weight < 0:
        raise ValueError("Weights must be non-negative")
    if args.l0_weight == 0 and args.kl_weight == 0:
        raise ValueError("At least one weight must be greater than 0")
    # Normalize weights if both are > 0
    if args.l0_weight > 0 and args.kl_weight > 0:
        total_weight = args.l0_weight + args.kl_weight
        args.l0_weight /= total_weight
        args.kl_weight /= total_weight
        print(f"Normalized weights: L0={args.l0_weight:.3f}, KL={args.kl_weight:.3f}")
    # Convert attack parameters from integer to float (divide by 255)
    args.attack_epsilon = args.attack_epsilon / 255.0
    args.attack_stepsize = args.attack_stepsize / 255.0
    print(f"Attack parameters: epsilon={args.attack_epsilon:.6f}, stepsize={args.attack_stepsize:.6f}, norm={args.norm}")

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'Using device: {device}')

    # Data preparation
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True, num_workers=2)
    # Split test set into validation and test sets
    testset_full = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    torch.manual_seed(42)
    test_size = len(testset_full)
    val_size = test_size // 2
    test_size = test_size - val_size
    valset, testset = torch.utils.data.random_split(testset_full, [val_size, test_size])
    valloader = torch.utils.data.DataLoader(
        valset, batch_size=100, shuffle=False, num_workers=2)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2)
    print(f'Training set: {len(trainset)} images')
    print(f'Validation set: {len(valset)} images')
    print(f'Test set: {len(testset)} images')

    # Build backbone
    print('==> Building backbone..')
    if args.pretrained:
        print('==> Loading pretrained ImageNet weights for backbone..')
        if args.backbone == 'resnet18':
            backbone = torchvision.models.resnet18(pretrained=True)
        elif args.backbone == 'resnet50':
            backbone = torchvision.models.resnet50(pretrained=True)
        elif args.backbone == 'vgg19':
            backbone = torchvision.models.vgg19(pretrained=True)
        elif args.backbone == 'densenet121':
            backbone = torchvision.models.densenet121(pretrained=True)
        else:
            backbone = ResNet18()
        # Remove the final classification layer
        if hasattr(backbone, 'fc'):
            backbone.fc = nn.Identity()
        elif hasattr(backbone, 'classifier'):
            backbone.classifier = nn.Identity()
        print(f'==> Loaded pretrained {args.backbone} backbone')
    else:
        print('==> Using random initialization for backbone..')
        if args.backbone == 'resnet18':
            backbone = ResNet18()
            backbone.linear = nn.Identity()
        elif args.backbone == 'vgg19':
            backbone = VGG('VGG19')
            backbone.classifier = nn.Identity()
        elif args.backbone == 'densenet121':
            backbone = DenseNet121()
            backbone.classifier = nn.Identity()
        else:
            backbone = ResNet18()
            backbone.linear = nn.Identity()
    # Replace ReLU with GELU
    print('==> Replacing ReLU activations with GELU...')
    replace_activations(backbone, nn.GELU())
    print('==> Activation replacement completed')

    # ADD: Load pretrained weights if specified (same as respective file)
    if args.load_model:
        print(f'==> Loading model from {args.load_model}..')
        if os.path.isfile(args.load_model):
            checkpoint = torch.load(args.load_model)
            # Load backbone weights (excluding final layer)
            backbone_state_dict = {}
            for key, value in checkpoint['net'].items():
                if key.startswith('module.'):
                    clean_key = key[7:]
                else:
                    clean_key = key
                if not clean_key.startswith('linear.') and not clean_key.startswith('classifier.'):
                    backbone_state_dict[clean_key] = value
            backbone.load_state_dict(backbone_state_dict)
            print(f'==> Loaded backbone weights from checkpoint')
        else:
            print(f'==> No model file found at {args.load_model}')
            print('==> Starting with random initialization')

    # Create model
    model = DynamicPrototypeModelL0KL(backbone, num_classes=10, embedding_dim=512, dropout_rate=args.dropout_rate)
    model = model.to(device)

    # ADD: Load L0 model weights if specified
    if args.l0_load_model:
        print(f'==> Loading L0 model from {args.l0_load_model}..')
        if os.path.isfile(args.l0_load_model):
            l0_checkpoint = torch.load(args.l0_load_model)
            model.load_state_dict(l0_checkpoint['net'])
            print(f'==> Loaded L0 model with accuracy: {l0_checkpoint["acc"]:.2f}%')
            print(f'==> Continuing training from L0 checkpoint')
        else:
            print(f'==> L0 checkpoint not found at {args.l0_load_model}')
            print('==> Starting with random initialization')

    # ADD: Load KL model weights if specified
    if args.kl_load_model:
        print(f'==> Loading KL model from {args.kl_load_model}..')
        if os.path.isfile(args.kl_load_model):
            kl_checkpoint = torch.load(args.kl_load_model)
            model.load_state_dict(kl_checkpoint['net'])
            print(f'==> Loaded KL model with accuracy: {kl_checkpoint["acc"]:.2f}%')
            print(f'==> Continuing training from KL checkpoint')
        else:
            print(f'==> KL checkpoint not found at {args.kl_load_model}')
            print('==> Starting with random initialization')

    # Initialize prototypes from data centroids (only if not loading from checkpoint)
    if not args.l0_load_model and not args.kl_load_model:
        print("==> Initializing prototypes from data centroids...")
        model.initialize_prototypes_from_data(trainloader, device)

        # Freeze prototypes if requested
        if args.freeze_prototypes:
            model.freeze_prototypes()
            print("✅ Prototypes initialized and frozen")
        else:
            print("✅ Prototypes initialized and trainable")
    else:
        print("==> Using prototypes from loaded checkpoint (skipping initialization)")
        if args.freeze_prototypes:
            model.freeze_prototypes()

    # Training
    if args.train:
        print("==> Starting training...")
        print(f"Training mode: L0 weight={args.l0_weight}, KL weight={args.kl_weight}, Dot weight={args.dot_weight}")
        print(f"Prototype training: {'FROZEN' if args.freeze_prototypes else 'TRAINABLE'}")
        check_prototype_status(model)
        # Create checkpoint directory
        os.makedirs(os.path.dirname(args.checkpoint), exist_ok=True)
        # Train the model
        best_acc = train_model(model, trainloader, valloader, testloader, device, args)
        print(f"Training completed with best accuracy: {best_acc:.2f}%")

    # Evaluation
    if args.evaluate:
        print("==> Loading trained model for evaluation...")
        # Choose checkpoint based on argument
        if args.use_last_epoch:
            checkpoint_path = args.checkpoint.replace('.pth', '_last.pth')
            print(f"Using last epoch checkpoint: {checkpoint_path}")
        else:
            checkpoint_path = args.checkpoint
            print(f"Using best model checkpoint: {checkpoint_path}")

        if os.path.isfile(checkpoint_path):
            checkpoint = torch.load(checkpoint_path)
            model.load_state_dict(checkpoint['net'])
            print(f"Loaded model with accuracy: {checkpoint['acc']:.2f}%")
            # Load weights from checkpoint automatically
            if 'l0_weight' in checkpoint and 'kl_weight' in checkpoint:
                l0_weight = checkpoint['l0_weight']
                kl_weight = checkpoint['kl_weight']
                # Check if dot_weight exists in checkpoint, otherwise use default
                if 'dot_weight' in checkpoint:
                    dot_weight = checkpoint['dot_weight']
                    print(f"Loaded weights from checkpoint: L0={l0_weight}, KL={kl_weight}, Dot={dot_weight}")
                else:
                    dot_weight = args.dot_weight
                    print(f"Loaded weights from checkpoint: L0={l0_weight}, KL={kl_weight}")
                    print(f"Using command line dot_weight: {dot_weight} (not found in checkpoint)")
            else:
                # Fallback to command line arguments if not in checkpoint
                l0_weight = args.l0_weight
                kl_weight = args.kl_weight
                dot_weight = args.dot_weight
                print(f"Using command line weights: L0={l0_weight}, KL={kl_weight}, Dot={dot_weight}")
                print(f"Note: Consider retraining to save weights in checkpoint")
            # Evaluate on test set
            combined_acc, l0_acc, kl_acc, dot_acc = evaluate_model(model, testloader, device, args, l0_weight, kl_weight,dot_weight)
            if args.consensus_eval:
                evaluate_consensus_accuracy(model, testloader, device, args, l0_weight, kl_weight,dot_weight)
            if args.eval_robustness:
                evaluate_adversarial_robustness(model, testloader, device, args, l0_weight, kl_weight,dot_weight)
        else:
            print(f"Checkpoint not found at {checkpoint_path}")
        return

if __name__ == '__main__':
    main()