# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import math
import torch
import torch.distributed as dist
from config.config_template import ConfigTemplate
from tqdm import tqdm


def validation(config: ConfigTemplate, model, dataloader_val) -> float:
    # Define variables
    accu_steps = config.accu_steps
    loss_lm = torch.tensor(0.0, dtype=torch.float32, device="cuda")
    # Set the model to evaluation mode
    model.eval()
    # Loop through the validation set
    for idx_iter, (inputs, targets) in enumerate(tqdm(
        iterable=dataloader_val,
        desc="Validating model",
        total=len(dataloader_val),
        disable=dist.get_rank() != 0,  # Enable tqdm only for rank 0
    )):
        # Chunk data
        inputs  = torch.chunk(inputs,  accu_steps, dim=0)
        targets = torch.chunk(targets, accu_steps, dim=0)
        for idx_accu in range(accu_steps):
            inputs_current = inputs[idx_accu].to(device="cuda", non_blocking=True)
            targets_current = targets[idx_accu].to(device="cuda", non_blocking=True)
            # Forward pass
            with torch.no_grad():
                loss, telemetry = model(inputs_current, targets_current)
                loss_lm += telemetry["loss_lm"]
    # Normalize and synchronize `loss_lm`
    loss_lm = loss_lm / (len(dataloader_val) * accu_steps)
    dist.all_reduce(loss_lm, op=dist.ReduceOp.AVG)
    # Calculate perplexity on cpu
    perplexity_val = math.exp(loss_lm.item())
    # Set the model back to training mode
    model.train()
    return perplexity_val
