import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import functools

# Try to import AutoAttack
try:
    from autoattack import AutoAttack
    AUTOATTACK_AVAILABLE = True
    print("AutoAttack imported successfully")
except ImportError:
    AUTOATTACK_AVAILABLE = False
    print("Warning: AutoAttack not available. Install with: pip install autoattack")

from models import *
try:
    from utils import progress_bar
    PROGRESS_BAR_AVAILABLE = True
except (ValueError, OSError):
    # Fallback progress bar if utils.py fails
    PROGRESS_BAR_AVAILABLE = False
    print("Warning: Progress bar not available, using simple progress display")
    
    def progress_bar(current, total, msg=None):
        if current % 10 == 0 or current == total - 1:
            print(f"Progress: {current+1}/{total} {msg or ''}")

# Adversarial attack parameters
CIFAR_MEAN = torch.tensor([0.4914, 0.4822, 0.4465]).view(1, 3, 1, 1)
CIFAR_STD = torch.tensor([0.2023, 0.1994, 0.2010]).view(1, 3, 1, 1)
lower_limit, upper_limit = 0, 1
def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)

class DynamicPrototypeModel(nn.Module):
    def __init__(self, backbone, num_classes=10, embedding_dim=512):
        super().__init__()
        self.backbone = backbone
        # Learnable class prototypes - initialized with small random values around zero
        # Note: We use forced_prototypes = prototypes + 0.0 in regularization to maintain gradients
        self.class_prototypes = nn.Parameter(torch.randn(num_classes, embedding_dim) * 0.1)
        self.num_classes = num_classes
        
    def forward(self, x):
        # Extract image embeddings
        image_embeddings = self.backbone(x)
        
        # Normalize both image embeddings and class prototypes
        image_embeddings = F.normalize(image_embeddings, p=2, dim=1)
        class_prototypes = F.normalize(self.class_prototypes, p=2, dim=1)
        
        # Compute cosine similarities
        similarities = torch.mm(image_embeddings, class_prototypes.t())
        return similarities
    
    def get_prototypes(self):
        """Get normalized prototypes for regularization"""
        return F.normalize(self.class_prototypes, p=2, dim=1)
    
    def get_image_embeddings(self, x):
        """Extract normalized image embeddings"""
        embeddings = self.backbone(x)
        return F.normalize(embeddings, p=2, dim=1)

def separation_regularizer(class_prototypes, margin=0.5, temperature=1.0):
    """
    Push class prototypes further apart using cosine distance
    
    Args:
        class_prototypes: Normalized prototypes (num_classes, embedding_dim)
        margin: Minimum desired distance between prototypes
        temperature: Scaling factor for the regularization strength
    """
    num_classes = class_prototypes.size(0)
    
    # Force gradient connection by using prototypes in computation
    forced_prototypes = class_prototypes + 0.0  # This maintains gradients
    
    similarity_matrix = torch.mm(forced_prototypes, forced_prototypes.t())

    mask = torch.eye(num_classes, device=forced_prototypes.device)
    off_diagonal_similarities = similarity_matrix * (1 - mask)

    distances = 1 - off_diagonal_similarities

    violation = F.relu(margin - distances)

    separation_loss = violation.mean() * temperature
    
    return separation_loss

def prototype_loss_with_regularization(outputs, targets, model, 
                                     alpha=1.0, beta=0.1, margin=0.5):
    """
    Combined loss: classification + separation regularization
    
    Args:
        outputs: Model predictions (batch_size, num_classes)
        targets: Ground truth labels
        model: Model instance to access prototypes
        alpha: Weight for separation regularization
        beta: Weight for separation regularization
        margin: Minimum distance between prototypes
    """
    # Classification loss (cross-entropy)
    classification_loss = F.cross_entropy(outputs, targets)
    
    # Separation regularization - handle DataParallel
    if hasattr(model, 'module'):
        # DataParallel case: access underlying model
        prototypes = model.module.get_prototypes()
    else:
        # Single GPU case: direct access
        prototypes = model.get_prototypes()
    
    # Force gradient connection by using prototypes in computation
    forced_prototypes = prototypes + 0.0  # This maintains gradients
    separation_loss = separation_regularizer(forced_prototypes, margin=margin)
    
    # Combined loss
    total_loss = alpha * classification_loss + beta * separation_loss
    
    return total_loss, classification_loss, separation_loss

def visualize_prototypes(model, epoch, save_dir='./prototype_vis'):
    """Visualize how prototypes evolve during training"""
    os.makedirs(save_dir, exist_ok=True)
    
    # Handle DataParallel
    if hasattr(model, 'module'):
        prototypes = model.module.get_prototypes().detach().cpu()
    else:
        prototypes = model.get_prototypes().detach().cpu()
    
    # Force gradient connection for visualization (though detached, keeps pattern consistent)
    forced_prototypes = prototypes + 0.0
    
    # Compute pairwise distances
    num_classes = forced_prototypes.size(0)
    distance_matrix = torch.zeros(num_classes, num_classes)
    
    for i in range(num_classes):
        for j in range(num_classes):
            if i != j:
                # Cosine distance
                distance = 1 - torch.dot(forced_prototypes[i], forced_prototypes[j])
                distance_matrix[i, j] = distance
    
    # Plot distance matrix
    plt.figure(figsize=(10, 8))
    plt.imshow(distance_matrix.numpy(), cmap='viridis')
    plt.colorbar(label='Cosine Distance')
    plt.title(f'Class Prototype Distances - Epoch {epoch}')
    plt.xlabel('Class Index')
    plt.ylabel('Class Index')
    
    # Add class labels
    class_names = ['plane', 'car', 'bird', 'cat', 'deer', 
                   'dog', 'frog', 'horse', 'ship', 'truck']
    plt.xticks(range(num_classes), class_names, rotation=45)
    plt.yticks(range(num_classes), class_names)
    
    plt.tight_layout()
    plt.savefig(f'{save_dir}/prototype_distances_epoch_{epoch}.png')
    plt.close()
    
    # Print average separation
    avg_distance = distance_matrix.mean().item()
    print(f'Epoch {epoch}: Average prototype separation: {avg_distance:.4f}')
    
    return avg_distance

def attack_pgd(model, X, target, alpha, attack_iters, norm, device, epsilon=0, restarts=1, mean=None, std=None):
    # keep model in eval mode; we only want grads wrt delta
    was_training = model.training
    model.eval()
    for p in model.parameters():
        p.requires_grad_(False)

    """Projected Gradient Descent attack"""
    # Ensure inputs are on the correct device
    X = X.to(device)
    target = target.to(device)
    #lower_limit = ((0 - mean) / std).to(device)
    #upper_limit = ((1 - mean) / std).to(device)
    #std_avg = std.mean().item() 
    # scale eps/alpha for normalized space (per-channel), shape (1,3,1,1)
    #epsilon_scaled = epsilon / std_avg
    #alpha_scaled = alpha / std_avg
    epsilon_scaled=epsilon
    alpha_scaled=alpha
    
    # Initialize perturbation
    delta = torch.zeros_like(X, device=device)
    if norm == "l_inf":
        delta.uniform_(-epsilon_scaled, epsilon_scaled)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0), -1)
        n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r / n * epsilon_scaled
    else:
        raise ValueError(f"Norm {norm} not supported")
    
    delta = clamp(delta, lower_limit - X, upper_limit - X)
    # Set requires_grad AFTER all initialization operations
    delta.requires_grad = True
    
    for _ in range(attack_iters):
        # make sure we build a fresh small graph each step:
        delta.requires_grad_(True)
        # Forward pass
        output = model(X + delta)
        loss = F.cross_entropy(output, target)
        
        # Backward pass
        loss.backward()
        
        # Get gradient
        grad = delta.grad.detach()
        d = delta[:, :, :, :]
        g = grad[:, :, :, :]
        x = X[:, :, :, :]
        
        if norm == "l_inf":
            d = torch.clamp(d + alpha_scaled * torch.sign(g), min=-epsilon_scaled, max=epsilon_scaled)
        elif norm == "l_2":
            g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
            scaled_g = g / (g_norm + 1e-10)
            d = (d + scaled_g * alpha_scaled).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon_scaled).view_as(d)
        
        d = clamp(d, lower_limit - x, upper_limit - x)
        delta.data[:, :, :, :] = d
        delta.grad.zero_()
    
    return delta

def attack_cw(model, X, target, alpha, attack_iters, norm, device, epsilon=0, restarts=1, mean=None, std=None):
    """Carlini & Wagner attack with proper normalized space handling"""
    X = X.to(device)
    target = target.to(device)
    epsilon_scaled = epsilon
    alpha_scaled = alpha
    
    # Initialize perturbation
    delta = torch.zeros_like(X, device=device)
    if norm == "l_inf":
        delta.uniform_(-epsilon_scaled, epsilon_scaled)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0), -1)
        n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r / n * epsilon_scaled
    else:
        raise ValueError(f"Norm {norm} not supported")
    
    # Project to valid bounds
    delta = clamp(delta, lower_limit - X, upper_limit - X)
    delta.requires_grad = True
    
    for _ in range(attack_iters):
        # Forward pass
        output = model(X + delta)
        
        # C&W loss: maximize the difference between correct and wrong logits
        num_classes = output.size(1)
        label_mask = torch.zeros_like(output)
        label_mask.scatter_(1, target.unsqueeze(1), 1)
        
        correct_logit = torch.sum(label_mask * output, dim=1)
        wrong_logit, _ = torch.max((1 - label_mask) * output - 1e4 * label_mask, dim=1)
        
        # C&W loss: minimize (correct_logit - wrong_logit + margin)
        loss = -torch.sum(F.relu(correct_logit - wrong_logit + 50))
        
        # Backward pass
        loss.backward()
        
        # Get gradient
        grad = delta.grad.detach()
        d = delta[:, :, :, :]
        g = grad[:, :, :, :]
        x = X[:, :, :, :]
        
        if norm == "l_inf":
            d = torch.clamp(d + alpha_scaled * torch.sign(g), min=-epsilon_scaled, max=epsilon_scaled)
        elif norm == "l_2":
            g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
            scaled_g = g / (g_norm + 1e-10)
            d = (d + scaled_g * alpha_scaled).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon_scaled).view_as(d)
        
        # Project to valid bounds
        d = clamp(d, lower_limit - x, upper_limit - x)
        delta.data[:, :, :, :] = d
        delta.grad.zero_()
    
    return delta

def attack_auto(model, X, target, epsilon, device, attacks_to_run=['apgd-ce', 'apgd-dlr']):
    """AutoAttack - ensemble of multiple attacks"""
    if not AUTOATTACK_AVAILABLE:
        print("AutoAttack not available, falling back to PGD")
        return attack_pgd(model, X, target, 2/255, 10, 'l_inf', device, epsilon, mean=CIFAR_MEAN, std=CIFAR_STD)
    
    # Create forward pass function for AutoAttack
    def forward_pass(images):
        return model(images)
    
    # Initialize AutoAttack adversary
    adversary = AutoAttack(forward_pass, norm='Linf', eps=epsilon, version='standard', verbose=False, device=device)
    adversary.attacks_to_run = attacks_to_run
    
    # Run AutoAttack
    try:
        x_adv = adversary.run_standard_evaluation(X, target, bs=X.shape[0])
        # AutoAttack returns perturbed images, so we need to compute the perturbation
        delta = x_adv - X
        return delta
    except Exception as e:
        print(f"AutoAttack failed: {e}, falling back to PGD")
        return attack_pgd(model, X, target, 2/255, 10, 'l_inf', device, epsilon, mean=CIFAR_MEAN, std=CIFAR_STD)

def evaluate_robustness(model, testloader, device, attack_type='pgd', epsilon=8/255, 
                       alpha=2/255, attack_iters=10, norm='l_inf', auto_attacks=['apgd-ce', 'apgd-dlr']):
    """Evaluate model robustness against adversarial attacks"""
    model.eval()
    
    # Metrics
    clean_correct = 0
    adv_correct = 0
    total = 0
    
    print(f'Evaluating robustness with {attack_type.upper()} attack (ε={epsilon:.3f})')
    if attack_type == 'auto':
        print(f'AutoAttack methods: {auto_attacks}')
    
    for batch_idx, (inputs, targets) in enumerate(tqdm(testloader, desc='Evaluating')):
        inputs, targets = inputs.to(device), targets.to(device)
        batch_size = inputs.size(0)
        
        # Clean accuracy
        with torch.no_grad():
            clean_outputs = model(inputs)
            _, clean_predicted = clean_outputs.max(1)
            clean_correct += clean_predicted.eq(targets).sum().item()
        
        # Generate adversarial examples
        if attack_type == 'pgd':
            delta = attack_pgd(model, inputs, targets, alpha, attack_iters, norm, device, epsilon, mean=CIFAR_MEAN, std=CIFAR_STD)
        elif attack_type == 'cw':
            delta = attack_cw(model, inputs, targets, alpha, attack_iters, norm, device, epsilon, mean=CIFAR_MEAN, std=CIFAR_STD)
        elif attack_type == 'auto':
            delta = attack_auto(model, inputs, targets, epsilon, device, auto_attacks)
        else:
            raise ValueError(f"Attack type {attack_type} not supported")
        
        # Adversarial accuracy
        with torch.no_grad():
            adv_outputs = model(inputs + delta)
            _, adv_predicted = adv_outputs.max(1)
            adv_correct += adv_predicted.eq(targets).sum().item()
        
        total += batch_size
        
        # Progress update
        if batch_idx % 50 == 0:
            clean_acc = 100. * clean_correct / total
            adv_acc = 100. * adv_correct / total
            print(f'Batch {batch_idx}: Clean Acc: {clean_acc:.2f}%, Adv Acc: {adv_acc:.2f}%')
    
    # Final results
    clean_acc = 100. * clean_correct / total
    adv_acc = 100. * adv_correct / total
    
    print(f'\nFinal Results:')
    print(f'Clean Accuracy: {clean_acc:.2f}%')
    print(f'Adversarial Accuracy: {adv_acc:.2f}%')
    print(f'Robustness Drop: {clean_acc - adv_acc:.2f}%')
    
    return clean_acc, adv_acc

def main():
    parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Training with Dynamic Prototypes')
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
    parser.add_argument('--load_model', '-l', type=str, help='load pretrained model from main.py checkpoint and convert to prototype architecture')
    parser.add_argument('--checkpoint', type=str, default='./checkpoint_prototype/ckpt_prototype.pth', help='checkpoint path')
    parser.add_argument('--epochs', type=int, default=200, help='number of training epochs')
    parser.add_argument('--alpha', type=float, default=1.0, help='weight for classification loss')
    parser.add_argument('--beta', type=float, default=0.1, help='weight for separation regularization')
    parser.add_argument('--margin', type=float, default=0.5, help='minimum distance between prototypes')
    parser.add_argument('--eval_robustness', action='store_true', help='evaluate robustness after training')
    parser.add_argument('--attack', type=str, default='pgd', choices=['pgd', 'auto', 'cw'], help='attack type for robustness evaluation')
    parser.add_argument('--attack_epsilon', type=int, default=2, help='perturbation budget for robustness evaluation (will be divided by 255)')
    parser.add_argument('--attack_iters', type=int, default=100, help='number of attack iterations')
    parser.add_argument('--attack_stepsize', type=int, default=2, help='attack step size (will be divided by 255)')
    parser.add_argument('--norm', type=str, default='l_2', choices=['l_inf', 'l_2'], help='norm for robustness evaluation')
    parser.add_argument('--pretrained', action='store_true', help='use pretrained ImageNet weights for backbone')
    parser.add_argument('--backbone', type=str, default='resnet18', 
                       choices=['resnet18', 'resnet50', 'vgg19', 'densenet121'], 
                       help='backbone architecture')
    parser.add_argument('--auto_attacks', type=str, nargs='+', default=['apgd-ce', 'apgd-dlr'], 
                       help='AutoAttack methods to use (default: apgd-ce apgd-dlr)')
    parser.add_argument('--patience', type=int, default=20, help='number of epochs to wait for validation accuracy improvement before early stopping')
    parser.add_argument('--eval_only', action='store_true', help='evaluate existing checkpoint without training')
    
    args = parser.parse_args()
    
    # Convert epsilon from integer to float (divide by 255)
    args.attack_epsilon = args.attack_epsilon / 255.0
    print(f'Using epsilon: {args.attack_epsilon:.6f} ({args.attack_epsilon*255:.0f}/255)')
    
    # Convert attack_alpha from integer to float (divide by 255)
    args.attack_stepsize = args.attack_stepsize / 255.0
    print(f'Using attack_stepsize: {args.attack_stepsize:.6f} ({args.attack_stepsize*255:.0f}/255)')
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch
    
    print(f'Using device: {device}')
    print(f'Backbone: {args.backbone}')
    print(f'Pretrained: {args.pretrained}')
    print(f'Alpha (classification): {args.alpha}, Beta (separation): {args.beta}, Margin: {args.margin}')
    print(f'Early stopping patience: {args.patience} epochs')
    
    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True, num_workers=2)
    
    # Split test set into validation and test sets
    testset_full = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    
    # Set random seed for reproducible split
    torch.manual_seed(42)
    test_size = len(testset_full)
    val_size = test_size // 2  # 5000 images
    test_size = test_size - val_size  # 5000 images
    
    valset, testset = torch.utils.data.random_split(testset_full, [val_size, test_size])
    
    valloader = torch.utils.data.DataLoader(
        valset, batch_size=100, shuffle=False, num_workers=2)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2)
    
    print(f'Training set: {len(trainset)} images')
    print(f'Validation set: {len(valset)} images')
    print(f'Test set: {len(testset)} images')
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')
    
    # Model
    print('==> Building model..')
    
    # Load pretrained backbone
    if args.pretrained:
        print('==> Loading pretrained ImageNet weights for backbone..')
        if args.backbone == 'resnet18':
            backbone = torchvision.models.resnet18(pretrained=True)
        elif args.backbone == 'resnet50':
            backbone = torchvision.models.resnet50(pretrained=True)
        elif args.backbone == 'vgg19':
            backbone = torchvision.models.vgg19(pretrained=True)
        elif args.backbone == 'densenet121':
            backbone = torchvision.models.densenet121(pretrained=True)
        else:
            print(f'Pretrained {args.backbone} not supported, using random initialization')
            backbone = ResNet18()
        
        # Remove the final classification layer to get embeddings
        if hasattr(backbone, 'fc'):
            backbone.fc = nn.Identity()  # ResNet
        elif hasattr(backbone, 'classifier'):
            backbone.classifier = nn.Identity()  # VGG
        elif hasattr(backbone, 'classifier'):
            backbone.classifier = nn.Identity()  # DenseNet
        
        print(f'==> Loaded pretrained {args.backbone} backbone')
    else:
        print('==> Using random initialization for backbone..')
        if args.backbone == 'resnet18':
            backbone = ResNet18()
            # Remove the final classification layer to get features
            backbone.linear = nn.Identity()
            print('==> Modified ResNet18 to extract features (removed final layer)')
        elif args.backbone == 'vgg19':
            backbone = VGG('VGG19')
            # Remove the final classification layer
            backbone.classifier = nn.Identity()
            print('==> Modified VGG19 to extract features (removed final layer)')
        elif args.backbone == 'densenet121':
            backbone = DenseNet121()
            # Remove the final classification layer
            backbone.classifier = nn.Identity()
            print('==> Modified DenseNet121 to extract features (removed final layer)')
        else:
            backbone = ResNet18()
            backbone.linear = nn.Identity()
            print('==> Modified ResNet18 to extract features (removed final layer)')
    
    net = DynamicPrototypeModel(backbone, num_classes=10, embedding_dim=512)
    net = net.to(device)
    
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True
    
    # Load checkpoint
    if args.resume:
        print(f'==> Loading checkpoint from {args.checkpoint}..')
        if os.path.isfile(args.checkpoint):
            checkpoint = torch.load(args.checkpoint)
            net.load_state_dict(checkpoint['net'])
            best_acc = checkpoint['acc']
            start_epoch = checkpoint['epoch']
            
            # Check if checkpoint has backbone info
            if 'backbone' in checkpoint:
                print(f'==> Loaded checkpoint from epoch {checkpoint["epoch"]} with accuracy {checkpoint["acc"]:.2f}%')
                print(f'==> Checkpoint backbone: {checkpoint["backbone"]}, pretrained: {checkpoint["pretrained"]}')
            else:
                print(f'==> Loaded checkpoint from epoch {checkpoint["epoch"]} with accuracy {checkpoint["acc"]:.2f}%')
        else:
            print('==> No checkpoint found, starting from scratch')

    if args.load_model:
        print(f'==> Loading model from {args.load_model}..')
        if os.path.isfile(args.load_model):
            checkpoint = torch.load(args.load_model)
            
            # Check if this is a standard model checkpoint (from main.py)
            if 'class_prototypes' not in checkpoint['net']:
                print('==> Loading standard model checkpoint (from main.py)...')
                print('==> Converting to prototype model architecture...')
                print('==> Note: This preserves your trained backbone weights but creates new learnable prototypes')
                
                # Create a new backbone with the same architecture but without final layer

                backbone = ResNet18()
                # Remove the final classification layer to get embeddings
                backbone.linear = nn.Identity()
                print('==> Modified ResNet18 to extract features (removed final layer)')

                # Load the backbone weights (excluding the final layer)
                backbone_state_dict = {}
                for key, value in checkpoint['net'].items():
                    # Remove 'module.' prefix if it exists (from DataParallel)
                    if key.startswith('module.'):
                        clean_key = key[7:]  # Remove 'module.' prefix
                    else:
                        clean_key = key
                    
                    # Skip the final classification layer
                    if not clean_key.startswith('linear.') and not clean_key.startswith('classifier.'):
                        backbone_state_dict[clean_key] = value
                
                backbone.load_state_dict(backbone_state_dict)
                print(f'==> Loaded backbone weights from checkpoint')
                print(f'==> Loaded {len(backbone_state_dict)} layers')
                
                # Create new prototype model with the loaded backbone
                net = DynamicPrototypeModel(backbone, num_classes=10, embedding_dim=512)
                net = net.to(device)
                
                if device == 'cuda':
                    net = torch.nn.DataParallel(net)
                    cudnn.benchmark = True
                
                print(f'==> Successfully converted standard model to prototype model')
                print(f'==> Backbone weights loaded from pretrained model')
                
            else:
                print('==> This appears to be a prototype model checkpoint, using --checkpoint instead')
                print('==> Please use --checkpoint for prototype model checkpoints')
                
        else:
            print(f'==> No model file found at {args.load_model}')
            print('==> Starting with random initialization')


    
    criterion = nn.CrossEntropyLoss()
    
    # Handle DataParallel for parameter separation
    if hasattr(net, 'module'):
        # DataParallel case: access underlying model
        backbone_params = list(net.module.backbone.parameters())
        prototype_params = [net.module.class_prototypes]  # This is already a Parameter
    else:
        # Single GPU case: direct access
        backbone_params = list(net.backbone.parameters())
        prototype_params = [net.class_prototypes]  # This is already a Parameter
    
    # Backbone: lower learning rate, more stable
    backbone_optimizer = optim.SGD(backbone_params, lr=args.lr * 0.1,
                                   momentum=0.9, weight_decay=5e-4)
    
    # Prototypes: higher learning rate, faster adaptation
    prototype_optimizer = optim.SGD(prototype_params, lr=args.lr,
                                    momentum=0.9, weight_decay=1e-4)
    
    # Schedulers
    backbone_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(backbone_optimizer, T_max=args.epochs)
    prototype_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(prototype_optimizer, T_max=args.epochs)
    
    # Training
    def train(epoch):
        print('\nEpoch: %d' % epoch)
        net.train()
        train_loss = 0
        classification_loss_sum = 0
        separation_loss_sum = 0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Zero gradients for both optimizers
            backbone_optimizer.zero_grad()
            prototype_optimizer.zero_grad()
            
            # Forward pass
            outputs = net(inputs)
            
            # Compute loss with regularization
            loss, cls_loss, sep_loss = prototype_loss_with_regularization(
                outputs, targets, net, 
                alpha=args.alpha, beta=args.beta, margin=args.margin
            )
            
            loss.backward()
            
            # Step both optimizers
            backbone_optimizer.step()
            prototype_optimizer.step()
    
            # Update metrics
            train_loss += loss.item()
            classification_loss_sum += cls_loss.item()
            separation_loss_sum += sep_loss.item()
            
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
            # Progress bar with fallback
            if PROGRESS_BAR_AVAILABLE:
                progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Cls: %.3f | Sep: %.3f | Acc: %.3f%% (%d/%d)'
                             % (train_loss/(batch_idx+1), 
                                classification_loss_sum/(batch_idx+1),
                                separation_loss_sum/(batch_idx+1),
                                100.*correct/total, correct, total))
            else:
                if batch_idx % 50 == 0:
                    print(f'Batch {batch_idx}: Loss: {train_loss/(batch_idx+1):.3f} | Cls: {classification_loss_sum/(batch_idx+1):.3f} | Sep: {separation_loss_sum/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})')
    
    def test(epoch, dataloader=None, is_validation=False):
        nonlocal best_acc
        if dataloader is None:
            dataloader = testloader
            set_name = "Test"
        else:
            set_name = "Validation" if is_validation else "Test"
        
        print(f'==> {set_name}ing..')
        net.eval()
        test_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(dataloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                
                # Compute loss with regularization for monitoring
                loss, cls_loss, sep_loss = prototype_loss_with_regularization(
                    outputs, targets, net, 
                    alpha=args.alpha, beta=args.beta, margin=args.margin
                )
                
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        # Calculate overall accuracy
        acc = 100.*correct/total
        avg_loss = test_loss / len(dataloader)
        
        print(f'{set_name} Results - Loss: {avg_loss:.3f} | Accuracy: {acc:.2f}% ({correct}/{total})')
        
        # Save checkpoint if validation accuracy improved (only during training)
        if is_validation and acc > best_acc:
            print('Saving..')
            state = {
                'net': net.state_dict(),
                'acc': acc,
                'epoch': epoch,
                'alpha': args.alpha,
                'beta': args.beta,
                'margin': args.margin,
                'backbone': args.backbone,
                'pretrained': args.pretrained,
            }
            if not os.path.isdir('checkpoint_prototype'):
                os.makedirs('checkpoint_prototype')
            torch.save(state, args.checkpoint)
            best_acc = acc
            print(f'New best validation accuracy: {best_acc:.2f}%')
        
        return acc
    
    def evaluate_adversarial_robustness(epoch, dataloader=None, is_validation=False):
        """Evaluate model robustness against PGD attack"""
        if dataloader is None:
            dataloader = testloader
            set_name = "test"
        else:
            set_name = "validation" if is_validation else "test"
        
        print(f'==> Evaluating adversarial robustness on {set_name} set..')
        
        clean_correct = 0
        adv_correct = 0
        total = 0
        
        # First, evaluate clean accuracy on the entire set
        print(f'==> Computing clean accuracy on {set_name} set...')
        with torch.no_grad():
            net.eval()
            for batch_idx, (inputs, targets) in enumerate(dataloader):
                inputs, targets = inputs.to(device), targets.to(device)
                batch_size = inputs.size(0)
                
                clean_outputs = net(inputs)
                _, clean_predicted = clean_outputs.max(1)
                clean_correct += clean_predicted.eq(targets).sum().item()
                total += batch_size
        
        clean_acc = 100. * clean_correct / total
        print(f'Clean accuracy on {set_name} set: {clean_acc:.2f}%')
        
        # Reset counters for adversarial evaluation
        adv_correct = 0
        
        # Now evaluate adversarial accuracy
        print(f'==> Computing adversarial accuracy on {set_name} set...')
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            print(f'Batch {batch_idx} of {len(dataloader)}')
            inputs, targets = inputs.to(device), targets.to(device)
            print(f'Inputs shape: {inputs.shape}')
            batch_size = inputs.size(0)
            print(f'Batch size: {batch_size}')
            # Generate adversarial examples with PGD
            # Use context manager to temporarily enable gradients without affecting training state
            net.eval()
            print(f'Net is in train mode')
            if args.attack == 'pgd':
                delta = attack_pgd(net, inputs, targets, args.attack_stepsize, args.attack_iters, args.norm, device, args.attack_epsilon, mean=CIFAR_MEAN, std=CIFAR_STD)
            elif args.attack == 'cw':
                delta = attack_cw(net, inputs, targets, args.attack_stepsize, args.attack_iters, args.norm, device, args.attack_epsilon, mean=CIFAR_MEAN, std=CIFAR_STD)
            elif args.attack == 'auto':
                delta = attack_auto(net, inputs, targets, args.attack_stepsize, args.attack_iters, args.norm, device, args.attack_epsilon, mean=CIFAR_MEAN, std=CIFAR_STD)
            
            # Adversarial accuracy
            with torch.no_grad():
                adv_outputs = net(inputs + delta)
                _, adv_predicted = adv_outputs.max(1)
                adv_correct += adv_predicted.eq(targets).sum().item()
        
        adv_acc = 100. * adv_correct / total
        robustness_drop = clean_acc - adv_acc
        
        print(f'Adversarial Results on {set_name} set - Clean: {clean_acc:.2f}% | Adversarial: {adv_acc:.2f}% | Drop: {robustness_drop:.2f}%')
        
        return clean_acc, adv_acc, robustness_drop
    
    
    if args.eval_only:
        #load checkpoint
        print(f'==> Loading checkpoint from {args.checkpoint}..')
        checkpoint = torch.load(args.checkpoint)
        net.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']
        
        print('==> Evaluating on test set..')
        clean_acc, adv_acc = evaluate_robustness(
            net, testloader, device, 
            attack_type=args.attack,
            epsilon=args.attack_epsilon,
            alpha=args.attack_stepsize,
            attack_iters=args.attack_iters,
            norm=args.norm,
            auto_attacks=args.auto_attacks
        )
        
        # Save robustness results
        robustness_results = {
            'clean_accuracy': clean_acc,
            'adversarial_accuracy': adv_acc,
            'robustness_drop': clean_acc - adv_acc,
            'attack_type': args.attack,
            'epsilon': args.attack_epsilon,
            'final_accuracy': best_acc,
            'auto_attacks': args.auto_attacks if args.attack == 'auto' else None
        }
        print(f'\nRobustness Results: {robustness_results} with epsilon {args.attack_epsilon} and alpha {args.attack_stepsize}')
        return
        
    # Training loop
    print('==> Starting training..')
    
    # Early stopping variables
    best_val_acc = 0
    patience_counter = 0
    
    for epoch in range(start_epoch, start_epoch + args.epochs):
        train(epoch)
        # Use validation set during training for model selection
        val_acc = test(epoch, valloader, is_validation=True)
        
        # Early stopping check
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            print(f'Epoch {epoch}: Validation accuracy improved to {best_val_acc:.2f}%')
        else:
            patience_counter += 1
            print(f'Epoch {epoch}: No improvement for {patience_counter} epochs (best: {best_val_acc:.2f}%)')
            
            if patience_counter >= args.patience:
                print(f'Early stopping triggered after {patience_counter} epochs without improvement')
                print(f'Best validation accuracy: {best_val_acc:.2f}%')
                break
        
        # Evaluate adversarial robustness on validation set every 5 epochs
        if epoch % 5 == 0:
            clean_robust, adv_robust, robustness_drop = evaluate_adversarial_robustness(epoch, valloader, is_validation=True)
            print(f'Epoch {epoch}: Val Acc: {clean_robust:.2f}% | Val Adv Acc: {adv_robust:.2f}% | Val Robustness Drop: {robustness_drop:.2f}%')
        
        backbone_scheduler.step()
        prototype_scheduler.step()
        
        # Visualize prototypes every 10 epochs
        if epoch % 10 == 0:
            try:
                avg_separation = visualize_prototypes(net, epoch)
                print(f'Epoch {epoch}: Average prototype separation: {avg_separation:.4f}')
            except Exception as e:
                print(f'Epoch {epoch}: Prototype visualization failed: {e}')
                print(f'Epoch {epoch}: Continuing without visualization...')
    
    print('==> Training completed!')
    
    # Final evaluation on test set (unseen during training)
    print('==> Final evaluation on test set..')
    final_test_acc = test(epoch, testloader, is_validation=False)
    print(f'Final test accuracy: {final_test_acc:.2f}%')
    
    # Evaluate robustness if requested
    if args.eval_robustness:
        print('==> Evaluating robustness on test set..')
        clean_acc, adv_acc = evaluate_robustness(
            net, testloader, device, 
            attack_type=args.attack,
            epsilon=args.attack_epsilon,
            alpha=args.attack_stepsize,
            attack_iters=args.attack_iters,
            norm=args.norm,
            auto_attacks=args.auto_attacks
        )
        
        # Save robustness results
        robustness_results = {
            'clean_accuracy': clean_acc,
            'adversarial_accuracy': adv_acc,
            'robustness_drop': clean_acc - adv_acc,
            'attack_type': args.attack,
            'epsilon': args.attack_epsilon,
            'final_accuracy': best_acc,
            'auto_attacks': args.auto_attacks if args.attack == 'auto' else None
        }
        
        print(f'\nRobustness Results: {robustness_results} with epsilon {args.attack_epsilon} and alpha {args.attack_stepsize}')
        print('==> Evaluating robustness on test set..')
        clean_acc, adv_acc = evaluate_robustness(
            net, testloader, device, 
            attack_type=args.attack,
            epsilon=args.attack_epsilon*2,
            alpha=args.attack_stepsize*2,
            attack_iters=args.attack_iters,
            norm=args.norm,
            auto_attacks=args.auto_attacks
        )
        
        # Save robustness results
        robustness_results = {
            'clean_accuracy': clean_acc,
            'adversarial_accuracy': adv_acc,
            'robustness_drop': clean_acc - adv_acc,
            'attack_type': args.attack,
            'epsilon': args.attack_epsilon,
            'final_accuracy': best_acc,
            'auto_attacks': args.auto_attacks if args.attack == 'auto' else None
        }
        
        print(f'\nRobustness Results: {robustness_results} with epsilon {args.attack_epsilon*2} and alpha {args.attack_stepsize*2}')
        # Save to file
        results_file = args.checkpoint.replace('.pth', '_robustness.txt')
        with open(results_file, 'w') as f:
            for key, value in robustness_results.items():
                f.write(f'{key}: {value}\n')
        print(f'Robustness results saved to: {results_file}')

if __name__ == '__main__':
    main()
