import torch
from torch import nn
from copy import deepcopy

def compute_accuracy(logits, labels):
    # Mask out invalid tokens (e.g., -100)
    valid_mask = labels >= 0

    # Apply softmax to logits
    probs = nn.functional.softmax(logits, dim=-1)

    # Compute Top-1 accuracy
    top1_pred = torch.argmax(probs, dim=-1)
    correct_top1 = (top1_pred[valid_mask] == labels[valid_mask]).float().mean()

    # Compute Top-5 accuracy
    _, top5_preds = torch.topk(probs, 5, dim=-1)
    correct_top5 = top5_preds[valid_mask].eq(labels[valid_mask].unsqueeze(-1)).any(dim=-1).float().mean()

    return correct_top1.item() * 100, correct_top5.item() * 100

def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
    reduction = "sum" if num_items_in_batch is not None else "mean"
    loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
    if reduction == "sum":
        loss = loss / num_items_in_batch
    return loss

def ForCausalLMLoss(
    logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
    # Upcast to float if we need to compute the loss to avoid potential precision issues
    logits = logits.float()
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    shift_logits = shift_logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
    top_1_acc, top_5_acc = compute_accuracy(shift_logits, shift_labels)
    return loss, top_1_acc, top_5_acc
