import torch
import torch.nn.functional as F

from src.model.regress_lm.vocabs import DecoderVocab
from src.utils.number_token_selector import NumberTokenSelector


class NumberTokenLoss:
    def __init__(
        self, vocab: DecoderVocab[float], device, loss_function=F.mse_loss, weight=0.5
    ):
        self.loss_function = loss_function
        self.weight = weight

        self.selector = NumberTokenSelector(vocab, device)
        self.nvocab = self.selector.nvocab

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        if logits.numel() == 0:
            raise ValueError("Logits passed to the NumberTokenLoss are empty!")
        if labels.numel() == 0:
            raise ValueError("Labels passed to the NumberTokenLoss are empty!")

        digit_mask = self.selector.digit_token_mask
        exp_mask = self.selector.exponent_token_mask
        digit_values = self.selector.digit_vocab
        exp_values = self.selector.exponent_vocab

        B, L, V = logits.shape

        is_exp_pos = exp_mask[labels]

        very_negative = -1e9
        logits_digit = logits.clone()
        logits_digit[..., ~digit_mask] = very_negative
        logits_exp = logits.clone()
        logits_exp[..., ~exp_mask] = very_negative

        probs_digit = F.softmax(logits_digit, dim=-1)
        probs_exp = F.softmax(logits_exp, dim=-1)

        digit_values_safe = torch.where(digit_mask, digit_values, torch.zeros_like(digit_values))
        exp_values_safe = torch.where(exp_mask, exp_values, torch.zeros_like(exp_values))
        yhat_digit = torch.sum(probs_digit * digit_values_safe.view(1, 1, V), dim=-1)
        yhat_exp = torch.sum(probs_exp * exp_values_safe.view(1, 1, V), dim=-1)

        yhat = torch.where(is_exp_pos, yhat_exp, yhat_digit)

        y_exp = exp_values_safe[labels]
        y_digit = digit_values_safe[labels]
        y = torch.where(is_exp_pos, y_exp, y_digit)

        valid = torch.isfinite(y)
        if valid.any():
            loss = self.loss_function(yhat[valid], y[valid])
        else:
            loss = torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
        return loss
