import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Import math proof analysis function
from math_part import math_proof_analysis

# Try to import AutoAttack
try:
    from autoattack import AutoAttack
    AUTOATTACK_AVAILABLE = True
    print("AutoAttack imported successfully")
except ImportError:
    AUTOATTACK_AVAILABLE = False
    print("Warning: AutoAttack not available. Install with: pip install autoattack")

from models import *
try:
    from utils import progress_bar
    PROGRESS_BAR_AVAILABLE = True
except (ValueError, OSError):
    PROGRESS_BAR_AVAILABLE = False
    print("Warning: Progress bar not available, using simple progress display")

def progress_bar(current, total, msg=None):
    if current % 10 == 0 or current == total - 1:
        print(f"Progress: {current+1}/{total} {msg or ''}")

# Adversarial attack parameters
# CIFAR-10 normalization constants
CIFAR_MEAN = torch.tensor([0.4914, 0.4822, 0.4465]).view(1, 3, 1, 1)
CIFAR_STD = torch.tensor([0.2023, 0.1994, 0.2010]).view(1, 3, 1, 1)
lower_limit, upper_limit = 0, 1

def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)

def corr_penalty(a, b):
    """Compute Pearson correlation penalty between two logit tensors"""
    # a, b: (B, C) logits
    a = a - a.mean(dim=1, keepdim=True)
    b = b - b.mean(dim=1, keepdim=True)
    num = (a * b).sum(dim=1).mean()
    den = (a.norm(dim=1) * b.norm(dim=1)).mean() + 1e-8
    return num / den  # Pearson correlation in [-1,1]

def ortho_reg(W1, W2):
    """Compute orthogonality regularization between two weight matrices"""
    M = W1.t() @ W2
    return (M * M).mean()

def _norm_logits_for_loss(z):
    """Standardize logits for loss computation to equalize gradient scales"""
    z = z - z.mean(dim=1, keepdim=True)
    z = z / (z.std(dim=1, keepdim=True) + 1e-8)
    return z
def replace_activations(module: nn.Module, new_act: nn.Module):
    """Replace ReLU activations with a new activation function"""
    for name, child in module.named_children():
        # torchvision ResNet uses nn.ReLU(inplace=True)
        if isinstance(child, nn.ReLU):
            setattr(module, name, new_act.__class__(**new_act.__dict__['_parameters']))
        else:
            replace_activations(child, new_act)
class DynamicPrototypeModelL0KL(nn.Module):
    def __init__(self, backbone, num_classes=10, embedding_dim=512, dropout_rate=0.3):
        super().__init__()
        self.backbone = backbone
        # Initialize prototypes as learnable parameters (will be set to centroids later)
        self.class_prototypes = nn.Parameter(torch.randn(num_classes, embedding_dim) * 0.1)
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self._prototypes_initialized = False
        # ADD DROPOUT LAYERS TO COMBAT OVERFITTING
        self.dropout1 = nn.Dropout(dropout_rate)  # After backbone
        self.dropout2 = nn.Dropout(dropout_rate)  # Before similarity computation
        
        # Two lightweight per-head projections (image feature views)
        self.proj_l0 = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
        self.proj_kl = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
        # Mild init for L0: almost identity -> easy learning
        with torch.no_grad():
            self.proj_l0.weight.zero_()
            eye = torch.eye(self.embedding_dim, device=self.proj_l0.weight.device)
            self.proj_l0.weight.add_(eye)                            # identity
            self.proj_l0.weight.add_(0.01 * torch.randn_like(eye))   # tiny noise


        nn.init.orthogonal_(self.proj_kl.weight, gain=1.0)
        
        # Prototype projections for different geometric views
        self.proto_proj_l0 = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
        self.proto_proj_kl = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
        nn.init.orthogonal_(self.proto_proj_l0.weight, gain=1.0)
        nn.init.orthogonal_(self.proto_proj_kl.weight, gain=1.0)
        
        # Learnable temperature scales for each head
        self.l0_logit_scale = nn.Parameter(torch.tensor(1.0))
        self.kl_logit_scale = nn.Parameter(torch.tensor(1.0))
        
        # Initialize with helpful bump so L0 starts with wider spread
        with torch.no_grad():
            self.l0_logit_scale.copy_(torch.tensor(6.0))  # try 4–12
            self.kl_logit_scale.copy_(torch.tensor(1.0))
        
        # Per-head dropout fields for slight stochastic diversity during training
        self.dropout_l0 = nn.Dropout(p=dropout_rate)         # already have dropout_rate arg
        self.dropout_kl = nn.Dropout(p=dropout_rate * 0.7)   # asymmetric on purpose
        self._prototypes_frozen = False
    def freeze_prototypes(self):
        """Freeze prototype parameters"""
        # Fix: class_prototypes is a Parameter, not a module
        self.class_prototypes.requires_grad_(False)
        self._prototypes_frozen = True
        print("🔒 Prototypes frozen - no further updates")
    
    def unfreeze_prototypes(self):
        """Unfreeze prototype parameters"""
        # Fix: class_prototypes is a Parameter, not a module
        self.class_prototypes.requires_grad_(True)
        self._prototypes_frozen = False
        print("🔄 Prototypes unfrozen - can be updated")
    
    def get_prototype_status(self):
        """Get current prototype training status"""
        return {
            'frozen': self._prototypes_frozen,
            'requires_grad': self.class_prototypes.requires_grad,
            'norm': torch.norm(self.class_prototypes, p=2, dim=1).mean().item()
        }
    def forward(self, x, targets=None, return_individual=False, tau=0.75, compute_separation=False, margin=0.2, class_boost=0.5, l0_weight=1.0, kl_weight=0.0,dot_weight=0.0):
        # Extract image embeddings
        image_embeddings = self.backbone(x)
        # ADD DROPOUT 1: After backbone extraction
        image_embeddings = self.dropout1(image_embeddings)
        # Normalize image embeddings only (keep prototypes as raw vectors for separation)
        image_embeddings = F.normalize(image_embeddings, p=2, dim=1)
        # ADD DROPOUT 2: Before similarity computation
        image_embeddings = self.dropout2(image_embeddings)
        
        # Per-head feature views (do NOT change prototypes)
        z_l0 = F.normalize(image_embeddings, p=2, dim=1)
        z_kl = F.normalize(image_embeddings, p=2, dim=1)
        

        # Apply to_signed_prob_smooth to avoid zero elements for KL divergence
        z_l0_processed = F.softmax(z_l0/0.05, dim=1)  # Already normalized above
        z_kl_processed = F.softmax(z_kl/0.05, dim=1)  # Already normalized above

        protos_l0 = self.class_prototypes
        protos_l0= F.softmax(protos_l0/0.05, dim=1)
        protos_kl = self.class_prototypes
        protos_kl = F.softmax(protos_kl/0.05, dim=1)
        # Compute L0 and KL distances based on weights - use projected prototypes
        if l0_weight > 0:
            l0_similarities, l0_distances = self.compute_l0_similarity(z_l0_processed, protos_l0, tau=tau)  # Use raw image embeddings for L0
        else:
            l0_similarities = torch.zeros(image_embeddings.size(0), self.num_classes, device=image_embeddings.device)
            l0_distances = torch.zeros(image_embeddings.size(0), self.num_classes, device=image_embeddings.device)
            
        if kl_weight > 0:
            kl_similarities, kl_distances = self.compute_kl_similarity(z_kl_processed, protos_kl)  # Use processed z_kl for KL
        else:

            kl_similarities = torch.zeros(z_kl.size(0), self.num_classes, device=z_kl.device)
            kl_distances = torch.zeros(z_kl.size(0), self.num_classes, device=z_kl.device)
        
        # Compute dot product similarity as true baseline (detached, no projections)
        with torch.no_grad():
            base_feats = F.normalize(image_embeddings, p=2, dim=1)   # no proj, no dropout
            base_protos = F.normalize(self.class_prototypes, p=2, dim=1)
            dot_similarities = base_feats @ base_protos.t()


        if targets is not None and l0_weight > 0:
            batch_size = l0_similarities.size(0)
            class_indices = targets.view(-1, 1)
            # Create one-hot encoding for correct classes
            one_hot = torch.zeros_like(l0_similarities)
            one_hot.scatter_(1, class_indices, 1)

        # Apply learnable temperature scales
        l0_normalized = l0_similarities * self.l0_logit_scale
        kl_normalized = kl_similarities * self.kl_logit_scale

        if return_individual:
            self._log_metric_dominance(l0_normalized, kl_normalized, l0_weight, kl_weight)
            # Debug L0 similarities and gradients
            if not hasattr(self, '_l0_debug_counter'):
                self._l0_debug_counter = 0
            self._l0_debug_counter += 1
            if self._l0_debug_counter % 100 == 0:
                print(f"\n=== Similarity Debug (Batch {self._l0_debug_counter}) ===")
                print(f"Raw L0 similarities - Min: {l0_similarities.min().item():.6f}, Max: {l0_similarities.max().item():.6f}")
                print(f"Normalized L0 - Min: {l0_normalized.min().item():.6f}, Max: {l0_normalized.max().item():.6f}")
                print(f"KL similarities - Min: {kl_similarities.min().item():.6f}, Max: {kl_similarities.max().item():.6f}")
                print(f"Dot similarities - Min: {dot_similarities.min().item():.6f}, Max: {dot_similarities.max().item():.6f}")
                print(f"L0 predictions: {torch.argmax(l0_normalized, dim=1)[:5].tolist()}")
                print(f"KL predictions: {torch.argmax(kl_normalized, dim=1)[:5].tolist()}")
                print(f"Dot predictions: {torch.argmax(dot_similarities, dim=1)[:5].tolist()}")
                print(f"L0 weight: {l0_weight}, KL weight: {kl_weight}")
                print(f"Using projected features: L0 (z_l0), KL (z_kl), Dot (z_l0), shared prototypes")
                print("=" * 50)

        combined_similarities = l0_weight * l0_normalized + kl_weight * kl_normalized+dot_weight*dot_similarities


        if l0_weight == 1.0 and kl_weight == 0.0:
            combined_similarities = l0_normalized
            if not hasattr(self, '_l0_only_warning_shown'):
                print(f"🚀 L0-ONLY MODE ACTIVATED: Using only L0 similarities")
                print(f" This should behave exactly like respective L0 training")
                print(f" L0 weight: {l0_weight}, KL weight: {kl_weight}")
                self._l0_only_warning_shown = True

        elif l0_weight == 0.0 and kl_weight == 1.0:
            combined_similarities = kl_normalized
            if not hasattr(self, '_kl_only_warning_shown'):
                print(f"🚀 KL-ONLY MODE ACTIVATED: Using only KL similarities")
                print(f" This should behave exactly like respective L0 training")
                print(f" L0 weight: {l0_weight}, KL weight: {kl_weight}")
                self._kl_only_warning_shown = True

        # Compute separation loss in forward pass if requested
        # Use forced_prototypes ONLY for separation loss where it's necessary to maintain gradients
        separation_loss = None
        if compute_separation:
            # Create forced_prototypes ONLY when needed for separation loss
            # This minimizes the extra operations while preserving gradients for separation
            forced_prototypes = self.class_prototypes + 0.0
            separation_loss = self._compute_separation_loss(forced_prototypes, margin)

        if return_individual:
            if compute_separation:
                return l0_normalized, kl_normalized, combined_similarities, separation_loss, l0_distances, kl_distances, dot_similarities
            else:
                return l0_normalized, kl_normalized, combined_similarities, l0_distances, kl_distances, dot_similarities
        if compute_separation:
            return combined_similarities, separation_loss
        else:
            return combined_similarities

    def _compute_separation_loss(self, class_prototypes, margin=0.2):
        """Compute separation loss directly in forward pass to maintain computation graph"""
        num_classes = class_prototypes.size(0)
        # Compute L2 distances directly between all prototypes
        distances = torch.cdist(class_prototypes, class_prototypes, p=1)
        diag = torch.eye(distances.size(0), device=distances.device).bool()

        # For the loss:
        off_for_min = distances.masked_fill(diag, float('inf'))
        min_distances_per_class = off_for_min.min(dim=1).values
        violations_per_class = F.relu(margin - min_distances_per_class)
        separation_loss = violations_per_class.sum()

        return separation_loss

    def _log_metric_dominance(self, l0_normalized, kl_normalized, l0_weight, kl_weight):
        """Log statistics to check which metric dominates"""
        # Calculate variance of each metric across classes
        l0_variance = torch.var(l0_normalized, dim=1).mean().item()
        kl_variance = torch.var(kl_normalized, dim=1).mean().item()
        # Calculate range of each metric across classes
        l0_range = (l0_normalized.max(dim=1)[0] - l0_normalized.min(dim=1)[0]).mean().item()
        kl_range = (kl_normalized.max(dim=1)[0] - kl_normalized.min(dim=1)[0]).mean().item()
        # Calculate standard deviation
        l0_std = torch.std(l0_normalized, dim=1).mean().item()
        kl_std = torch.std(kl_normalized, dim=1).mean().item()
        # Determine dominance
        l0_dominance_score = l0_variance + l0_range + l0_std
        kl_dominance_score = kl_variance + kl_range + kl_std
        # Log every 100 batches to avoid spam
        if not hasattr(self, '_batch_counter'):
            self._batch_counter = 0
        self._batch_counter += 1
        if self._batch_counter % 100 == 0:
            print(f"L0 vs KL Dominance - L0: {l0_dominance_score:.3f}, KL: {kl_dominance_score:.3f}")
            print(f"Weights - L0: {l0_weight}, KL: {kl_weight}")
            if l0_weight > 0 and kl_weight > 0:
                if l0_dominance_score > kl_dominance_score * 1.2:
                    print("⚠️ L0 is DOMINATING")
                elif kl_dominance_score > l0_dominance_score * 1.2:
                    print("⚠️ KL is DOMINATING")
                else:
                    print("✅ Metrics are BALANCED")
            elif l0_weight > 0:
                print("✅ L0-ONLY mode")
            elif kl_weight > 0:
                print("✅ KL-ONLY mode")


    def compute_l0_similarity(self, img_features, text_features, tau=0.75):
        # both are vectors in R^D (not doubled dimensions)
        # normalize only prototypes; keep image features as already normalized outside
        img_expanded = img_features.unsqueeze(1)                           # (B,1,D)
        text_expanded = text_features.unsqueeze(0)                         # (1,C,D)

        diff = (img_expanded - text_expanded).abs()                    # (B,C,D)

        mu = diff.mean(dim=2, keepdim=True)
        sigma = diff.std(dim=2, keepdim=True)
        thresholds = mu + 0.5 * sigma                 # tune 0.3–0.8

        # sharp but *not* the same temperature as KL
        temperature = 0.2
        gate = torch.sigmoid((diff-thresholds) / temperature)     # 1 if |Δ| below thr
                 # count "kept" dims (≈ L0 of small diffs)
        kept = gate.sum(dim=2)                               # (B,C)
        sim = kept / diff.size(2)                            # normalize to [0,1]
        logits = sim                                         # raw similarity, will be scaled by learnable parameter
         # for analysis you can return "l0_distance" as (#dims - kept), but you
         # don't need the signed-prob construction at all here.
        l0_distance = (diff > thresholds).float().sum(dim=2)
        return logits, l0_distance

    def compute_kl_similarity(self, img_features, text_features):
        """Compute KL divergence-based similarity between image features and prototypes"""
        bsz = img_features.shape[0]
        num_classes = text_features.shape[0]
        # Input features are already processed with to_signed_prob_smooth, so don't apply again
        # Expand dimensions for broadcasting
        img_expanded = img_features.unsqueeze(1)  # (bsz, 1, 2*512)
        text_expanded = text_features.unsqueeze(0)  # (1, num_classes, 2*512)
        # Compute KL divergence: KL(text||img) = sum(text * log(text/img))
        # Add small epsilon to avoid log(0)
        epsilon = 1e-10
        img_dist_safe = torch.clamp(img_expanded, min=epsilon, max=1.0)
        text_dist_safe = torch.clamp(text_expanded, min=epsilon, max=1.0)
        # Compute KL divergence
        kl_div = torch.sum(text_dist_safe * torch.log(text_dist_safe / img_dist_safe), dim=2)
        # Convert KL divergence to similarity (lower KL = higher similarity)
        # Use exponential transformation for better gradient flow
        kl_similarity = 1-kl_div # torch.exp(-kl_div)
        # Raw similarity, will be scaled by learnable parameter
        return kl_similarity, kl_div

    def compute_dot_similarity(self, img_features, text_features):
        """Compute dot product similarity between image features and prototypes"""
        # Normalize prototypes to unit vectors for cosine similarity
        text_features_norm = F.normalize(text_features, p=2, dim=1)
        # Compute dot product (cosine similarity since both are normalized)
        dot_similarities = torch.mm(img_features, text_features_norm.t())
        # Apply temperature scaling to make differences clearer
        temperature = 3.0
        dot_similarities = dot_similarities * temperature
        return dot_similarities

    def initialize_prototypes_from_data(self, dataloader, device):
        """Initialize class prototypes from the centroid of image embeddings for each class"""
        if self._prototypes_initialized:
            print("Prototypes already initialized from data")
            return
        print("Initializing prototypes from data centroids...")
        # Collect embeddings for each class
        class_embeddings = {i: [] for i in range(self.num_classes)}
        self.eval()
        with torch.no_grad():
            for batch_idx, batch_data in enumerate(dataloader):
                if batch_idx >= 100:  # Use first 100 batches for initialization
                    break
                # Handle different batch data formats
                if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 2:
                    inputs, targets = batch_data[0], batch_data[1]
                else:
                    continue
                inputs, targets = inputs.to(device), targets.to(device)
                embeddings = self.backbone(inputs)
                # Group embeddings by class
                for i, target in enumerate(targets):
                    class_idx = target.item()
                    class_embeddings[class_idx].append(embeddings[i])
            # Compute centroids for each class
            centroids = torch.zeros(self.num_classes, self.embedding_dim, device=device)
            valid_classes = 0
            for class_idx in range(self.num_classes):
                if len(class_embeddings[class_idx]) > 0:
                    class_emb = torch.stack(class_embeddings[class_idx])
                    centroid = class_emb.mean(dim=0)
                    centroids[class_idx] = centroid
                    valid_classes += 1
                    print(f"Class {class_idx}: {len(class_embeddings[class_idx])} samples, centroid norm: {torch.norm(centroid, p=2).item():.4f}")
                else:
                    print(f"Warning: No samples found for class {class_idx}")
            # Set prototypes to centroids
            with torch.no_grad():
                self.class_prototypes.data = centroids
            print(f"Initialized {valid_classes}/{self.num_classes} prototypes from data centroids")
            # Verify initial separation
            self._verify_prototype_separation()
            self._prototypes_initialized = True

    def _verify_prototype_separation(self):
        """Verify that prototypes have good initial separation"""
        with torch.no_grad():
            distances = torch.cdist(self.class_prototypes, self.class_prototypes, p=2)
            mask = torch.eye(self.num_classes, device=self.class_prototypes.device)
            off_diagonal = distances * (1 - mask)
            min_separation = off_diagonal.min().item()
            max_separation = off_diagonal.max().item()
            avg_separation = off_diagonal.mean().item()
            print(f"Prototype separation - Min: {min_separation:.4f}, Max: {max_separation:.4f}, Avg: {avg_separation:.4f}")
            if min_separation < 1.0:
                print("⚠️ Warning: Some prototypes are very close together!")

    def get_prototypes(self):
        """Get raw prototypes for regularization (not normalized)"""
        return self.class_prototypes

    def get_image_embeddings(self, x):
        """Extract normalized image embeddings"""
        embeddings = self.backbone(x)
        return F.normalize(embeddings, p=2, dim=1)

    def get_processed_embeddings(self, x):
        """Extract processed embeddings (with to_signed_prob_smooth) for mathematical proof analysis"""
        # Extract image embeddings
        image_embeddings = self.backbone(x)
        # Apply dropout
        image_embeddings = self.dropout1(image_embeddings)
        # Normalize image embeddings
        image_embeddings = F.normalize(image_embeddings, p=2, dim=1)
        # Apply dropout
        image_embeddings = self.dropout2(image_embeddings)
        
        # Per-head feature views
        z_l0 = F.normalize(image_embeddings, p=2, dim=1)
        z_kl = F.normalize(image_embeddings, p=2, dim=1)
        
        # Apply to_signed_prob_smooth to avoid zero elements
        z_l0_processed = F.softmax(z_l0/0.05, dim=1)
        z_kl_processed = F.softmax(z_kl/0.05, dim=1)
        
        # Also process projected prototypes to match dimensions for mathematical proof analysis
        protos_l0_projected = self.class_prototypes
        protos_kl_projected = self.class_prototypes

        #normalize prototypes
        protos_l0_projected = F.normalize(protos_l0_projected, p=2, dim=1)
        protos_kl_projected = F.normalize(protos_kl_projected, p=2, dim=1)
        protos_l0_processed = F.softmax(protos_l0_projected/0.05, dim=1)
        protos_kl_processed = F.softmax(protos_kl_projected/0.05, dim=1)
        
        return z_l0_processed, z_kl_processed, protos_l0_processed, protos_kl_processed

    def set_training_mode(self, l0_weight, kl_weight,dot_weight):
        """Set the training mode by freezing/unfreezing appropriate projection parameters"""
        if l0_weight == 0.0:
            # KL-only mode: freeze L0 projections
            for param in self.proj_l0.parameters():
                param.requires_grad_(False)
            for param in self.proto_proj_l0.parameters():
                param.requires_grad_(False)
            for param in self.proj_kl.parameters():
                param.requires_grad_(True)
            for param in self.proto_proj_kl.parameters():
                param.requires_grad_(True)
            print("✅ Training mode set to KL-ONLY: L0 frozen, KL trainable")
        elif kl_weight == 0.0:
            # L0-only mode: freeze KL projections
            for param in self.proj_l0.parameters():
                param.requires_grad_(True)
            for param in self.proto_proj_l0.parameters():
                param.requires_grad_(True)
            for param in self.proj_kl.parameters():
                param.requires_grad_(False)
            for param in self.proto_proj_kl.parameters():
                param.requires_grad_(False)
            print("✅ Training mode set to L0-ONLY: L0 trainable, KL frozen")
        else:
            # Combined mode: both trainable
            for param in self.proj_l0.parameters():
                param.requires_grad_(True)
            for param in self.proto_proj_l0.parameters():
                param.requires_grad_(True)
            for param in self.proj_kl.parameters():
                param.requires_grad_(True)
            for param in self.proto_proj_kl.parameters():
                param.requires_grad_(True)
            print("✅ Training mode set to COMBINED: Both L0 and KL trainable")

    def get_training_mode_status(self):
        """Get the current training mode status"""
        l0_trainable = any(p.requires_grad for p in self.proj_l0.parameters())
        kl_trainable = any(p.requires_grad for p in self.proj_kl.parameters())
        
        if l0_trainable and kl_trainable:
            return "COMBINED"
        elif l0_trainable and not kl_trainable:
            return "L0-ONLY"
        elif kl_trainable and not l0_trainable:
            return "KL-ONLY"
        else:
            return "UNKNOWN"

    def predict_with_consensus(self, x, tau=0.75):

        # Get individual similarities
        l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = self.forward(x, return_individual=True, tau=tau)
        # Get predictions from each metric (most likely class)
        l0_preds = torch.argmax(l0_sims, dim=1)
        kl_preds = torch.argmax(kl_sims, dim=1)
        dot_preds = torch.argmax(dot_sims, dim=1)
        # Check if L0 and KL predictions agree
        consensus_flags = (l0_preds == kl_preds)
        # Make predictions: use consensus prediction when available, otherwise unknown (-1)
        predictions = torch.where(consensus_flags, l0_preds, torch.full_like(l0_preds, -1))
        return predictions, consensus_flags, l0_preds, kl_preds, dot_preds

def check_prototype_status(model):
    """Check and display prototype training status"""
    print(f"\n=== PROTOTYPE STATUS ===")
    print(f"Prototypes trainable: {model.class_prototypes.requires_grad}")
    print(f"Prototype gradients: {model.class_prototypes.grad is not None if model.class_prototypes.grad is not None else 'None'}")
    
    with torch.no_grad():
        prototypes = model.class_prototypes
        print(f"Prototype norms: {torch.norm(prototypes, p=2, dim=1)}")
        print(f"Prototype mean norm: {torch.norm(prototypes, p=2, dim=1).mean():.4f}")
    
    return model.class_prototypes.requires_grad


def train_model(model, trainloader, valloader, testloader, device, args):
    """Train the combined L0+KL model"""
    print(f"==> Training Combined L0+KL Model...")
    print(f" L0 weight: {args.l0_weight}, KL weight: {args.kl_weight}, Dot weight: {args.dot_weight}")
    if args.freeze_prototypes:
        print("�� PROTOTYPE FREEZING MODE: Prototypes are frozen, only training encoder")
        # Freeze prototype parameters
        model.class_prototypes.requires_grad_(False)
        print(f"✅ Prototypes frozen: {model.class_prototypes.requires_grad}")
    else:
        print("�� FULL TRAINING MODE: Training both encoder and prototypes")
    

    criterion = nn.CrossEntropyLoss()
    
    # PURE SINGLE-METRIC TRAINING: Use the model's built-in method to set training mode
    model.set_training_mode(args.l0_weight, args.kl_weight,args.dot_weight)
    
    # Build optimizer parameters based on training mode
    if args.l0_weight == 0.0:
        print("🚀 PURE KL-ONLY MODE: Excluding L0 projections from optimizer")
        backbone_params = list(model.backbone.parameters()) + \
                         list(model.proj_kl.parameters()) + \
                         list(model.proto_proj_kl.parameters()) + \
                         [model.kl_logit_scale]
    elif args.kl_weight == 0.0:
        print("🚀 PURE L0-ONLY MODE: Excluding KL projections from optimizer")
        backbone_params = list(model.backbone.parameters()) + \
                         list(model.proj_l0.parameters()) + \
                         list(model.proto_proj_l0.parameters()) + \
                         [model.l0_logit_scale]
    else:
        print("🚀 COMBINED MODE: Including both L0 and KL projections in optimizer")
        backbone_params = list(model.backbone.parameters()) + \
                         list(model.proj_l0.parameters()) + \
                         list(model.proto_proj_l0.parameters()) + \
                         list(model.proj_kl.parameters()) + \
                         list(model.proto_proj_kl.parameters()) + \
                         [model.l0_logit_scale, model.kl_logit_scale]
    if args.freeze_prototypes:
        # Only train encoder (backbone + projections), not prototypes
        prototype_params = []  # Empty list - no prototype optimizer needed
        
        print("🚀 ENCODER-ONLY TRAINING: Prototypes excluded from optimizer")
    else:
        prototype_params = [model.class_prototypes]
        
        print("🚀 FULL TRAINING: Both encoder and prototypes included in optimizer")
    
    backbone_optimizer = optim.SGD(backbone_params, lr=args.lr * 0.1, momentum=0.9, weight_decay=1e-3)
    if prototype_params:
        prototype_optimizer = optim.SGD(prototype_params, lr=args.lr*0.1, momentum=0.9, weight_decay=1e-3)
        prototype_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(prototype_optimizer, T_max=args.epochs)
    else:
        prototype_optimizer = None
        prototype_scheduler = None
    
    # Log which parameters are being trained
    print(f"\n📊 OPTIMIZER SETUP:")
    print(f"   Backbone optimizer: {len(backbone_params)} parameter groups")
    print(f"   Prototype optimizer: {len(prototype_params)} parameter groups")
    
    # Count trainable parameters
    total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {total_trainable:,}")
    print(f"   Frozen parameters: {total_params - total_trainable:,}")
    
    # Schedulers
    backbone_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(backbone_optimizer, T_max=args.epochs)
    if prototype_optimizer:
        prototype_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(prototype_optimizer, T_max=args.epochs)
    else:
        prototype_scheduler = None

    best_acc = 0
    patience_counter = 0
    for epoch in range(args.epochs):
        print(f'\nEpoch: {epoch}')
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            # Zero gradients
            backbone_optimizer.zero_grad()
            if prototype_optimizer:
                prototype_optimizer.zero_grad()
            # Forward pass with separation loss and individual similarities
            l0_sims, kl_sims, combined_sims, separation_loss, l0_dists, kl_dists, dot_sims = model(inputs, targets=targets, tau=args.tau, 
                                          compute_separation=True, margin=args.margin, 
                                          class_boost=args.class_boost, l0_weight=args.l0_weight, 
                                          kl_weight=args.kl_weight,dot_weight=args.dot_weight, return_individual=True)
            
            # Compute separate losses for L0 and KL similarities (with standardized logits)
            loss_l0 = F.cross_entropy(_norm_logits_for_loss(l0_sims), targets)
            loss_kl = F.cross_entropy(_norm_logits_for_loss(kl_sims), targets)
            loss_dot = F.cross_entropy(_norm_logits_for_loss(dot_sims), targets)
            
            # Weighted combination of L0 and KL losses
            #cls_loss = args.l0_weight * loss_l0 + args.kl_weight * loss_kl
            cls_loss =  args.l0_weight * loss_l0 + args.kl_weight * loss_kl + args.dot_weight * loss_dot
            # Make L0 and KL disagree in their *rankings* (decorrelation loss)
            deco = corr_penalty(l0_sims.detach(), kl_sims)  # stop-grad on L0 side
            
            # Build total loss
            total_loss = args.alpha * cls_loss
            if not args.freeze_prototypes:
                total_loss = total_loss + args.beta * separation_loss
            total_loss = total_loss + args.deco_weight * deco
            
            # Add orthogonality regularization for projection layers
            if args.ortho_weight > 0:
                ortho_loss = ortho_reg(model.proj_l0.weight, model.proj_kl.weight)
                total_loss = total_loss + args.ortho_weight * ortho_loss
            
            # Logging
            if batch_idx % 50 == 0:
                if args.freeze_prototypes:
                    print(f"Loss: {total_loss.item():.4f} (L0: {loss_l0.item():.4f}, KL: {loss_kl.item():.4f}, Dot: {loss_dot.item():.4f}, deco: {deco.item():.4f}, prototypes frozen)")
                else:
                    print(f"Loss: {total_loss.item():.4f} (L0: {loss_l0.item():.4f}, KL: {loss_kl.item():.4f}, Dot: {loss_dot.item():.4f}, separation: {separation_loss.item():.4f}, deco: {deco.item():.4f})")
                if args.ortho_weight > 0:
                    print(f"  Ortho reg: {ortho_loss.item():.6f}")
            
            total_loss.backward()
            
            # Step optimizers
            backbone_optimizer.step()
            if prototype_optimizer:
                prototype_optimizer.step()
            # Update metrics
            train_loss += total_loss.item()
            _, predicted = combined_sims.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if batch_idx % 50 == 0:
                print(f'Batch {batch_idx}: Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})')

        # Validation (same as respective training)
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, targets in valloader:
                inputs, targets = inputs.to(device), targets.to(device)
                l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = model(inputs, tau=args.tau, l0_weight=args.l0_weight, kl_weight=args.kl_weight,dot_weight=args.dot_weight, return_individual=True)
                _, predicted = combined_sims.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
        val_acc = 100. * val_correct / val_total
        print(f'Validation Accuracy: {val_acc:.2f}%')

        # Display prototype separation analysis during validation
        print(f'\n=== PROTOTYPE SEPARATION ANALYSIS (Epoch {epoch}) ===')
        with torch.no_grad():
            prototypes = model.get_prototypes()
            distances = torch.cdist(prototypes, prototypes, p=2)
            mask = torch.eye(prototypes.size(0), device=prototypes.device)
            off_diagonal = distances * (1 - mask)
            
            print(f"Min separation: {off_diagonal.min().item():.4f}")
            print(f"Max separation: {off_diagonal.max().item():.4f}")
            print(f"Avg separation: {off_diagonal.mean().item():.4f}")
            
            # Show distance matrix (compact format)
            print("Distance Matrix:")
            for i in range(min(5, prototypes.size(0))):  # Show first 5 rows to avoid clutter
                row_str = "  "
                for j in range(min(5, prototypes.size(0))):
                    if i == j:
                        row_str += " 0.0000"
                    else:
                        row_str += f" {distances[i, j].item():.4f}"
                print(row_str)
            if prototypes.size(0) > 5:
                print(f"  ... (showing first 5x5, total: {prototypes.size(0)}x{prototypes.size(0)})")
            
            # Warning if separation is poor
            if off_diagonal.min().item() < 1.0:
                print("Warning: Some prototypes are very close together!")
            elif off_diagonal.min().item() < 2.0:
                print("Caution: Prototype separation could be improved")
            else:
                print("Good prototype separation maintained")
        print("=" * 60)

        # Save best model
        if val_acc > best_acc:
            print('Saving best model..')
            state = {
                'net': model.state_dict(),
                'acc': val_acc,
                'epoch': epoch,
                'l0_weight': args.l0_weight,
                'kl_weight': args.kl_weight,
                'dot_weight': args.dot_weight,
                'alpha': args.alpha,
                'beta': args.beta,
                'margin': args.margin,
                'tau': args.tau,
                'class_boost': args.class_boost,
                'model_type': 'combined_l0_kl'
            }
            torch.save(state, args.checkpoint)
            print(f'Saved best model to {args.checkpoint}')
            best_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= args.patience:
                print(f'Early stopping after {patience_counter} epochs without improvement')
                break
        backbone_scheduler.step()
        if prototype_scheduler:
            prototype_scheduler.step()
    print(f'Training completed! Best validation accuracy: {best_acc:.2f}%')
    return best_acc

def evaluate_model(model, testloader, device, args, l0_weight=None, kl_weight=None,dot_weight=None):
    """Evaluate the trained model on test set"""
    print("==> Evaluating model on test set...")
    model.eval()
    # Use weights from checkpoint if not provided
    if l0_weight is None or kl_weight is None:
        print(f"Using weights from checkpoint: L0={l0_weight}, KL={kl_weight}")
    total = 0
    correct = 0
    # Individual metric accuracies
    l0_correct = 0
    kl_correct = 0
    dot_correct = 0
    combined_correct = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            # Get individual similarities and combined output
            l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = model(inputs, return_individual=True, tau=args.tau, l0_weight=l0_weight, kl_weight=kl_weight,dot_weight=dot_weight)
            # Individual predictions
            l0_preds = torch.argmax(l0_sims, dim=1)
            kl_preds = torch.argmax(kl_sims, dim=1)
            dot_preds = torch.argmax(dot_sims, dim=1)
            combined_preds = torch.argmax(combined_sims, dim=1)
            # Count correct predictions
            l0_correct += (l0_preds == targets).sum().item()
            kl_correct += (kl_preds == targets).sum().item()
            dot_correct += (dot_preds == targets).sum().item()
            combined_correct += (combined_preds == targets).sum().item()
            total += targets.size(0)
    # Calculate accuracies
    l0_acc = 100. * l0_correct / total
    kl_acc = 100. * kl_correct / total
    dot_acc = 100. * dot_correct / total
    combined_acc = 100. * combined_correct / total
    print(f'\n=== TEST RESULTS ===')
    print(f'Total test images: {total}')
    print(f'L0 Accuracy: {l0_acc:.2f}% ({l0_correct}/{total})')
    print(f'KL Accuracy: {kl_acc:.2f}% ({kl_correct}/{total})')
    print(f'Dot Accuracy: {dot_acc:.2f}% ({dot_correct}/{total})')
    print(f'Combined Accuracy: {combined_acc:.2f}% ({combined_correct}/{total})')
    print(f'L0 weight: {l0_weight}, KL weight: {kl_weight}')
    
    # Display prototype separation analysis during evaluation
    print(f'\n=== PROTOTYPE SEPARATION ANALYSIS (EVALUATION) ===')
    with torch.no_grad():
        prototypes = model.get_prototypes()
        distances = torch.cdist(prototypes, prototypes, p=2)
        mask = torch.eye(prototypes.size(0), device=prototypes.device)
        off_diagonal = distances * (1 - mask)
        
        print(f"Min separation: {off_diagonal.min().item():.4f}")
        print(f"Max separation: {off_diagonal.max().item():.4f}")
        print(f"Avg separation: {off_diagonal.mean().item():.4f}")
        
        # Show distance matrix (compact format)
        print("Distance Matrix:")
        for i in range(min(5, prototypes.size(0))):  # Show first 5 rows to avoid clutter
            row_str = "  "
            for j in range(min(5, prototypes.size(0))):
                if i == j:
                    row_str += " 0.0000"
                else:
                    row_str += f" {distances[i, j].item():.4f}"
            print(row_str)
        if prototypes.size(0) > 5:
            print(f"  ... (showing first 5x5, total: {prototypes.size(0)}x{prototypes.size(0)})")
        
        # Warning if separation is poor
        if off_diagonal.min().item() < 1.0:
            print("Warning: Some prototypes are very close together!")
        elif off_diagonal.min().item() < 2.0:
            print("Caution: Prototype separation could be improved")
        else:
            print("Good prototype separation maintained")
    print("=" * 60)
    
    return combined_acc, l0_acc, kl_acc, dot_acc

def evaluate_consensus_accuracy(model, testloader, device, args, l0_weight=None, kl_weight=None,dot_weight=None):
    """Evaluate model using consensus prediction with 4-category breakdown (like respective.py)"""
    print("==> Evaluating consensus accuracy with 4-category breakdown...")
    model.eval()
    total = 0
    # Category counters for predictions
    same_correct = 0  # L0 and KL same and correct
    same_incorrect = 0  # L0 and KL same but incorrect
    diff_one_correct = 0  # L0 and KL different but one correct
    diff_both_wrong = 0  # L0 and KL different and both incorrect
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(testloader, desc='Evaluating consensus')):
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size = inputs.size(0)
            # Get individual similarities and predictions
            l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = model(inputs, return_individual=True, tau=args.tau, l0_weight=l0_weight, kl_weight=kl_weight,dot_weight=dot_weight)
            # Get predictions from each metric (most likely class)
            l0_preds = torch.argmax(l0_sims, dim=1)
            kl_preds = torch.argmax(kl_sims, dim=1)
            dot_preds = torch.argmax(dot_sims, dim=1)
            # Categorize predictions
            for i in range(batch_size):
                l0_pred = l0_preds[i].item()
                kl_pred = kl_preds[i].item()
                target = targets[i].item()
                if l0_pred == kl_pred:
                    # Same prediction
                    if l0_pred == target:
                        # Both correct
                        same_correct += 1
                    else:
                        # Both incorrect
                        same_incorrect += 1
                else:
                    # Different predictions
                    if l0_pred == target or kl_pred == target:
                        # One correct
                        diff_one_correct += 1
                    else:
                        # Both incorrect
                        diff_both_wrong += 1
            total += batch_size
    # Calculate percentages
    same_correct_pct = 100. * same_correct / total
    same_incorrect_pct = 100. * same_incorrect / total
    diff_one_correct_pct = 100. * diff_one_correct / total
    diff_both_wrong_pct = 100. * diff_both_wrong / total
    # Calculate consensus rate
    consensus_rate = 100. * (same_correct + same_incorrect) / total
    # Calculate consensus accuracy (only when L0 and KL agree)
    consensus_acc = 100. * same_correct / (same_correct + same_incorrect) if (same_correct + same_incorrect) > 0 else 0
    # Calculate true consensus accuracy (entire test set)
    true_consensus_acc = 100. * same_correct / total
    print(f'\n=== CONSENSUS EVALUATION RESULTS ===')
    print(f'Total test images: {total}')
    print(f'')
    print(f'1. L0 and KL same and correct: {same_correct} ({same_correct_pct:.2f}%)')
    print(f'2. L0 and KL same but incorrect: {same_incorrect} ({same_incorrect_pct:.2f}%)')
    print(f'3. L0 and KL different but one correct: {diff_one_correct} ({diff_one_correct_pct:.2f}%)')
    print(f'4. L0 and KL different and both incorrect: {diff_both_wrong} ({diff_both_wrong_pct:.2f}%)')
    print(f'')
    print(f'Consensus Rate: {consensus_rate:.2f}% (Groups 1+2)')
    print(f'Consensus Accuracy: {consensus_acc:.2f}% (Group 1 / Groups 1+2)')
    print(f'TRUE Consensus Accuracy: {true_consensus_acc:.2f}% (Group 1 / Total)')
    print(f'')
    print(f'Note: "Consensus Accuracy" only counts predictions where L0 and KL agree.')
    print(f' "TRUE Consensus Accuracy" counts all predictions (unknown = incorrect).')
    print(f'Breakdown:')
    print(f' - L0 and KL agree: {same_correct + same_incorrect} images ({consensus_rate:.1f}%)')
    print(f' - L0 and KL disagree: {diff_one_correct + diff_both_wrong} images ({100-consensus_rate:.1f}%)')
    print(f' - When they agree, correct: {same_correct} images ({consensus_acc:.1f}%)')
    print(f' - When they disagree, all marked as unknown (incorrect)')
    print(f' - Total correct: {same_correct} out of {total} images')
    print(f' - TRUE accuracy: {true_consensus_acc:.2f}%')
    
    # Display prototype separation analysis during consensus evaluation
    print(f'\n=== PROTOTYPE SEPARATION ANALYSIS (CONSENSUS EVALUATION) ===')
    with torch.no_grad():
        prototypes = model.get_prototypes()
        distances = torch.cdist(prototypes, prototypes, p=2)
        mask = torch.eye(prototypes.size(0), device=prototypes.device)
        off_diagonal = distances * (1 - mask)
        
        print(f"Min separation: {off_diagonal.min().item():.4f}")
        print(f"Max separation: {off_diagonal.max().item():.4f}")
        print(f"Avg separation: {off_diagonal.mean().item():.4f}")
        
        # Show distance matrix (compact format)
        print("Distance Matrix:")
        for i in range(min(5, prototypes.size(0))):  # Show first 5 rows to avoid clutter
            row_str = "  "
            for j in range(min(5, prototypes.size(0))):
                if i == j:
                    row_str += " 0.0000"
                else:
                    row_str += f" {distances[i, j].item():.4f}"
            print(row_str)
        if prototypes.size(0) > 5:
            print(f"  ... (showing first 5x5, total: {prototypes.size(0)}x{prototypes.size(0)})")
        
        # Warning if separation is poor
        if off_diagonal.min().item() < 1.0:
            print("Warning: Some prototypes are very close together!")
        elif off_diagonal.min().item() < 2.0:
            print("Caution: Prototype separation could be improved")
        else:
            print("Good prototype separation maintained")
    print("=" * 60)
    
    return {
        'same_correct': same_correct,
        'same_incorrect': same_incorrect,
        'diff_one_correct': diff_one_correct,
        'diff_both_wrong': diff_both_wrong,
        'consensus_rate': consensus_rate,
        'consensus_acc': consensus_acc,
        'true_consensus_acc': true_consensus_acc,
        'total': total
    }

def attack_pgd(model, X, target, alpha, attack_iters, norm, device, epsilon=0, restarts=1, tau=0.75, mean=None, std=None):
    """Projected Gradient Descent attack with proper normalized space handling"""
    X = X.to(device)
    target = target.to(device)
    #lower_limit = ((0 - mean) / std).to(device)
    #upper_limit = ((1 - mean) / std).to(device)
    #std_avg = std.mean().item() 
    # scale eps/alpha for normalized space (per-channel), shape (1,3,1,1)
    #epsilon_scaled = epsilon / std_avg
    #alpha_scaled = alpha / std_avg
    epsilon_scaled=epsilon
    alpha_scaled=alpha
    # Initialize perturbation
    delta = torch.zeros_like(X, device=device)
    if norm == "l_inf":
        delta.uniform_(-epsilon_scaled, epsilon_scaled)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0), -1)
        n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r / n * epsilon_scaled
    else:
        raise ValueError(f"Norm {norm} not supported")
    
    # Project to valid bounds
    delta = torch.max(torch.min(delta, upper_limit - X), lower_limit - X)
    delta.requires_grad = True
    
    for _ in range(attack_iters):
        # Forward pass
        # Get L0 and KL similarities from perturbed input
        _, _, combined_sims, _, _, _ = model(X + delta, return_individual=True, tau=tau)
        # Use combined similarities for attack (has gradients), not detached dot similarities
        # Use combined similarities for attack (has gradients), not detached dot similarities
        loss = F.cross_entropy(combined_sims, target)
        #loss = -torch.mean(torch.abs(l0_sims - dot_sims))
        #output = model(X + delta, tau=tau)
        #loss = F.cross_entropy(output, target)
        
        # Backward pass
        loss.backward()
        
        # Get gradient
        grad = delta.grad.detach()
        d = delta[:, :, :, :]
        g = grad[:, :, :, :]
        x = X[:, :, :, :]
        
        if norm == "l_inf":
            d = torch.clamp(d + alpha_scaled * torch.sign(g), min=-epsilon_scaled, max=epsilon_scaled)
        elif norm == "l_2":
            g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
            scaled_g = g / (g_norm + 1e-10)
            d = (d + scaled_g * alpha_scaled).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon_scaled).view_as(d)
        
        # Project to valid bounds
        d = torch.max(torch.min(d, upper_limit - x), lower_limit - x)
        delta.data[:, :, :, :] = d
        delta.grad.zero_()
    
    return delta



def evaluate_adversarial_robustness(model, testloader, device, args, l0_weight=None, kl_weight=None,dot_weight=None):
    """Evaluate adversarial robustness using consensus prediction with 4-category breakdown and Proposition 4 analysis"""
    print("==> Evaluating adversarial robustness...")
    
    # Get clean predictions directly instead of calling full consensus evaluation
    print("==> Getting clean predictions for baseline...")
    model.eval()
    total = 0
    clean_same_correct = 0
    clean_same_incorrect = 0
    clean_diff_one_correct = 0
    clean_diff_both_wrong = 0
    
    # Track individual metric accuracies for clean data
    clean_l0_correct = 0
    clean_kl_correct = 0
    clean_dot_correct = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(testloader, desc='Clean baseline')):
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size = inputs.size(0)
            # Get individual similarities and predictions
            l0_sims, kl_sims, combined_sims, l0_dists, kl_dists, dot_sims = model(inputs, return_individual=True, tau=args.tau, l0_weight=l0_weight, kl_weight=kl_weight,dot_weight=dot_weight)
            # Get predictions from each metric
            #l0_preds = torch.argmax(l0_sims, dim=1)
            l0_preds = torch.argmin(l0_dists, dim=1)
            #kl_preds = torch.argmax(kl_sims, dim=1)
            kl_preds = torch.argmin(kl_dists, dim=1)
            dot_preds = torch.argmax(dot_sims, dim=1)
            # Count individual metric accuracies
            clean_l0_correct += (l0_preds == targets).sum().item()
            clean_kl_correct += (kl_preds == targets).sum().item()
            clean_dot_correct += (dot_preds == targets).sum().item()
            # Categorize clean predictions
            for i in range(batch_size):
                l0_pred = l0_preds[i].item()
                kl_pred = kl_preds[i].item()
                target = targets[i].item()
                if l0_pred == kl_pred:
                    if l0_pred == target:
                        clean_same_correct += 1
                    else:
                        clean_same_incorrect += 1
                else:
                    if l0_pred == target or kl_pred == target:
                        clean_diff_one_correct += 1
                    else:
                        clean_diff_both_wrong += 1
            total += batch_size
    
    print(f"Clean same correct: {clean_same_correct}, Clean same incorrect: {clean_same_incorrect}, Clean diff one correct: {clean_diff_one_correct}, Clean diff both wrong: {clean_diff_both_wrong}")
    print(f"Clean l0 correct: {clean_l0_correct}, Clean kl correct: {clean_kl_correct}, Clean dot correct: {clean_dot_correct}")
  
    
    equation_fulfilled = 0
    equation_not_fulfilled = 0
    equation_fulfilled_clean_correct = 0
    equation_fulfilled_adv_same_incorrect = 0
    equation_not_fulfilled_clean_correct = 0
    equation_not_fulfilled_adv_same_incorrect = 0
    # Re-calculate clean metric combinations
    adv_same_correct = 0
    adv_same_incorrect = 0
    adv_diff_one_correct = 0
    adv_diff_both_wrong = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(testloader, desc='Adversarial evaluation')):
        with torch.no_grad():
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size = inputs.size(0)
            l0_sims_clean, kl_sims_clean, combined_sims_clean, l0_dists_clean, kl_dists_clean, dot_sims_clean = model(inputs, return_individual=True, tau=args.tau, l0_weight=l0_weight, kl_weight=kl_weight,dot_weight=dot_weight)
            #clean predictions
            l0_preds_clean = torch.argmin(l0_dists_clean, dim=1)
            kl_preds_clean = torch.argmin(kl_dists_clean, dim=1)
            dot_preds_clean = torch.argmax(dot_sims_clean, dim=1)
            #l0_preds_clean = torch.argmax(l0_sims_clean, dim=1)
            #kl_preds_clean = torch.argmax(kl_sims_clean, dim=1)
            #dot_preds_clean = torch.argmax(dot_sims_clean, dim=1)


        # Generate adversarial examples with PGD
        delta = attack_pgd(model, inputs, targets, args.attack_stepsize, 10, args.norm, device, args.attack_epsilon, tau=args.tau, mean=CIFAR_MEAN, std=CIFAR_STD)
        with torch.no_grad():
            l0_sims_adv, kl_sims_adv, combined_sims_adv, l0_dists_adv, kl_dists_adv, dot_sims_adv = model(inputs + delta, return_individual=True, tau=args.tau, l0_weight=l0_weight, kl_weight=kl_weight,dot_weight=dot_weight)
            # Get adversarial predictions
            l0_preds_adv = torch.argmin(l0_dists_adv, dim=1)
            kl_preds_adv = torch.argmin(kl_dists_adv, dim=1)
            dot_preds_adv = torch.argmax(dot_sims_adv, dim=1)

            
            for i in range(batch_size):
                target = targets[i].item()

                pred_class_l0 = l0_preds_clean[i].item()
                pred_class_l0_adv = l0_preds_adv[i].item()
                pred_class_kl = kl_preds_clean[i].item()
                pred_class_kl_adv = kl_preds_adv[i].item()
                pred_class_dot = dot_preds_clean[i].item()
                pred_class_dot_adv = dot_preds_adv[i].item()

                if pred_class_l0_adv == pred_class_kl_adv:
                    if pred_class_l0_adv == target:
                        adv_same_correct += 1
                    else:
                        adv_same_incorrect += 1
                else:
                    if pred_class_l0_adv == target or pred_class_kl_adv == target:
                        adv_diff_one_correct += 1
                    else:
                        adv_diff_both_wrong += 1

                pred_class = pred_class_kl_adv
                if pred_class_l0==pred_class_kl==target:
                    pred_class_clean = pred_class_kl
                else:
                    continue

                # Get KL distances (not similarities) for mathematical proof analysis
                KL_y_prime_p_star = kl_dists_clean[i, pred_class]
                KL_y_star_p_star = kl_dists_clean[i, target]
                KL_y_prime_p_prime = kl_dists_adv[i, pred_class]
                KL_y_star_p_prime = kl_dists_adv[i, target]
                
                # Get L0 distances (not similarities) for mathematical proof analysis
                L0_y_prime_p_star = l0_dists_clean[i, pred_class]
                L0_y_star_p_star = l0_dists_clean[i, target]
                L0_y_prime_p_prime = l0_dists_adv[i, pred_class]
                L0_y_star_p_prime = l0_dists_adv[i, target]

                # Check mathematical conditions before analysis
                Delta_KL_p_star = KL_y_prime_p_star - KL_y_star_p_star
                Delta_KL_p_prime = KL_y_prime_p_prime - KL_y_star_p_prime
                Delta_L0_p_star = L0_y_prime_p_star - L0_y_star_p_star

                p_star_l0, p_star_kl, protos_l0, protos_kl = model.get_processed_embeddings(inputs[i:i+1])  # Clean image embeddings
                p_prime_l0, p_prime_kl, _, _ = model.get_processed_embeddings(inputs[i:i+1] + delta[i:i+1])  # Adversarial image embeddings
                # Use KL embeddings and prototypes for mathematical proof analysis
                p_star = p_star_kl[0]  # Clean image embedding (KL)
                p_prime = p_prime_kl[0]  # Adversarial image embedding (KL)
                y_star = protos_kl[target]  # Target class prototype (processed)
                y_prime = protos_kl[pred_class]  # Conflicting class prototype (processed)
            
                # Check if Equation (46) is fulfilled
                
                output = math_proof_analysis(
                    target, pred_class, y_star, y_prime, p_star, p_prime, args.attack_epsilon,
                    KL_y_prime_p_star, KL_y_star_p_star, KL_y_prime_p_prime, KL_y_star_p_prime,
                    L0_y_prime_p_star, L0_y_star_p_star, L0_y_prime_p_prime, L0_y_star_p_prime, args
                )
                if output: 
                    equation_fulfilled += 1
                    try:
                        if not (L0_y_prime_p_star >= L0_y_star_p_star and L0_y_prime_p_prime >= L0_y_star_p_prime):
                            print(f"L0 condition not satisfied - L0_y_prime_p_star: {L0_y_prime_p_star}, L0_y_star_p_star: {L0_y_star_p_star}, L0_y_prime_p_prime: {L0_y_prime_p_prime}, L0_y_star_p_prime: {L0_y_star_p_prime}")
                    except:
                        print(f"L0_y_prime_p_star: {L0_y_prime_p_star}, L0_y_star_p_star: {L0_y_star_p_star}, L0_y_prime_p_prime: {L0_y_prime_p_prime}, L0_y_star_p_prime: {L0_y_star_p_prime}")
                    if pred_class_clean==pred_class_l0==target:
                        equation_fulfilled_clean_correct += 1
                    if (pred_class==pred_class_l0_adv) and (pred_class !=target) and not  (L0_y_prime_p_prime==L0_y_star_p_prime) and not (L0_y_prime_p_star==L0_y_star_p_star):
                        try:
                            if not (L0_y_prime_p_star>L0_y_star_p_star and L0_y_prime_p_prime<=L0_y_star_p_prime):
                                print(f"L0 condition not satisfied! - L0_y_prime_p_star: {L0_y_prime_p_star}, L0_y_star_p_star: {L0_y_star_p_star}, L0_y_prime_p_prime: {L0_y_prime_p_prime}, L0_y_star_p_prime: {L0_y_star_p_prime}")
                        except:
                            print(f"L0_y_prime_p_star: {L0_y_prime_p_star}, L0_y_star_p_star: {L0_y_star_p_star}, L0_y_prime_p_prime: {L0_y_prime_p_prime}, L0_y_star_p_prime: {L0_y_star_p_prime}")
                        #assert L0_y_prime_p_star>L0_y_star_p_star and L0_y_prime_p_prime<L0_y_star_p_prime, f"L0_y_prime_p_star: {L0_y_prime_p_star}, L0_y_star_p_star: {L0_y_star_p_star}, L0_y_prime_p_prime: {L0_y_prime_p_prime}, L0_y_star_p_prime: {L0_y_star_p_prime}"
                        equation_fulfilled_adv_same_incorrect += 1

                else:
                    equation_not_fulfilled +=1
                    if pred_class_clean==pred_class_l0==target:
                        equation_not_fulfilled_clean_correct += 1
                    if (pred_class==pred_class_l0_adv) and (pred_class !=target):
                        equation_not_fulfilled_adv_same_incorrect += 1
    
    print(f"Adv same correct: {adv_same_correct}, Adv same incorrect: {adv_same_incorrect}, Adv diff one correct: {adv_diff_one_correct}, Adv diff both wrong: {adv_diff_both_wrong}")
    
    print(f"Equation fulfilled: {equation_fulfilled}, Equation not fulfilled: {equation_not_fulfilled}, equation fulfilled and clean correct: {equation_fulfilled_clean_correct}, equation fulfilled and adv same incorrect: {equation_fulfilled_adv_same_incorrect}, equation not fulfilled and clean correct: {equation_not_fulfilled_clean_correct}, equation not fulfilled and adv same incorrect: {equation_not_fulfilled_adv_same_incorrect}")
    print(f"Equation fulfilled and clean correct percentage: {equation_fulfilled_clean_correct/equation_fulfilled},clean incorrect: {(equation_fulfilled-equation_fulfilled_clean_correct)/equation_fulfilled}, Equation fulfilled and adv same incorrect: {equation_fulfilled_adv_same_incorrect/equation_fulfilled},adv not same incorrect: {(equation_fulfilled-equation_fulfilled_adv_same_incorrect)/equation_fulfilled}")
    print(f"Equation not fulfilled and clean correct percentage: {equation_not_fulfilled_clean_correct/equation_not_fulfilled},clean incorrect: {(equation_not_fulfilled-equation_not_fulfilled_clean_correct)/equation_not_fulfilled}, Equation not fulfilled and adv same incorrect: {equation_not_fulfilled_adv_same_incorrect/equation_not_fulfilled},adv not same incorrect: {(equation_not_fulfilled-equation_not_fulfilled_adv_same_incorrect)/equation_not_fulfilled}")
    return 

def main():
    parser = argparse.ArgumentParser(description='Train CIFAR-10 with Combined L0+KL Prototypes')
    parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
    parser.add_argument('--epochs', type=int, default=100, help='number of training epochs')
    parser.add_argument('--alpha', type=float, default=1.0, help='weight for classification loss')
    parser.add_argument('--beta', type=float, default=0.01, help='weight for separation regularization')
    parser.add_argument('--margin', type=float, default=0.2, help='minimum distance between prototypes')
    parser.add_argument('--tau', type=float, default=0.75, help='threshold parameter for L0 similarity')
    parser.add_argument('--class_boost', type=float, default=0.5, help='boost value for correct class similarities')
    parser.add_argument('--patience', type=int, default=10, help='early stopping patience')
    parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet18', 'resnet50', 'vgg19', 'densenet121'], help='backbone architecture')
    parser.add_argument('--pretrained', action='store_true', help='use pretrained ImageNet weights for backbone')
    parser.add_argument('--checkpoint', type=str, default='./checkpoint_combined/ckpt_combined.pth', help='checkpoint path')
    parser.add_argument('--train', action='store_true', help='train the model')
    parser.add_argument('--evaluate', action='store_true', help='evaluate trained model')
    parser.add_argument('--consensus_eval', action='store_true', help='evaluate using consensus prediction (4-category breakdown)')
    parser.add_argument('--eval_robustness', action='store_true', help='evaluate adversarial robustness with PGD attack')
    parser.add_argument('--attack_epsilon', type=float, default=2, help='perturbation budget for adversarial attack (will be divided by 255)')
    parser.add_argument('--attack_stepsize', type=float, default=2, help='attack step size for adversarial attack (will be divided by 255)')
    parser.add_argument('--norm', type=str, default='l_inf', choices=['l_inf', 'l_2'], help='norm for adversarial attack')
    parser.add_argument('--l0_weight', type=float, default=1.0, help='weight for L0 similarity (0.0 to 1.0)')
    parser.add_argument('--kl_weight', type=float, default=0.0, help='weight for KL similarity (0.0 to 1.0)')
    parser.add_argument('--dropout_rate', type=float, default=0.3, help='dropout rate')
    # ADD: Model loading arguments for continuing training from pre-trained models
    parser.add_argument('--l0_load_model', type=str, help='load L0 model from checkpoint to continue training')
    parser.add_argument('--kl_load_model', type=str, help='load KL model from checkpoint to continue training')
    parser.add_argument('--load_model', type=str, help='load standard model checkpoint (from main.py) and convert to prototype architecture')
    parser.add_argument('--freeze_prototypes', action='store_true', 
                       help='freeze prototypes after initialization and only train encoder')
    # Add math proof argument
    parser.add_argument('--math_proof', action='store_true', help='run mathematical proof analysis during adversarial evaluation')
    # Add decorrelation and orthogonality arguments
    parser.add_argument('--deco_weight', type=float, default=0.15,
                        help='weight for L0/KL logit decorrelation')
    parser.add_argument('--ortho_weight', type=float, default=1e-3,
                        help='weight for proj_l0/proj_kl orthogonality')
    parser.add_argument('--dot_weight', type=float, default=1.0, help='weight for dot similarity (0.0 to 1.0)')
    
    
    args = parser.parse_args()

    # Validate weights
    if args.l0_weight < 0 or args.kl_weight < 0:
        raise ValueError("Weights must be non-negative")
    if args.l0_weight == 0 and args.kl_weight == 0:
        raise ValueError("At least one weight must be greater than 0")
    # Normalize weights if both are > 0
    if args.l0_weight > 0 and args.kl_weight > 0:
        total_weight = args.l0_weight + args.kl_weight
        args.l0_weight /= total_weight
        args.kl_weight /= total_weight
        print(f"Normalized weights: L0={args.l0_weight:.3f}, KL={args.kl_weight:.3f}")
    # Convert attack parameters from integer to float (divide by 255)
    args.attack_epsilon = args.attack_epsilon / 255.0
    args.attack_stepsize = args.attack_stepsize / 255.0
    print(f"Attack parameters: epsilon={args.attack_epsilon:.6f}, stepsize={args.attack_stepsize:.6f}, norm={args.norm}")

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'Using device: {device}')

    # Data preparation
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True, num_workers=2)
    # Split test set into validation and test sets
    testset_full = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    torch.manual_seed(42)
    test_size = len(testset_full)
    val_size = test_size // 2
    test_size = test_size - val_size
    valset, testset = torch.utils.data.random_split(testset_full, [val_size, test_size])
    valloader = torch.utils.data.DataLoader(
        valset, batch_size=100, shuffle=False, num_workers=2)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2)
    print(f'Training set: {len(trainset)} images')
    print(f'Validation set: {len(valset)} images')
    print(f'Test set: {len(testset)} images')

    # Build backbone
    print('==> Building backbone..')
    if args.pretrained:
        print('==> Loading pretrained ImageNet weights for backbone..')
        if args.backbone == 'resnet18':
            backbone = torchvision.models.resnet18(pretrained=True)
        elif args.backbone == 'resnet50':
            backbone = torchvision.models.resnet50(pretrained=True)
        elif args.backbone == 'vgg19':
            backbone = torchvision.models.vgg19(pretrained=True)
        elif args.backbone == 'densenet121':
            backbone = torchvision.models.densenet121(pretrained=True)
        else:
            backbone = ResNet18()
        # Remove the final classification layer
        if hasattr(backbone, 'fc'):
            backbone.fc = nn.Identity()
        elif hasattr(backbone, 'classifier'):
            backbone.classifier = nn.Identity()
        print(f'==> Loaded pretrained {args.backbone} backbone')
    else:
        print('==> Using random initialization for backbone..')
        if args.backbone == 'resnet18':
            backbone = ResNet18()
            backbone.linear = nn.Identity()
        elif args.backbone == 'vgg19':
            backbone = VGG('VGG19')
            backbone.classifier = nn.Identity()
        elif args.backbone == 'densenet121':
            backbone = DenseNet121()
            backbone.classifier = nn.Identity()
        else:
            backbone = ResNet18()
            backbone.linear = nn.Identity()

        # Replace ReLU with GELU
    print('==> Replacing ReLU activations with GELU...')
    replace_activations(backbone, nn.GELU())

    # ADD: Load pretrained weights if specified (same as respective file)
    if args.load_model:
        print(f'==> Loading model from {args.load_model}..')
        if os.path.isfile(args.load_model):
            checkpoint = torch.load(args.load_model)
            # Load backbone weights (excluding final layer)
            backbone_state_dict = {}
            for key, value in checkpoint['net'].items():
                if key.startswith('module.'):
                    clean_key = key[7:]
                else:
                    clean_key = key
                if not clean_key.startswith('linear.') and not clean_key.startswith('classifier.'):
                    backbone_state_dict[clean_key] = value
            backbone.load_state_dict(backbone_state_dict)
            print(f'==> Loaded backbone weights from checkpoint')
        else:
            print(f'==> No model file found at {args.load_model}')
            print('==> Starting with random initialization')

    # Create model
    model = DynamicPrototypeModelL0KL(backbone, num_classes=10, embedding_dim=512, dropout_rate=args.dropout_rate)
    model = model.to(device)

    # ADD: Load L0 model weights if specified
    if args.l0_load_model:
        print(f'==> Loading L0 model from {args.l0_load_model}..')
        if os.path.isfile(args.l0_load_model):
            l0_checkpoint = torch.load(args.l0_load_model)
            model.load_state_dict(l0_checkpoint['net'])
            print(f'==> Loaded L0 model with accuracy: {l0_checkpoint["acc"]:.2f}%')
            print(f'==> Continuing training from L0 checkpoint')
        else:
            print(f'==> L0 checkpoint not found at {args.l0_load_model}')
            print('==> Starting with random initialization')

    # ADD: Load KL model weights if specified
    if args.kl_load_model:
        print(f'==> Loading KL model from {args.kl_load_model}..')
        if os.path.isfile(args.kl_load_model):
            kl_checkpoint = torch.load(args.kl_load_model)
            model.load_state_dict(kl_checkpoint['net'])
            print(f'==> Loaded KL model with accuracy: {kl_checkpoint["acc"]:.2f}%')
            print(f'==> Continuing training from KL checkpoint')
        else:
            print(f'==> KL checkpoint not found at {args.kl_load_model}')
            print('==> Starting with random initialization')

    # Initialize prototypes from data centroids (only if not loading from checkpoint)
    if not args.l0_load_model and not args.kl_load_model:
        print("==> Initializing prototypes from data centroids...")
        model.initialize_prototypes_from_data(trainloader, device)

        # Freeze prototypes if requested
        if args.freeze_prototypes:
            model.freeze_prototypes()
            print("✅ Prototypes initialized and frozen")
        else:
            print("✅ Prototypes initialized and trainable")
    else:
        print("==> Using prototypes from loaded checkpoint (skipping initialization)")
        if args.freeze_prototypes:
            model.freeze_prototypes()

    # Training
    if args.train:
        print("==> Starting training...")
        print(f"Training mode: L0 weight={args.l0_weight}, KL weight={args.kl_weight}, Dot weight={args.dot_weight}")
        print(f"Prototype training: {'FROZEN' if args.freeze_prototypes else 'TRAINABLE'}")
        check_prototype_status(model)
        # Create checkpoint directory
        os.makedirs(os.path.dirname(args.checkpoint), exist_ok=True)
        # Train the model
        best_acc = train_model(model, trainloader, valloader, testloader, device, args)
        print(f"Training completed with best accuracy: {best_acc:.2f}%")

    # Evaluation
    if args.evaluate:
        print("==> Loading trained model for evaluation...")
        if os.path.isfile(args.checkpoint):
            checkpoint = torch.load(args.checkpoint)
            model.load_state_dict(checkpoint['net'])
            print(f"Loaded model with accuracy: {checkpoint['acc']:.2f}%")
            # Load weights from checkpoint automatically
            if 'l0_weight' in checkpoint and 'kl_weight' in checkpoint:
                l0_weight = checkpoint['l0_weight']
                kl_weight = checkpoint['kl_weight']
                # Check if dot_weight exists in checkpoint, otherwise use default
                if 'dot_weight' in checkpoint:
                    dot_weight = checkpoint['dot_weight']
                    print(f"Loaded weights from checkpoint: L0={l0_weight}, KL={kl_weight}, Dot={dot_weight}")
                else:
                    dot_weight = args.dot_weight
                    print(f"Loaded weights from checkpoint: L0={l0_weight}, KL={kl_weight}")
                    print(f"Using command line dot_weight: {dot_weight} (not found in checkpoint)")
            else:
                # Fallback to command line arguments if not in checkpoint
                l0_weight = args.l0_weight
                kl_weight = args.kl_weight
                dot_weight = args.dot_weight
                print(f"Using command line weights: L0={l0_weight}, KL={kl_weight}, Dot={dot_weight}")
                print(f"Note: Consider retraining to save weights in checkpoint")
            # Evaluate on test set
            combined_acc, l0_acc, kl_acc, dot_acc = evaluate_model(model, testloader, device, args, l0_weight, kl_weight,dot_weight)
            if args.consensus_eval:
                evaluate_consensus_accuracy(model, testloader, device, args, l0_weight, kl_weight,dot_weight)
            if args.eval_robustness:
                evaluate_adversarial_robustness(model, testloader, device, args, l0_weight, kl_weight,dot_weight)
        else:
            print(f"Checkpoint not found at {args.checkpoint}")
        return

if __name__ == '__main__':
    main()