from transformers import LogitsProcessor
import torch
import numpy as np
import warnings

class ScalingProcessor(LogitsProcessor):
    def __init__(self, temp, bias=None, add_bias_first=True):
        """
        Args:
            temp: scalar, 1D tensor, or 2D tensor of size vocab_size
            bias: None, scalar, or 1D tensor of size vocab_size
            add_bias_first: bool, if True then bias is added before scaling by temp
        """
        super().__init__()
        if isinstance(temp, np.ndarray):
            temp = torch.from_numpy(temp).float()
        self.temp = temp
        if isinstance(bias, np.ndarray):
            bias = torch.from_numpy(bias).float()
        self.bias = bias
        self.add_bias_first = add_bias_first

    def __call__(self, input_ids, scores):
        # scores: [batch_size, vocab_size]
        bias = self.bias
        if bias is not None and isinstance(bias, torch.Tensor):
            bias = bias.to(scores.device)

        if self.add_bias_first and bias is not None:
            scores = scores + bias

        if isinstance(self.temp, torch.Tensor):
            temp = self.temp.to(scores.device)
            if temp.dim() == 0:
                scores = scores / temp
            elif temp.dim() == 1:
                scores = scores / temp
            elif temp.dim() == 2:
                scores = torch.matmul(scores, temp.T)
            else:
                raise ValueError(f"Unsupported temp dim: {temp.dim()}")
        else:
            scores = scores / self.temp

        if not self.add_bias_first and bias is not None:
            scores = scores + bias

        return scores
