"""Functions for computing losses during training and evaluation.
"""
from typing import Union, List
import torch
import torch.nn.functional as F

def get_all_special_tokens(tokenizer):
    """
    Gets all special tokens of transformers tokenizer which contains attribute pairs
    [spec token, spec_token_id]. Ignores added special token ids.
    """
    spec_tokens = set([])
    for spec_token in tokenizer.SPECIAL_TOKENS_ATTRIBUTES:
        try:
            spec_token_id = getattr(tokenizer, f"{spec_token}_id")
            if spec_token_id is not None:
                spec_tokens.add(spec_token_id)
        except AttributeError:
            continue
    return list(spec_tokens)

def get_ignore_mask(tokens: torch.Tensor, ignore_index: Union[int, List[int]]):
    """Gets a mask over the tokens for all tokens that are in ignore_index, where ignore_index
    is allowed to be an int or a list."""
    if isinstance(ignore_index, int):
        ignore_index = torch.tensor([ignore_index], device=tokens.device).int()
    else:
        ignore_index = torch.tensor(ignore_index,device=tokens.device).int()
    ignore_mask = torch.isin(tokens, ignore_index) #(targets == ignore_index)
    return ignore_mask

def next_token_cross_entropy_loss(logits, targets, ignore_index=0, reduction='mean'):
    """
    Computes the next token cross-entropy loss for language models.
    
    Args:
        logits (torch.Tensor): Logits of shape (batch_size, nseq, vocab_size).
        targets (torch.Tensor): Actual token indices of shape (batch_size, nseq).
        ignore_index (int): Index to ignore for padding tokens (optional).
        
    Returns:
        torch.Tensor: Cross-entropy loss.
    """
    # Shift the targets to the left to predict the next token
    next_targets = targets[:, 1:]

    # Shift logits to match the target tokens
    next_logits = logits[:, :-1, :]

    ignore_mask = get_ignore_mask(targets, ignore_index)
    ignore_mask = torch.logical_or(ignore_mask[:, 1:], ignore_mask[:, :-1]) # either current or next is padding
    next_targets = next_targets[~ignore_mask]
    next_logits = next_logits[~ignore_mask]
    
    # Compute cross-entropy loss
    loss = F.cross_entropy(next_logits, next_targets, reduction=reduction)
    
    return loss

def kl_div_loss(logits, true_logits, reduction='batchmean'):
    # Assumes you've already attention masked
    prob_p = F.log_softmax(true_logits, dim=-1)  # Distribution P (true)
    prob_q = F.log_softmax(logits, dim=-1)  # Distribution Q (predicted)
    # Compute KL divergence: sum(P * (log(P) - log(Q))) > for each batch
    # 1/batch size = 1/4 * mean(P * (log(P) - log(Q)))
    kl_div = F.kl_div(prob_q, prob_p, reduction=reduction, log_target=True)
    return kl_div

def mse_loss(logits, true_logits):
    # assumes (batch, seq, dim) for both logits and true
    loss = torch.mean(torch.sum((logits - true_logits) ** 2, dim=-1))
    return loss
