"""
Run SimLAP on CIFAR-10: pyhton main_minimal.py  --use_filter 
Run Supervised on CIFAR-10: python main_minimal.py  
"""
from collections import defaultdict
import math
from timm.data import transforms
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import argparse
import datetime
import time
import json
import os
import sys
from pathlib import Path
from typing import Iterable, Optional
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from timm.data import create_transform

# Optional wandb import
try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False
    wandb = None



class SmoothedValue:
    """
    A simple metric tracker that maintains a running average of values.
    
    This is a minimal implementation that tracks a single value rather than a window.
    In a full implementation, this would maintain a sliding window of values for smoothing.
    
    Args:
        window_size (int): Size of the sliding window (currently unused in minimal version)
        fmt (str): Format string for displaying the value
    """
    def __init__(self, window_size=1, fmt='{value:.6f}'):
        self.window_size = window_size
        self.fmt = fmt
        self.value = 0.0
    
    def update(self, value):
        """Update the tracked value with a new measurement."""
        self.value = value
    
    @property
    def global_avg(self):
        """Return the current value (in full version, this would be the window average)."""
        return self.value
    
def build_head(num_layers, input_dim, mlp_dim, output_dim, hidden_bn=True, activation=nn.ReLU, last_norm='bn'):
    """
    Build a multi-layer perceptron (MLP) head for feature projection.
    
    This function constructs a neural network head commonly used in self-supervised learning
    to project backbone features to a different dimensional space. The architecture follows
    common practices in contrastive learning frameworks.
    
    Args:
        num_layers (int): Number of linear layers in the MLP
        input_dim (int): Input feature dimension
        mlp_dim (int): Hidden layer dimension
        output_dim (int): Output feature dimension
        hidden_bn (bool): Whether to use batch normalization in hidden layers
        activation: Activation function for hidden layers (default: ReLU)
        last_norm (str): Normalization for the final layer ('bn', 'ln', or 'none')
        
    Returns:
        nn.Sequential: The constructed MLP head
        
    Architecture Details:
    - Hidden layers: Linear -> BatchNorm -> Activation
    - Final layer: Linear -> Normalization (optional)
    - No bias in final layer (common practice in contrastive learning)
    - BatchNorm in final layer has no learnable parameters (affine=False)
    """
    mlp = []
    for l in range(num_layers):
        # Determine input and output dimensions for this layer
        dim1 = input_dim if l == 0 else mlp_dim
        dim2 = output_dim if l == num_layers - 1 else mlp_dim

        # Add linear layer (no bias in final layer)
        if l == num_layers-1:
            mlp.append(nn.Linear(dim1, dim2, bias=False))
        else:
            mlp.append(nn.Linear(dim1, dim2, bias=True))

        # Add normalization and activation for hidden layers
        if l < num_layers - 1:
            if hidden_bn:
                mlp.append(nn.BatchNorm1d(dim2))
            mlp.append(activation())
        else:
            # Add final normalization layer
            if last_norm=='bn':
                mlp.append(nn.BatchNorm1d(dim2, affine=False))  # No learnable parameters
            elif last_norm=='ln':
                mlp.append(nn.LayerNorm(dim2))
            elif last_norm=='none':
                pass  # No normalization
            else:
                raise NotImplementedError(f"last_norm={last_norm} not implemented")

    return nn.Sequential(*mlp)

def multipos_ce_loss(logits, pos_mask, neg_mask=None):
    """
    Multi-positive contrastive loss (extended InfoNCE).
    
    This loss function extends the standard InfoNCE loss to handle multiple positive samples
    per anchor, which is crucial for the SimLAP framework where we compare against arbitrary
    positive pairs rather than just augmented views.
    
    Mathematical Formulation:
    L = -log(Σ_pos exp(sim(z, z+)) / (Σ_pos exp(sim(z, z+)) + Σ_neg exp(sim(z, z-))))
    
    Args:
        logits (torch.Tensor): Similarity scores between anchor and all other samples [N, N]
        pos_mask (torch.Tensor): Boolean mask indicating positive pairs [N, N]
        neg_mask (torch.Tensor, optional): Boolean mask indicating negative pairs [N, N]
                                        If None, computed as ~pos_mask
        
    Returns:
        torch.Tensor: The computed contrastive loss (scalar)
        
    Key Differences from Standard InfoNCE:
    1. Handles multiple positives per anchor (not just one)
    2. Uses sum over positives in numerator (not just single positive)
    3. Normalizes by total number of positive pairs
    """
    if neg_mask is None:
        neg_mask = ~pos_mask
    
    # Center logits to improve numerical stability
    logits = logits - logits.mean(1, keepdim=True)
    
    # Convert logits to similarity scores (exponential)
    similarity = logits.exp()
    N = similarity.size(0)
 
    # Compute InfoNCE loss with multiple positives
    # Sum over all negative similarities for each anchor
    neg = (similarity * neg_mask).sum(1, keepdim=True)
    
    # Compute loss: -log(positive_similarities / (positive_similarities + negative_similarities))
    # For multiple positives, we sum over all positive similarities
    loss = torch.sum(pos_mask * (torch.log(similarity + neg) - logits)) / pos_mask.sum()
    loss = loss.mean()
   
    return loss

class OpenGate(nn.Module):
    """
    A simple gate that always returns ones (no filtering).
    
    This is a baseline gate implementation that doesn't perform any feature filtering.
    It's useful for ablation studies to understand the impact of the filtering mechanism.
    
    Args:
        embed_dim (int): Dimension of the feature embeddings
        num_classes (int): Number of classes (unused in this implementation)
    """
    def __init__(self, embed_dim, num_classes):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_classes = num_classes

    def forward(self, y1, y2=None, log=None):
        """
        Forward pass that returns a gate of all ones.
        
        Args:
            y1 (torch.Tensor): Primary labels [batch_size]
            y2 (torch.Tensor, optional): Secondary labels [batch_size]
            log (dict, optional): Logging dictionary (unused)
            
        Returns:
            torch.Tensor: Gate values of all ones [batch_size, embed_dim]
        """
        bs = y1.size(0)
        gate = torch.ones(bs, self.embed_dim, device=y1.device)
        return gate

class BasicGate(OpenGate):
    """
    A learnable gate that generates feature masks based on class labels.
    
    This gate learns to selectively filter features based on the class labels of the samples.
    It uses an embedding layer to convert class labels to dense representations, then applies
    an MLP to generate gate values that control which features are emphasized.
    
    Args:
        embed_dim (int): Dimension of the feature embeddings
        num_classes (int): Number of classes
        in_dim (int): Input dimension for the MLP
        mlp_dim (int): Hidden dimension of the MLP
        lam (float): Regularization parameter (currently unused)
        fuse (bool): Whether to fuse multiple label embeddings or multiply gates
    """
    def __init__(self, embed_dim, num_classes=1000, in_dim=512, mlp_dim=1024, lam=0, fuse=True):
        super().__init__(embed_dim, num_classes)
        
        # MLP that maps label embeddings to gate values
        self.mlp = nn.Sequential(
            nn.ReLU(), nn.BatchNorm1d(in_dim),
            nn.Linear(in_dim, mlp_dim),
            nn.ReLU(), nn.BatchNorm1d(mlp_dim),
            nn.Linear(mlp_dim, embed_dim),            
        )
        # Embedding layer to convert class labels to dense vectors
        self.label_embedding = nn.Embedding(num_classes, in_dim)
        self.lam = lam
        self.fuse = fuse

    def forward(self, y1, y2=None, log=None):
        """
        Generate gate values based on class labels.
        
        Args:
            y1 (torch.Tensor): Primary labels [batch_size]
            y2 (torch.Tensor, optional): Secondary labels [batch_size]
            log (dict, optional): Logging dictionary (unused)
            
        Returns:
            torch.Tensor: Gate values [batch_size, embed_dim] in range [0,1]
        """
        if self.fuse:      
            # Fuse mode: average label embeddings if y2 is provided
            if y2 is None:
                label_embeds = self.label_embedding(y1)
            else:
                label_embeds = (self.label_embedding(y1) + self.label_embedding(y2))/2
            
            logits = self.mlp(label_embeds)
            gate = logits.sigmoid()
        else:
            # Non-fuse mode: compute separate gates and multiply them
            if y2 is None:
                gate = self.mlp(self.label_embedding(y1)).sigmoid()
            else:
                gate1 = self.mlp(self.label_embedding(y1)).sigmoid()
                gate2 = self.mlp(self.label_embedding(y2)).sigmoid()
                gate = gate1 * gate2  # Element-wise multiplication
        return gate

class Filter(nn.Module):
    """
    Filter module that applies learned gates to features for selective emphasis.
    
    This module is a key component of the SimLAP framework. It uses a gate mechanism
    to selectively filter features based on class labels, allowing the model to focus
    on relevant features for each class during contrastive learning.
    
    Args:
        num_classes (int): Number of classes
        embed_dim (int): Dimension of the feature embeddings
        gate_fn: Gate function class (default: BasicGate)
    """
    def __init__(self, num_classes, embed_dim, gate_fn=BasicGate):
        super().__init__()
        self.embed_dim = embed_dim
        self.gate = gate_fn(embed_dim, num_classes=num_classes)
    
    def forward(self, x1, x2, y1, y2=None):
        """
        Apply learned gates to filter features.
        
        Args:
            x1 (torch.Tensor): Anchor features [batch_size, embed_dim]
            x2 (torch.Tensor): Candidate features [num_candidates, embed_dim]
            y1 (torch.Tensor): Anchor labels [batch_size]
            y2 (torch.Tensor, optional): Candidate labels [batch_size]
            
        Returns:
            tuple: Filtered and normalized features (x1_filtered, x2_filtered)
        """
        # Generate gate values based on labels
        gate = self.gate(y1, y2)
        
        # Apply gates to features (element-wise multiplication)
        # x1: [batch_size, embed_dim] * [batch_size, embed_dim] -> [batch_size, embed_dim]
        x1 = torch.einsum("bk,bk->bk", x1, gate)
        # x2: [num_candidates, embed_dim] * [batch_size, embed_dim] -> [batch_size, num_candidates, embed_dim]
        x2 = torch.einsum("nk,bk->bnk", x2, gate)
        
        # L2 normalize the filtered features
        x1 = F.normalize(x1, p=2, dim=-1)
        x2 = F.normalize(x2, p=2, dim=-1)
        return x1, x2
    
    def contrast(self, x1, x2):
        """
        Compute contrastive similarities between filtered features.
        
        Args:
            x1 (torch.Tensor): Anchor features [batch_size, embed_dim]
            x2 (torch.Tensor): Candidate features [batch_size, num_candidates, embed_dim]
            
        Returns:
            torch.Tensor: Similarity scores [batch_size, num_candidates]
        """
        # Compute cosine similarity between each anchor and all candidates
        logits = torch.einsum("bj,bnj->bn", x1, x2)
        return logits

class SimpleBackbone(nn.Module):
    """
    A simple CNN backbone for feature extraction.
    
    This is a minimal CNN architecture designed for testing and prototyping.
    It consists of a single convolutional layer followed by global average pooling
    and a linear projection to the desired embedding dimension.
    
    Architecture:
    - Conv2d(3->64, 3x3, padding=1) + ReLU
    - AdaptiveAvgPool2d(7x7) -> Global average pooling
    - Flatten + Linear(64*7*7 -> embed_dim)
    
    Args:
        embed_dim (int): Output embedding dimension (default: 512)
    """
    def __init__(self, embed_dim=512):
        super().__init__()
        self.embed_dim = embed_dim
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),  # 3-channel input (RGB)
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((7, 7)),    # Global average pooling
            nn.Flatten(),                    # Flatten spatial dimensions
            nn.Linear(64 * 7 * 7, embed_dim) # Project to embedding space
        )
    
    def forward(self, x):
        """
        Extract features from input images.
        
        Args:
            x (torch.Tensor): Input images [batch_size, 3, height, width]
            
        Returns:
            torch.Tensor: Feature embeddings [batch_size, embed_dim]
        """
        return self.conv(x)

class SimLAP(nn.Module):
    """
    Simplified SimLAP (Similarity Learning with Arbitrary Positives) model.
    
    SimLAP is a self-supervised learning framework that learns representations by comparing
    samples with arbitrary positive pairs rather than just augmented views of the same image.
    This allows the model to learn more diverse and robust representations.
    
    Key Components:
    1. Backbone: Feature extractor (SimpleBackbone in this implementation)
    2. Projector: MLP head that projects features to contrastive learning space
    3. Filter: Learnable gate mechanism for selective feature emphasis
    4. Classification Head: Linear layer for supervised classification
    
    Args:
        out_dim (int): Output dimension for contrastive learning (default: 256)
        embed_dim (int): Backbone embedding dimension (default: 512)
        mlp_dim (int): Hidden dimension in projector MLP (default: 512)
        type (str): Type of positive sampling ('arbitrary', 'identical', 'distinct')
        temperature (float): Temperature parameter for contrastive loss (default: 0.1)
        num_classes (int): Number of classes for classification (default: 1000)
        
    Positive Sampling Types:
    - 'arbitrary': Random permutation of labels (most general)
    - 'identical': Same labels (like standard contrastive learning)
    - 'distinct': Different labels (hard negatives)
    """
    def __init__(self, out_dim=256, embed_dim=512, mlp_dim=512, type='arbitrary', 
                 temperature=0.1, num_classes=1000):
        super(SimLAP, self).__init__()
        assert type in ['arbitrary', 'identical', 'distinct']
        self.type = type
        self.s = 1/temperature  # Scale factor for contrastive loss
        self.num_classes = num_classes
        self.out_dim = out_dim
        self.embed_dim = embed_dim
        
        # Model components
        self.backbone = SimpleBackbone(embed_dim)  # Feature extractor
        self.projector = build_head(2, embed_dim, mlp_dim, out_dim, last_norm='ln')  # Contrastive projection
        self.filter = Filter(num_classes=num_classes, embed_dim=out_dim)  # Feature filtering
        self.cls_head = nn.Linear(embed_dim, num_classes)  # Classification head

    @torch.no_grad()
    def representation(self, x):
        """
        Extract representations without gradient computation (for evaluation).
        
        Args:
            x (torch.Tensor or list): Input images [batch_size, 3, height, width]
            
        Returns:
            dict: Dictionary containing:
                - 'latent': Backbone features [batch_size, embed_dim]
                - 'z': Projected features for contrastive learning [batch_size, out_dim]
        """
        if isinstance(x, list) or isinstance(x, tuple):
            x = x[0]  # Take first element if input is a list/tuple
        latent = self.backbone(x)  # Extract backbone features
        proj = self.projector(latent)  # Project to contrastive space
        rep = dict(latent=latent, z=proj)
        return rep
    
    def criterion(self, samples, targets, **kwargs):
        """
        Compute the training loss combining contrastive and classification losses.
        
        This method implements the core SimLAP training objective:
        1. Extract features using the backbone
        2. Project features to contrastive learning space
        3. Generate positive pairs based on the sampling type
        4. Compute contrastive loss using the filter mechanism
        5. Add supervised classification loss
        
        Args:
            samples (torch.Tensor): Input images [batch_size, 3, height, width]
            targets (torch.Tensor): Ground truth labels [batch_size]
            **kwargs: Additional arguments (unused)
            
        Returns:
            tuple: (total_loss, log_dict) where log_dict contains training metrics
        """
        self.log = {}
        
        # Extract backbone features
        rep = self.backbone(samples)
        
        # Project to contrastive learning space
        z = self.projector(rep)

        # Generate positive pairs based on sampling type
        y1 = targets  # Anchor labels
        if self.type == 'identical':
            # Standard contrastive learning: same labels
            y2 = targets
        elif self.type == 'distinct':
            # Hard negatives: different labels
            y2 = (targets + torch.randint(1, self.num_classes, (len(targets),), device=targets.device)) % self.num_classes
        else:  # 'arbitrary'
            # SimLAP: random permutation (arbitrary positives)
            y2 = targets[torch.randperm(len(targets), device=targets.device)]

        # Compute contrastive loss with filtering
        loss = self.disparate_loss(z, z, y1, y2)
        
        # Add supervised classification loss, not participated in representation learning!!
        predict = self.cls_head(rep.detach())
        loss += F.cross_entropy(predict, targets)

        # Log feature statistics
        self.log['z@std'] = z.std(0).mean().item()
        return loss, self.log

    def forward(self, samples, **kwargs):
        """
        Forward pass for inference (classification only).
        
        Args:
            samples (torch.Tensor): Input images [batch_size, 3, height, width]
            **kwargs: Additional arguments (unused)
            
        Returns:
            torch.Tensor: Classification logits [batch_size, num_classes]
        """
        rep = self.backbone(samples)
        predict = self.cls_head(rep.detach())
        return predict
    
    def disparate_loss(self, z1, k2, y1, posy):
        """
        Compute the disparate contrastive loss with feature filtering.
        
        This is the core loss function of SimLAP that:
        1. Applies learned filters to features based on class labels
        2. Computes cosine similarities between filtered features
        3. Creates positive/negative masks based on label relationships
        4. Applies multi-positive contrastive loss
        
        Args:
            z1 (torch.Tensor): Anchor features [batch_size, out_dim]
            k2 (torch.Tensor): Candidate features [batch_size, out_dim] (same as z1 in this case)
            y1 (torch.Tensor): Anchor labels [batch_size]
            posy (torch.Tensor): Positive labels [batch_size]
            
        Returns:
            torch.Tensor: Contrastive loss (scalar)
        """
        # Apply learned filters to features
        fz1, fz2 = self.filter(z1, k2, y1, posy)
        
        # Scale factor for temperature scaling
        scale = self.s
        self.log['scale'] = scale
        
        # Compute cosine similarities between filtered features
        cosine = self.filter.contrast(fz1, fz2)
        logits = scale * cosine
        
        # Create masks for different types of pairs
        c1_mask = (y1.unsqueeze(1) == (y1).unsqueeze(0))  # Same class as anchor
        c2_mask = (posy.unsqueeze(1) == (y1).unsqueeze(0))  # Same class as positive
        class_mask = c1_mask | c2_mask  # Any positive relationship
        neg_mask = ~class_mask  # Negative pairs

        # Apply multi-positive contrastive loss
        loss = multipos_ce_loss(logits, c2_mask, neg_mask)
        
        # Log similarity statistics for monitoring
        self.log['cosine@c1'] = ((cosine*c1_mask).sum()/c1_mask.sum()).item()
        self.log['cosine@c2'] = ((cosine*c2_mask).sum()/c2_mask.sum()).item()
        self.log['cosine@neg'] = ((cosine*neg_mask).sum()/neg_mask.sum()).item()
        return loss

class Supervised(nn.Module):
    """
    Simple supervised learning model with cross-entropy loss.
    
    This is a baseline model for comparison with the SimLAP approach.
    It uses standard supervised learning with cross-entropy loss on both
    the backbone features and projected features.
    
    Args:
        nb_classes (int): Number of classes
        embed_dim (int): Backbone embedding dimension (default: 512)
    """
    def __init__(self, nb_classes, embed_dim=512):
        super().__init__()
        
        self.backbone = SimpleBackbone(embed_dim)  # Feature extractor
        self.cls_head = nn.Linear(embed_dim, nb_classes)  # Main classification head
        self.projector = build_head(2, embed_dim, 2048, nb_classes, last_norm='none')  # Auxiliary projection
        self.nb_classes = nb_classes

    def criterion(self, images, targets):
        """
        Compute supervised cross-entropy loss.
        
        Args:
            images (torch.Tensor): Input images [batch_size, 3, height, width]
            targets (torch.Tensor): Ground truth labels [batch_size]
            
        Returns:
            tuple: (total_loss, log_dict) where log_dict contains loss metrics
        """
        latent = self.backbone(images)
        proj = self.projector(latent)
        
        # Combine losses from both heads
        loss = F.cross_entropy(proj, targets)  # Projection head loss
        loss += F.cross_entropy(self.cls_head(latent.detach()), targets)  # Main head loss (detached)
        
        return loss, {'loss': loss.item()}
    
    def forward(self, images):
        """
        Forward pass for inference.
        
        Args:
            images (torch.Tensor): Input images [batch_size, 3, height, width]
            
        Returns:
            torch.Tensor: Classification logits [batch_size, nb_classes]
        """
        latent = self.backbone(images)
        return self.cls_head(latent)



def get_args_parser():
    """
    Get argument parser with minimal arguments for SimLAP training.
    
    Returns:
        argparse.ArgumentParser: Configured argument parser with training options
    """
    parser = argparse.ArgumentParser('Minimal SimLAP training script', add_help=False)
    parser.add_argument('--batch_size', default=32, type=int, help='Batch size for training')
    parser.add_argument('--epochs', default=10, type=int, help='Number of training epochs')
    parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate for optimizer')
    parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use (cuda/cpu)')
    parser.add_argument('--seed', default=0, type=int, help='Random seed for reproducibility')
    parser.add_argument('--use_filter', action='store_true', help='Use SimLAP model with filtering (otherwise use supervised baseline)')
    parser.add_argument('--num_classes', default=10, type=int, help='Number of classes in dataset')
    parser.add_argument('--output_dir', default=None, type=str, help='Directory to save checkpoints and logs')
    return parser

def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, lr_scheduler):
    """
    Train the model for one epoch.
    
    This function implements the standard training loop:
    1. Set model to training mode
    2. Iterate through batches
    3. Forward pass with mixed precision
    4. Backward pass and optimization
    5. Log metrics and handle edge cases
    
    Args:
        model: The neural network model to train
        criterion: Loss function (model.criterion)
        data_loader: DataLoader for training data
        optimizer: Optimizer for parameter updates
        device: Device to run training on
        epoch: Current epoch number
        loss_scaler: Optional loss scaler for mixed precision (None in this implementation)
        lr_scheduler: Learning rate scheduler
        
    Returns:
        dict: Dictionary of averaged metrics for this epoch
    """
    model.train(True)
    metric_logger = defaultdict(list)
    
    for itr, (samples, targets) in enumerate(data_loader):
        # Move data to device
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        
        # Update learning rate (cosine annealing with warmup)
        lr_scheduler.step(epoch + itr/len(data_loader))
        
        # Forward pass with mixed precision
        with torch.amp.autocast('cuda' if device.type == 'cuda' else 'cpu'):
            loss, log = criterion(samples, targets)
        loss_value = loss.item()
        
        # Log metrics
        metric_logger['loss'].append(loss_value)
        metric_logger['lr'].append(lr_scheduler.get_last_lr()[0])
        for k, v in log.items():
            metric_logger[k].append(v)
        
        # Backward pass and optimization
        if loss_scaler is not None:
            # Mixed precision training (not used in this minimal version)
            loss_scaler(loss, optimizer, parameters=model.parameters(), need_update=True)
        else:
            # Standard training
            loss.backward()
            optimizer.step()
        
        optimizer.zero_grad()
        
        # Check for numerical issues
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)
        
        # Optional wandb logging
        lr = optimizer.param_groups[-1]["lr"]
        if WANDB_AVAILABLE and wandb.run:
            wandb.log({'loss': loss, 'lr': lr})
    
    # Average metrics over the epoch
    metric_logger = {k: np.mean(v) for k, v in metric_logger.items()}
    return metric_logger

@torch.no_grad()
def evaluate(data_loader, model, device):
    """
    Evaluate the model on a validation/test dataset.
    
    This function computes classification accuracy and loss without gradient computation.
    It's used for monitoring training progress and final model evaluation.
    
    Args:
        data_loader: DataLoader for validation/test data
        model: The neural network model to evaluate
        device: Device to run evaluation on
        
    Returns:
        dict: Dictionary containing averaged accuracy and loss metrics
    """
    criterion = torch.nn.CrossEntropyLoss()
    metric_logger = defaultdict(list)
    model.eval()  # Set model to evaluation mode
    
    for images, target in data_loader:
        # Move data to device with appropriate dtypes
        images = images.to(device, non_blocking=True, dtype=torch.float32)
        target = target.to(device, non_blocking=True, dtype=torch.long)
        
        # Forward pass (no gradients)
        output = model(images)
        loss = criterion(output, target)
        
        # Compute top-1 accuracy
        acc1 = (output.argmax(1) == target).float().mean() * 100
        
        # Store metrics
        batch_size = images.shape[0]
        loss = loss.item()
        acc1 = acc1.item()
        metric_logger['acc1'].append(acc1)
        metric_logger['loss'].append(loss)
    
    # Print and return averaged metrics
    print('* Acc@1 {acc1:.3f} loss {loss:.3f}'.format(
        acc1=np.mean(metric_logger['acc1']), 
        loss=np.mean(metric_logger['loss'])))
    return {k: np.mean(v) for k, v in metric_logger.items()}

def main(args):
    """
    Main training function for SimLAP.
    
    This function orchestrates the entire training pipeline:
    1. Set up reproducibility (random seeds)
    2. Create and configure datasets with appropriate transforms
    3. Initialize model (SimLAP or supervised baseline)
    4. Set up optimizer and learning rate scheduler
    5. Run training loop with evaluation
    6. Save checkpoints and logs
    
    Args:
        args: Parsed command line arguments containing training configuration
    """
    # Set random seeds for reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    print("Arguments:", args)
    
    # Create CIFAR-10 dataset with appropriate transforms
    # Training transforms include data augmentation for better generalization
    transform_train = create_transform(
        input_size=32, 
        is_training=True, 
        mean=(0.485, 0.456, 0.406),  # ImageNet normalization
        std=(0.229, 0.224, 0.225), 
        auto_augment='rand-m9-mstd0.5'  # Random augmentation policy
    )
    # Validation transforms: no augmentation, just normalization
    transform_val = create_transform(
        input_size=32, 
        is_training=False, 
        mean=(0.485, 0.456, 0.406), 
        std=(0.229, 0.224, 0.225)
    )  
    
    # Load CIFAR-10 dataset
    train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    val_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform_val)
    
    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
    
    # Initialize model based on configuration
    if args.use_filter:
        # SimLAP model with contrastive learning and feature filtering
        model = SimLAP(num_classes=args.num_classes)
    else:
        # Supervised baseline model
        model = Supervised(args.num_classes)
    
    print(f"Built Model: {model}")
    
    # Move model to device
    device = torch.device(args.device)
    model.to(device)
    
    # Print model statistics
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('Number of parameters:', n_parameters)
    
    # Set up optimizer and learning rate scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.05)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    
    # Get loss function from model
    criterion = model.criterion
    print("Criterion:", criterion)
    
    # Training loop
    start_time = time.time()
    max_accuracy = 0.0
    
    for epoch in tqdm(range(args.epochs)):
        # Train for one epoch
        train_stats = train_one_epoch(
            model, criterion, train_loader, optimizer, device, epoch, None, lr_scheduler)
        
        # Evaluate on validation set
        test_stats = evaluate(val_loader, model, device)
        
        # Track best accuracy
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')
        
        # Save checkpoint and logs
        if args.output_dir:
            # Append training logs to file
            with open(args.output_dir / 'log.txt', 'a') as f:
                f.write(json.dumps(train_stats) + "\n")
                f.write(json.dumps(test_stats) + "\n")
                
            # Create output directory and save checkpoint
            output_dir = Path(args.output_dir)
            output_dir.mkdir(exist_ok=True)
            
            state = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'args': args,
                'epoch': epoch,
            }
            torch.save(state, output_dir / f'checkpoint-{epoch:03d}.pth')
    
    # Print total training time
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time:', total_time_str)

if __name__ == '__main__':
    parser = argparse.ArgumentParser('Minimal SimLAP training script', parents=[get_args_parser()])
    args = parser.parse_args()
    main(args)
