# engine.py

import time
import torch

def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, log_freq, use_closure=True):
    """
    Trains the model for a single epoch.

    Args:
        model (torch.nn.Module): The model to be trained.
        train_loader (Iterable): The data loader for training data.
        criterion: The loss function.
        optimizer: The optimizer.
        device: The device for training ('cuda' or 'cpu').
        epoch (int): The current epoch number.
        log_freq (int): The frequency of printing logs.
        use_closure (bool): Whether to use an optimizer that requires a closure (e.g., CLAGR).

    Returns:
        dict: A dictionary containing the average loss and accuracy.
    """
    model.train()
    
    # Use simple accumulators instead of MetricLogger
    total_loss = 0.0
    total_acc1 = 0.0
    total_acc5 = 0.0
    total_samples = 0
    
    epoch_start_time = time.time()

    for batch_idx, (images, targets) in enumerate(train_loader):
        batch_start_time = time.time()
        images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        
        # Define the closure function for optimizers that require it (like CLAGR)
        # The step method of CLAGR will call this closure multiple times to re-calculate gradients
        def closure():
            # The CLAGR optimizer handles zero_grad internally
            output = model(images)
            loss = criterion(output, targets)
            loss.backward()
            return loss

        if use_closure:
            # For CLAGR, it requires an initial gradient
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            
            # The step method of CLAGR uses the closure to perform its internal multi-step gradient calculations
            optimizer.step(closure)
            
        else: # For standard optimizers (like SGD, Adam)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        
        # Calculate accuracy
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        batch_size = images.shape[0]

        # Update statistics
        total_loss += loss.item() * batch_size
        total_acc1 += acc1.item() * batch_size
        total_acc5 += acc5.item() * batch_size
        total_samples += batch_size

        if batch_idx % log_freq == 0:
            batch_time = time.time() - batch_start_time
            current_lr = optimizer.param_groups[0]['lr']
            print(
                f"Epoch: [{epoch+1}][{batch_idx}/{len(train_loader)}] | "
                f"LR: {current_lr:.6f} | "
                f"Loss: {loss.item():.4f} | "
                f"Acc@1: {acc1.item()/100:.4f} | "
                f"Time: {batch_time:.3f}s"
            )

    # Calculate average metrics for the entire epoch
    avg_loss = total_loss / total_samples
    avg_acc1 = total_acc1 / total_samples
    avg_acc5 = total_acc5 / total_samples
    
    print(f"--- Epoch {epoch+1} Summary ---")
    print(f"Train Loss: {avg_loss:.4f}, Train Acc@1: {avg_acc1:.2f}%, Train Acc@5: {avg_acc5:.2f}%")
    print(f"Epoch Time: {time.time() - epoch_start_time:.2f}s")
    
    return {'train_loss': avg_loss, 'train_acc1': avg_acc1, 'train_acc5': avg_acc5}


@torch.no_grad()
def evaluate(model, val_loader, criterion, device):
    """
    Evaluates the model's performance on the validation set.
    """
    model.eval()
    
    total_loss = 0.0
    total_acc1 = 0.0
    total_acc5 = 0.0
    total_samples = 0

    for images, targets in val_loader:
        images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        
        outputs = model(images)
        loss = criterion(outputs, targets)
        
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        batch_size = images.shape[0]
        
        total_loss += loss.item() * batch_size
        total_acc1 += acc1.item() * batch_size
        total_acc5 += acc5.item() * batch_size
        total_samples += batch_size

    avg_loss = total_loss / total_samples
    avg_acc1 = total_acc1 / total_samples
    avg_acc5 = total_acc5 / total_samples
    
    print(f"Test Loss: {avg_loss:.4f}, Test Acc@1: {avg_acc1:.2f}%, Test Acc@5: {avg_acc5:.2f}%")
    
    return {'test_loss': avg_loss, 'test_acc1': avg_acc1, 'test_acc5': avg_acc5}


def accuracy(output, target, topk=(1,)):
    """
    Computes the Top-k accuracy based on the model's output and true labels.
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res