import torch

from model.regress_lm.vocabs import DecoderVocab


class NumberTokenSelector:
    def __init__(self, vocab: DecoderVocab[float], device):
        self.tokenizer = vocab.tokenizer
        self.vocab = vocab
        self.nvocab = torch.full((len(vocab),), float("nan"), device=device)
        self.digit_vocab = torch.full((len(vocab),), float("nan"), device=device)
        self.exponent_vocab = torch.full((len(vocab),), float("nan"), device=device)

        hashed_num_tokens = set(self.tokenizer.get_num_tokens())
        exponent_tokens = []
        if hasattr(self.tokenizer, "get_exponent_tokens"):
            try:
                exponent_tokens = list(self.tokenizer.get_exponent_tokens())
            except Exception:
                exponent_tokens = []
        hashed_exp_tokens = set(exponent_tokens)

        for token, id in self.vocab.stoi.items():
            if token in hashed_num_tokens:
                value = float(self.tokenizer.token_to_number(token))
                self.nvocab[id] = value
                self.digit_vocab[id] = value
            elif token in hashed_exp_tokens and hasattr(self.tokenizer, "token_to_exponent"):
                try:
                    exp_value = float(self.tokenizer.token_to_exponent(token))
                    self.nvocab[id] = exp_value
                    self.exponent_vocab[id] = exp_value
                except Exception:
                    pass

        self.number_token_mask = ~torch.isnan(self.nvocab)
        self.digit_token_mask = ~torch.isnan(self.digit_vocab)
        self.exponent_token_mask = ~torch.isnan(self.exponent_vocab)
        self.number_token_indices = torch.nonzero(
            self.number_token_mask, as_tuple=False
        ).squeeze()

        self.number_token_values = self.nvocab[self.number_token_indices]
        self.digit_token_values = self.digit_vocab[self.digit_token_mask]
        self.exponent_token_values = self.exponent_vocab[self.exponent_token_mask]

    def select_number_tokens(self, logits: torch.Tensor):
        logits = logits[:, :, self.number_token_mask]
        return logits, self.number_token_mask
