from typing import Callable, Optional

from transformers import PreTrainedTokenizer

from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus, Logprob

import logging
logger = logging.getLogger(__name__)


class StopChecker:
    """LLMEngine helper class which separates out the logic involving stop
    checking. This checks things such as: whether the eos token was emitted,
    whether the max_tokens has been consumed, whether a stop string has been
    emitted, or if we have exceeded the max model len.
    """

    def __init__(self, max_model_len: int,
                 get_tokenizer_for_seq: Callable[[Sequence],
                                                 PreTrainedTokenizer]):
        # Do not use it directly, but use `self._get_max_model_len`.
        self._max_model_len = max_model_len
        self.get_tokenizer_for_seq = get_tokenizer_for_seq

    def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
        if lora_req and lora_req.long_lora_max_len:
            return lora_req.long_lora_max_len
        else:
            return self._max_model_len
        
    def _change_last_token(self, seq: Sequence) -> None:
        # change tokens in seq logical block
        last_block = seq.logical_token_blocks[-1]
        num_empty_slots = last_block.get_num_empty_slots()
        last_block.token_ids[-num_empty_slots-1] = seq.eos_token_id

        # change seq.data
        real_last_token = seq.data.output_token_ids[-1]
        real_last_token_logprob = seq.output_logprobs[-1][real_last_token]
        seq.data.cumulative_logprob -= real_last_token_logprob.logprob
        seq.data.output_token_ids[-1] = seq.eos_token_id

        # change seq.output_logprobs
        seq.output_logprobs[-1] = {seq.eos_token_id: Logprob(
            logprob=0.0,
            rank=1,
            decoded_token=""
        )}

    def maybe_stop_sequence_for_eos(self, seq: Sequence, new_char_count: int,
                            sampling_params: SamplingParams) -> None:
        this_status = seq.status
        self.maybe_stop_sequence(seq, new_char_count, sampling_params)
        if this_status is SequenceStatus.RUNNING and seq.is_finished():
            if seq.get_last_token_id() != seq.eos_token_id:
                self._change_last_token(seq)
            setattr(seq, 'real_stop_status', seq.status)
            setattr(seq, 'status', SequenceStatus.RUNNING)

    def maybe_stop_sequence(
        self,
        seq: Sequence,
        new_char_count: int,
        sampling_params: SamplingParams,
        lora_req: Optional[LoRARequest] = None,
    ) -> None:
        """Stop the finished sequences.

       new_char_count is the number of chars added to the
           sequence's output text for the newly generated token
        """

        # Check if the minimum number of tokens has been generated yet;
        # skip the stop string/token checks if not
        if seq.get_output_len() < sampling_params.min_tokens:
            return

        # Check if the sequence has generated the EOS token.
        if ((not sampling_params.ignore_eos)
                and seq.get_last_token_id() == seq.eos_token_id):
            # Remove the last EOS token unless explicitly specified
            # This prevents unintended exposure of the EOS token
            if new_char_count and (
                    not sampling_params.include_stop_str_in_output):
                seq.output_text = seq.output_text[:-new_char_count]
            seq.status = SequenceStatus.FINISHED_STOPPED
            return

        # Check if a stop token was encountered.
        # This assumes a single token produced per step.
        last_token_id = seq.get_last_token_id()
        if last_token_id in sampling_params.stop_token_ids:
            if new_char_count and (
                    not sampling_params.include_stop_str_in_output):
                # Remove last token
                seq.output_text = seq.output_text[:-new_char_count]
            seq.status = SequenceStatus.FINISHED_STOPPED
            seq.stop_reason = last_token_id
            return

        # Check if any stop strings are matched.
        stop_str = self._check_stop_strings(seq, new_char_count,
                                            sampling_params)
        if stop_str is not None:
            seq.status = SequenceStatus.FINISHED_STOPPED
            seq.stop_reason = stop_str
            return

        # Check if the sequence has reached max_model_len.
        if seq.get_len() > self._get_max_model_len(lora_req):
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the sequence has reached max_tokens.
        if seq.get_output_len() == sampling_params.max_tokens:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

    @staticmethod
    def _check_stop_strings(seq: Sequence, new_char_count: int,
                            sampling_params: SamplingParams) -> Optional[str]:
        """Check if any stop strings are matched and truncate sequence
        output text accordingly.

        Returns the stop string if matched or else None.
        """
        if not new_char_count:
            return None

        for stop_str in sampling_params.stop:
            stop_string_len = len(stop_str)
            # Avoid searching already-searched text.
            stop_index = seq.output_text.find(
                stop_str, -new_char_count - stop_string_len)
            if stop_index == -1:
                continue

            if sampling_params.include_stop_str_in_output:
                # Truncate to end of stop string.
                stop_index += stop_string_len
                if stop_index >= len(seq.output_text):
                    # No truncation required.
                    return stop_str

            # Truncate the output text to either the beginning
            # or end of the stop string.
            seq.output_text = seq.output_text[:stop_index]
            return stop_str
        return None
