import torch
import torch.nn.functional as F
from typing import Optional

def masked_cross_entropy(logits: torch.Tensor, 
                         labels: torch.Tensor, 
                         masks: Optional[torch.Tensor]=None):
    """
    Args:
        logits (Tensor (batch_size, seq_len, num_classes) or None): The model's predictions.
        labels (Tensor (batch_size, seq_len) or None): The true labels.
        masks (Tensor (batch_size, seq_len) or None): The mask tensor.
    """
    if masks is None:
        return F.cross_entropy(logits.transpose(1, 2), labels)
    else:
        batch_seq_loss = F.cross_entropy(logits.transpose(1, 2), labels, reduction='none')
        batch_seq_loss = batch_seq_loss * masks
        return batch_seq_loss.sum() / masks.sum()

def masked_accuracy(logits: torch.Tensor, 
                    labels: torch.Tensor, 
                    masks: Optional[torch.Tensor]=None):
    if masks is None:
        return (logits.argmax(dim=-1) == labels).float().mean()
    else:
        batch_seq_acc = (logits.argmax(dim=-1) == labels).float()
        batch_seq_acc = batch_seq_acc * masks
        return batch_seq_acc.sum() / masks.sum()

if __name__ == "__main__":
    # Test
    logits = torch.randn(2, 3, 4)
    labels = torch.tensor([[1, 2, 3], [2, 3, 0]])
    masks = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.float32)
    loss = masked_cross_entropy(logits, labels, masks)
    print(loss)
    acc = masked_accuracy(logits, labels, masks)
    print(acc)