import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

# Assuming necessary imports and configurations are handled externally

def train_one_epoch(model, criterion, optimizer, data_loader, device, scaler, epoch):
    """
    Perform one training epoch over the dataset.
    
    :param model: The neural network model
    :param criterion: Loss function
    :param optimizer: Optimizer
    :param data_loader: Training data loader
    :param device: Computing device
    :param scaler: Gradient scaler for mixed precision
    :param epoch: Current epoch number
    :return: Average loss and accuracy
    """
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    for inputs, targets in tqdm(data_loader, desc=f"Train[{epoch}]"):
        inputs = inputs.to(device).float()
        targets = targets.to(device)
        
        optimizer.zero_grad()
        
        with autocast(enabled=scaler is not None):
            logits, spike_info = model(inputs)
        
        with autocast(enabled=False):
            loss = criterion(logits, targets, spike_info=spike_info)
        
        if scaler:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        total_loss += loss.item() * inputs.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == targets).sum().item()
        total_samples += inputs.size(0)
    
    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples * 100
    return avg_loss, accuracy

def validate_one_epoch(model, criterion, data_loader, device, epoch):
    """
    Perform one validation epoch over the dataset.
    
    :param model: The neural network model
    :param criterion: Loss function
    :param data_loader: Validation data loader
    :param device: Computing device
    :param epoch: Current epoch number
    :return: Average loss and accuracy
    """
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for inputs, targets in tqdm(data_loader, desc=f"Val[{epoch}]"):
            inputs = inputs.to(device).float()
            targets = targets.to(device)
            
            logits, spike_info = model(inputs)
            loss = criterion(logits, targets, spike_info=spike_info)
            
            total_loss += loss.item() * inputs.size(0)
            preds = logits.argmax(dim=1)
            total_correct += (preds == targets).sum().item()
            total_samples += inputs.size(0)
    
    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples * 100
    return avg_loss, accuracy
