from typing import List

import torch
from transformers import add_start_docstrings, StoppingCriteria
from transformers.generation.stopping_criteria import STOPPING_CRITERIA_INPUTS_DOCSTRING


class StopAtSpecificTokenCriteria(StoppingCriteria):
    """
    当生成出第一个指定token时，立即停止生成
    ---------------
    ver: 2023-08-02
    by: changhongyu
    """

    def __init__(self, token_ids_list: List[int] = None):
        """
        :param token_id_list: 停止生成的指定token的id的列表
        """
        self.token_ids_list = token_ids_list

    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:

        return input_ids[0][-len(self.token_ids_list):].tolist() == self.token_ids_list
