import torch
import torch.nn.functional as F

from src.model.regress_lm.vocabs import DecoderVocab
from src.utils.number_token_selector import NumberTokenSelector


class NumberTokenLoss_WAS:
    def __init__(
        self, vocab: DecoderVocab[float], device, loss_function=F.mse_loss, weight=0.5
    ):
        self.loss_function = loss_function
        self.weight = weight

        self.selector = NumberTokenSelector(vocab, device)
        self.nvocab = self.selector.nvocab

    def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        is_number_pos = self.selector.number_token_mask[labels]

        if not is_number_pos.any():
            return torch.tensor(0.0, device=logits.device, dtype=logits.dtype)

        relevant_logits = logits[is_number_pos]
        relevant_labels = labels[is_number_pos]

        y_true_values = self.selector.nvocab[relevant_labels]

        probs = F.softmax(relevant_logits, dim=-1)

        number_probs = probs[:, self.selector.number_token_indices]
        
        all_number_values = self.selector.number_token_values

        abs_diff = torch.abs(y_true_values.unsqueeze(-1) - all_number_values.unsqueeze(0))

        per_token_loss = torch.sum(number_probs * abs_diff, dim=-1)

        loss = per_token_loss.mean()

        return loss