# coding=utf-8
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Rewrite GenerationMixin from Huggingface Transformers library to support 
- past_key_values

Note:
1. According to 
    https://github.com/huggingface/transformers/pull/17574
    past_key_values is an unmerged feature in Huggingface Transformers. So some 
    code in this file may become obselete in the future.
"""
import copy
import inspect
import warnings
import math
from copy import deepcopy
from dataclasses import dataclass
import re
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union, Set
import logging

import torch
import torch.distributed as dist
from torch import nn
import torch.nn.functional as F

from packaging import version
import transformers
is_legacy_transformer_version = version.parse(transformers.__version__) < version.parse('4.48')
if is_legacy_transformer_version:
    from transformers.deepspeed import is_deepspeed_zero3_enabled
else:
    from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.generation.beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList,LogitsProcessor
from transformers.generation.stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
    StoppingCriteria,
    StoppingCriteriaList,
    validate_stopping_criteria,
)
from transformers.generation.utils import (
    GenerationMixin, 
    GenerationMode, 
    GenerateOutput,
    GenerateNonBeamOutput,
    GenerateEncoderDecoderOutput,
    GenerateDecoderOnlyOutput,
)
from transformers.utils import ModelOutput, is_torchdynamo_compiling

logger = logging.getLogger(__name__)


def generation_post_init(model):
    """ This function monkey-patches several generation method in model with customed ones.
    
    """
    model.generate = lambda *args, **kwargs: generate(model, *args, **kwargs)
    # model.greedy_search = lambda *args, **kwargs: greedy_search(model, *args, **kwargs)
    model._sample = lambda *args, **kwargs: _sample(model, *args, **kwargs)

    return model


@torch.no_grad()
def generate(
    self,
    inputs: Optional[torch.Tensor] = None,
    generation_config: Optional[GenerationConfig] = None,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
    synced_gpus: Optional[bool] = None,
    assistant_model: Optional["PreTrainedModel"] = None,
    streamer: Optional["BaseStreamer"] = None,
    negative_prompt_ids: Optional[torch.Tensor] = None,
    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
    **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
    r"""

    Generates sequences of token ids for models with a language modeling head.

    <Tip warning={true}>

    Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
    model's default generation configuration. You can override any `generation_config` by passing the corresponding
    parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.

    For an overview of generation strategies and code examples, check out the [following
    guide](../generation_strategies).

    </Tip>

    Parameters:
        inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
            The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
            method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
            should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
            `input_ids`, `input_values`, `input_features`, or `pixel_values`.
        generation_config ([`~generation.GenerationConfig`], *optional*):
            The generation configuration to be used as base parametrization for the generation call. `**kwargs`
            passed to generate matching the attributes of `generation_config` will override them. If
            `generation_config` is not provided, the default will be used, which has the following loading
            priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
            configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
            default values, whose documentation should be checked to parameterize generation.
        logits_processor (`LogitsProcessorList`, *optional*):
            Custom logits processors that complement the default logits processors built from arguments and
            generation config. If a logit processor is passed that is already created with the arguments or a
            generation config an error is thrown. This feature is intended for advanced users.
        stopping_criteria (`StoppingCriteriaList`, *optional*):
            Custom stopping criteria that complements the default stopping criteria built from arguments and a
            generation config. If a stopping criteria is passed that is already created with the arguments or a
            generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
            sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
            intended for advanced users.
        prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
            If provided, this function constraints the beam search to allowed tokens only at each step. If not
            provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
            `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
            on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
            for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
            Retrieval](https://arxiv.org/abs/2010.00904).
        synced_gpus (`bool`, *optional*):
            Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
            `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
            generating before other GPUs. Otherwise it'll be set to `False`.
        assistant_model (`PreTrainedModel`, *optional*):
            An assistant model that can be used to accelerate generation. The assistant model must have the exact
            same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
            is much faster than running generation with the model you're calling generate from. As such, the
            assistant model should be much smaller.
        streamer (`BaseStreamer`, *optional*):
            Streamer object that will be used to stream the generated sequences. Generated tokens are passed
            through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
        negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            The negative prompt needed for some processors such as CFG. The batch size must match the input batch
            size. This is an experimental feature, subject to breaking API changes in future versions.
        negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Attention_mask for `negative_prompt_ids`.
        kwargs (`Dict[str, Any]`, *optional*):
            Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
            forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
            specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.

    Return:
        [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
        or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`.

            If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
            [`~utils.ModelOutput`] types are:

                - [`~generation.GenerateDecoderOnlyOutput`],
                - [`~generation.GenerateBeamDecoderOnlyOutput`]

            If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
            [`~utils.ModelOutput`] types are:

                - [`~generation.GenerateEncoderDecoderOutput`],
                - [`~generation.GenerateBeamEncoderDecoderOutput`]
    """
    # 0. process customized kwargs, so that they do not interfere with model_kwargs
    output_past_key_values = kwargs.pop("output_past_key_values", None)
    return_policy = kwargs.pop('return_policy', False)
    # use code_interpreter if provided in kwargs, otherwise use the one attached
    code_interpreter = kwargs.pop('code_interpreter', None)
    if code_interpreter is not None:
        self.code_interpreter = code_interpreter
    if hasattr(self, "code_interpreter") and hasattr(self.code_interpreter, "reset"):
        self.code_interpreter.reset()

    # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
    self._validate_model_class()
    tokenizer = kwargs.pop("tokenizer", None)  # Pull this out first, we only use it for stopping criteria

    generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
    self._validate_model_kwargs(model_kwargs.copy())
    if is_legacy_transformer_version:
        self._validate_assistant(assistant_model)
        
        # 2. Set generation parameters if not already defined
        if synced_gpus is None:
            if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
                synced_gpus = True
            else:
                synced_gpus = False
    else:
        assistant_tokenizer = kwargs.pop("assistant_tokenizer", None)  # only used for assisted generation
        self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)

        # 2. Set generation parameters if not already defined
        if synced_gpus is None:
            from transformers.integrations.fsdp import is_fsdp_managed_module
            
            synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
    

    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

    accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
    requires_attention_mask = "encoder_outputs" not in model_kwargs
    kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None

    # 3. Define model inputs
    inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
        inputs, generation_config.bos_token_id, model_kwargs
    )
    batch_size = inputs_tensor.shape[0]

    device = inputs_tensor.device
    self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

    # decoder-only models must use left-padding for batched generation.
    if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
        # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
        # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
        if (
            generation_config._pad_token_tensor is not None
            and batch_size > 1
            and len(inputs_tensor.shape) == 2
            and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
        ):
            logger.warning(
                "A decoder-only architecture is being used, but right-padding was detected! For correct "
                "generation results, please set `padding_side='left'` when initializing the tokenizer."
            )

    # 4. Define other model kwargs
    # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
    # generating the first new token or not, and we only want to use the embeddings for the first new token)
    if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
        model_kwargs["use_cache"] = True
    else:
        model_kwargs["use_cache"] = generation_config.use_cache

    if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
        if is_legacy_transformer_version:
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
                inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
            )
        else:
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
                inputs_tensor, generation_config, model_kwargs
            )
    elif kwargs_has_attention_mask:
        # TODO (joao): generalize this check with other types of inputs
        if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
            raise ValueError("`attention_mask` passed to `generate` must be 2D.")

    if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
        # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
        model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
            inputs_tensor, model_kwargs, model_input_name, generation_config
        )

    # 5. Prepare `input_ids` which will be used for auto-regressive generation
    if self.config.is_encoder_decoder:
        input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
            batch_size=batch_size,
            model_input_name=model_input_name,
            model_kwargs=model_kwargs,
            decoder_start_token_id=generation_config._decoder_start_token_tensor,
            device=inputs_tensor.device,
        )
    else:
        input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

    if generation_config.token_healing:
        input_ids = self.heal_tokens(input_ids, tokenizer)

    if streamer is not None:
        streamer.put(input_ids.cpu())

    # 6. Prepare `max_length` depending on other stopping criteria.
    input_ids_length = input_ids.shape[-1]
    has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
    has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
    generation_config = self._prepare_generated_length(
        generation_config=generation_config,
        has_default_max_length=has_default_max_length,
        has_default_min_length=has_default_min_length,
        model_input_name=model_input_name,
        inputs_tensor=inputs_tensor,
        input_ids_length=input_ids_length,
    )

    # If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole
    # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
    # dynamically overrides this value as it can need more than the last token logits
    if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs:
        model_kwargs["num_logits_to_keep"] = 1

    self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

    # 7. Prepare the cache.
    # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
    # - different models have a different cache name expected by the model (default = "past_key_values")
    # - `max_length`, prepared above, is used to determine the maximum cache length
    # TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format)
    cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
    user_defined_cache = model_kwargs.get(cache_name)
    max_cache_length = generation_config.max_length
    if (
        inputs_tensor.shape[1] != input_ids_length
        and model_input_name == "inputs_embeds"
        and not self.config.is_encoder_decoder
    ):
        max_cache_length += inputs_tensor.shape[1]
    self._prepare_cache_for_generation(
        generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
    )

    # 8. determine generation mode
    generation_mode = generation_config.get_generation_mode(assistant_model)

    if streamer is not None and (generation_config.num_beams > 1):
        raise ValueError(
            "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
        )

    if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
        warnings.warn(
            "You are calling .generate() with the `input_ids` being on a device type different"
            f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
            f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
            " Please make sure that you have put `input_ids` to the"
            f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
            " running `.generate()`.",
            UserWarning,
        )

    # 9. prepare logits processors and stopping criteria
    prepared_logits_processor = self._get_logits_processor(
        generation_config=generation_config,
        input_ids_seq_length=input_ids_length,
        encoder_input_ids=inputs_tensor,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
        logits_processor=logits_processor,
        device=inputs_tensor.device,
        model_kwargs=model_kwargs,
        negative_prompt_ids=negative_prompt_ids,
        negative_prompt_attention_mask=negative_prompt_attention_mask,
    )
    prepared_stopping_criteria = self._get_stopping_criteria(
        generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
    )

    # 10. go into different generation modes
    if generation_mode == GenerationMode.ASSISTED_GENERATION:
        if generation_config.num_return_sequences > 1:
            raise ValueError(
                "num_return_sequences has to be 1 when doing assisted generate, "
                f"but is {generation_config.num_return_sequences}."
            )
        if batch_size > 1:
            raise ValueError("assisted generate is only supported for batch_size = 1")
        if not model_kwargs["use_cache"]:
            raise ValueError("assisted generate requires `use_cache=True`")
        if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]:
            raise ValueError("assisted generate is not supported with Static cache classes`")
        if self._is_stateful:
            # In assisted generation we need the ability to confirm whether the model would pick certain tokens,
            # which is not possible with stateful models (they can't reset to a previous subset of generated text)
            raise ValueError(
                f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
            )

        # 11. Get the candidate generator, given the parameterization
        if is_legacy_transformer_version:
            candidate_generator = self._get_candidate_generator(
                generation_config=generation_config,
                input_ids=input_ids,
                inputs_tensor=inputs_tensor,
                assistant_model=assistant_model,
                logits_processor=logits_processor,
                model_kwargs=model_kwargs,
            )
        else:
            candidate_generator = self._get_candidate_generator(
                generation_config=generation_config,
                input_ids=input_ids,
                inputs_tensor=inputs_tensor,
                assistant_model=assistant_model,
                logits_processor=logits_processor,
                target_tokenizer=tokenizer,
                assistant_tokenizer=assistant_tokenizer,
                model_kwargs=model_kwargs,
            )

        # 12. run assisted generate
        generate_policy = self._assisted_decoding
        generate_kwargs = {
            "input_ids": input_ids,
            "candidate_generator": candidate_generator,
            'logits_processor': prepared_logits_processor,
            'stopping_criteria': prepared_stopping_criteria,
            'generation_config': generation_config,
            'synced_gpus': synced_gpus,
            'streamer': streamer,
        }
        generate_kwargs.update(model_kwargs)

    elif generation_mode == GenerationMode.DOLA_GENERATION:
        if self._is_stateful:
            # DoLa decoding was not designed for stateful models, and would require some changes
            raise ValueError(
                f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}"
            )
        generate_policy = self._dola_decoding
        generate_kwargs = {
            "input_ids": input_ids,
            "dola_layers": generation_config.dola_layers,
            'logits_processor': prepared_logits_processor,
            'stopping_criteria': prepared_stopping_criteria,
            'generation_config': generation_config,
            'synced_gpus': synced_gpus,
            'streamer': streamer,
        }
        generate_kwargs.update(model_kwargs)
    
    elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
        if not model_kwargs["use_cache"]:
            raise ValueError("Contrastive search requires `use_cache=True`")
        if self._is_stateful:
            # Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
            raise ValueError(
                f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}"
            )

        generate_policy = self._contrastive_search
        generate_kwargs = {
            "input_ids": input_ids,
            'logits_processor': prepared_logits_processor,
            'stopping_criteria': prepared_stopping_criteria,
            'generation_config': generation_config,
            'synced_gpus': synced_gpus,
            'streamer': streamer,
        }
        generate_kwargs.update(model_kwargs)

    elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
        # 11. expand input_ids with `num_return_sequences` additional sequences per batch
        input_ids, model_kwargs = self._expand_inputs_for_generation(
            input_ids=input_ids,
            expand_size=generation_config.num_return_sequences,
            is_encoder_decoder=self.config.is_encoder_decoder,
            **model_kwargs,
        )

        # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
        generate_policy = self._sample
        generate_kwargs = {
            "input_ids": input_ids,
            'logits_processor': prepared_logits_processor,
            'stopping_criteria': prepared_stopping_criteria,
            'generation_config': generation_config,
            'synced_gpus': synced_gpus,
            'streamer': streamer,
            'output_past_key_values': output_past_key_values,
        }
        generate_kwargs.update(model_kwargs)

    elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
        # 11. prepare beam search scorer
        beam_scorer = BeamSearchScorer(
            batch_size=batch_size,
            num_beams=generation_config.num_beams,
            device=inputs_tensor.device,
            length_penalty=generation_config.length_penalty,
            do_early_stopping=generation_config.early_stopping,
            num_beam_hyps_to_keep=generation_config.num_return_sequences,
            max_length=generation_config.max_length,
        )

        # 12. interleave input_ids with `num_beams` additional sequences per batch
        input_ids, model_kwargs = self._expand_inputs_for_generation(
            input_ids=input_ids,
            expand_size=generation_config.num_beams,
            is_encoder_decoder=self.config.is_encoder_decoder,
            **model_kwargs,
        )
        # 13. run beam search
        generate_policy = self._beam_search
        generate_kwargs = {
            "input_ids": input_ids,
            'beam_scorer': beam_scorer,
            'logits_processor': prepared_logits_processor,
            'stopping_criteria': prepared_stopping_criteria,
            'generation_config': generation_config,
            'synced_gpus': synced_gpus,
        }
        generate_kwargs.update(model_kwargs)

    elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
        # 11. prepare beam search scorer
        beam_scorer = BeamSearchScorer(
            batch_size=batch_size,
            num_beams=generation_config.num_beams,
            device=inputs_tensor.device,
            length_penalty=generation_config.length_penalty,
            do_early_stopping=generation_config.early_stopping,
            num_beam_hyps_to_keep=generation_config.num_return_sequences,
            num_beam_groups=generation_config.num_beam_groups,
            max_length=generation_config.max_length,
        )
        # 12. interleave input_ids with `num_beams` additional sequences per batch
        input_ids, model_kwargs = self._expand_inputs_for_generation(
            input_ids=input_ids,
            expand_size=generation_config.num_beams,
            is_encoder_decoder=self.config.is_encoder_decoder,
            **model_kwargs,
        )
        # 13. run beam search
        generate_policy = self._group_beam_search
        generate_kwargs = {
            "input_ids": input_ids,
            'beam_scorer': beam_scorer,
            'logits_processor': prepared_logits_processor,
            'stopping_criteria': prepared_stopping_criteria,
            'generation_config': generation_config,
            'synced_gpus': synced_gpus,
        }
        generate_kwargs.update(model_kwargs)

    elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
        final_constraints = []
        if generation_config.constraints is not None:
            final_constraints = generation_config.constraints

        if generation_config.force_words_ids is not None:

            def typeerror():
                raise ValueError(
                    "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]` "
                    f"of positive integers, but is {generation_config.force_words_ids}."
                )

            if (
                not isinstance(generation_config.force_words_ids, list)
                or len(generation_config.force_words_ids) == 0
            ):
                typeerror()

            for word_ids in generation_config.force_words_ids:
                if isinstance(word_ids[0], list):
                    if not isinstance(word_ids, list) or len(word_ids) == 0:
                        typeerror()
                    if any(not isinstance(token_ids, list) for token_ids in word_ids):
                        typeerror()
                    if any(
                        any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
                        for token_ids in word_ids
                    ):
                        typeerror()

                    constraint = DisjunctiveConstraint(word_ids)
                else:
                    if not isinstance(word_ids, list) or len(word_ids) == 0:
                        typeerror()
                    if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
                        typeerror()

                    constraint = PhrasalConstraint(word_ids)
                final_constraints.append(constraint)

        # 11. prepare beam search scorer
        constrained_beam_scorer = ConstrainedBeamSearchScorer(
            constraints=final_constraints,
            batch_size=batch_size,
            num_beams=generation_config.num_beams,
            device=inputs_tensor.device,
            length_penalty=generation_config.length_penalty,
            do_early_stopping=generation_config.early_stopping,
            num_beam_hyps_to_keep=generation_config.num_return_sequences,
            max_length=generation_config.max_length,
        )
        # 12. interleave input_ids with `num_beams` additional sequences per batch
        input_ids, model_kwargs = self._expand_inputs_for_generation(
            input_ids=input_ids,
            expand_size=generation_config.num_beams,
            is_encoder_decoder=self.config.is_encoder_decoder,
            **model_kwargs,
        )
        # 13. run beam search
        generate_policy = self._constrained_beam_search
        generate_kwargs = {
            "input_ids": input_ids,
            'constrained_beam_scorer': constrained_beam_scorer,
            'logits_processor': prepared_logits_processor,
            'stopping_criteria': prepared_stopping_criteria,
            'generation_config': generation_config,
            'synced_gpus': synced_gpus,
        }
        generate_kwargs.update(model_kwargs)

    if return_policy:
        return generate_policy, generate_kwargs
    else:
        return generate_policy(**generate_kwargs)


def _sample(
    self,
    input_ids: torch.LongTensor,
    logits_processor: LogitsProcessorList,
    stopping_criteria: StoppingCriteriaList,
    generation_config: GenerationConfig,
    synced_gpus: bool,
    streamer: Optional["BaseStreamer"],
    output_past_key_values: bool = False,
    **model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
    r"""
    Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
    can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

    Parameters:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The sequence used as a prompt for the generation.
        logits_processor (`LogitsProcessorList`):
            An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
            used to modify the prediction scores of the language modeling head applied at each generation step.
        stopping_criteria (`StoppingCriteriaList`):
            An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
            used to tell if the generation loop should stop.
        generation_config ([`~generation.GenerationConfig`]):
            The generation configuration to be used as parametrization of the decoding method.
        synced_gpus (`bool`):
            Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
        streamer (`BaseStreamer`, *optional*):
            Streamer object that will be used to stream the generated sequences. Generated tokens are passed
            through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
        model_kwargs:
            Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
            an encoder-decoder model the kwargs should include `encoder_outputs`.

    Return:
        [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
        A `torch.LongTensor` containing the generated tokens (default behaviour) or a
        [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
        `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
        `model.config.is_encoder_decoder=True`.
    """
    # init values
    pad_token_id = generation_config._pad_token_tensor
    output_attentions = generation_config.output_attentions
    output_hidden_states = generation_config.output_hidden_states
    output_scores = generation_config.output_scores
    output_logits = generation_config.output_logits
    return_dict_in_generate = generation_config.return_dict_in_generate
    max_length = generation_config.max_length
    has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
    do_sample = generation_config.do_sample

    # init attention / hidden states / scores tuples
    scores = () if (return_dict_in_generate and output_scores) else None
    raw_logits = () if (return_dict_in_generate and output_logits) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
    past_key_values = () if (return_dict_in_generate and output_past_key_values) else None

    # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    # keep track of which sequences are already finished
    batch_size, cur_len = input_ids.shape
    this_peer_finished = False
    unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
    model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

    while self._has_unfinished_sequences(
        this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
    ):
        # prepare model inputs
        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

        # prepare variable output controls (note: some models won't accept all output controls)
        model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
        _output_hidden_states = bool(output_hidden_states) if isinstance(output_hidden_states, list) else output_hidden_states
        model_inputs.update({"output_hidden_states": _output_hidden_states} if _output_hidden_states else {})
        # forward pass to get next token
        outputs = self(**model_inputs, return_dict=True)
        
        # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
        model_kwargs = self._update_model_kwargs_for_generation(
            outputs,
            model_kwargs,
            is_encoder_decoder=self.config.is_encoder_decoder,
        )
        if synced_gpus and this_peer_finished:
            continue

        # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
        # (the clone itself is always small)
        next_token_logits = outputs.logits[:, -1, :].clone().float()
        next_token_logits = next_token_logits.to(input_ids.device)

        # pre-process distribution
        next_token_scores = logits_processor(input_ids, next_token_logits)

        # Store scores, attentions and hidden_states when required
        if return_dict_in_generate:
            if output_scores:
                scores += (next_token_scores,)
            if output_logits:
                raw_logits += (next_token_logits,)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)

            if isinstance(output_hidden_states, list):
                decoder_hidden_states += (
                    tuple([outputs.decoder_hidden_states[hs_index] for hs_index in output_hidden_states])
                    if self.config.is_encoder_decoder
                    else tuple([outputs.hidden_states[hs_index] for hs_index in output_hidden_states])
                )
            elif output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )

            if output_past_key_values:
                past_key_values = outputs.past_key_values

        # token selection
        if do_sample:
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            next_tokens = torch.argmax(next_token_scores, dim=-1)

        # finished sentences should have their next token be a padding token
        if has_eos_stopping_criteria:
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        # sequences with code should have the next token be the corrresponding answer token
        if callable(getattr(self, "code_interpreter", None)):
            # input_ids, model_kwargs, next_tokens = self.code_interpreter(
            #     input_ids=input_ids, model=self, next_tokens=next_tokens, **model_kwargs)
            input_ids, next_tokens = self.code_interpreter(input_ids=input_ids, next_tokens=next_tokens)

        # update generated ids, model inputs, and length for next step
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        if streamer is not None:
            streamer.put(next_tokens.cpu())

        unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
        this_peer_finished = unfinished_sequences.max() == 0
        cur_len += 1

        # This is needed to properly delete outputs.logits which may be very large for first iteration
        # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
        del outputs

    if streamer is not None:
        streamer.end()

    if return_dict_in_generate:
        if self.config.is_encoder_decoder:
            return GenerateEncoderDecoderOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
                past_key_values=past_key_values,
            )
        else:
            return GenerateDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
                past_key_values=past_key_values,
            )
    else:
        return input_ids


@dataclass
class SamplingOutput(ModelOutput):
    sequences: torch.LongTensor = None
    transition_scores: Optional[torch.FloatTensor] = None
    past_key_values: Tuple[Tuple[torch.FloatTensor]] = None


@dataclass
class StepSamplingOutput(ModelOutput):
    sequences: torch.LongTensor = None
    steps: torch.LongTensor = None
    transition_scores: Optional[torch.FloatTensor] = None
    verifier_scores: Optional[torch.FloatTensor] = None
    past_key_values: Tuple[Tuple[torch.FloatTensor]] = None


class BatchEndStoppingCriteria(StoppingCriteria):
    # stop when all samples in the batch have generated the end token
    def __init__(self, end_token_id: int, device: torch.device):
        self.end_token_id = torch.tensor([end_token_id]).to(device)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids.eq(self.end_token_id).any(1).all():
            return True
        return False


class StepStoppingCriteria(StoppingCriteria):
    # stop when all samples in the batch have completed one step
    def __init__(self, cur_token_lens: torch.LongTensor, end_token_ids: Set[int], false_eos_ids: List[int], pad_token_id: int, device: torch.device):
        self.cur_token_lens = cur_token_lens
        self.false_eos_ids = [torch.tensor(false_eos_id).to(device) for false_eos_id in false_eos_ids]
        self.end_token_ids = torch.tensor(list(end_token_ids)).to(device)
        self.pad_token_id = pad_token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # new_tokens = get_new_generated_tokens(input_ids, past_token_lens=self.cur_token_lens, pad_token_id=self.pad_token_id)
        new_tokens = get_new_generated_tokens_with_forward(input_ids, past_token_lens=self.cur_token_lens, pad_token_id=self.pad_token_id, forward_num=0)
        if new_tokens[:, :, None].eq(self.end_token_ids.view((1, 1, -1))).any(2).any(1).all():
            positions = get_leftmost_token_position_with_false_tokens(new_tokens, self.end_token_ids, self.false_eos_ids , False)
            if (positions < new_tokens.shape[-1]).all():
                return True
        return False


def get_rightmost_token_position_with_false_tokens(inputs, target_tokens, false_tokens_list):
    mask = torch.ones_like(inputs, device=inputs.device)
    for false_tokens in false_tokens_list:
        window = inputs.unfold(1, false_tokens.shape[0], 1)
        false_mask = (window==false_tokens).all(2)
        padding = torch.zeros(false_mask.shape[0], false_tokens.shape[0]-1, device=false_mask.device, dtype=torch.bool)
        false_mask = torch.cat((false_mask, padding),dim=-1)
        px, py = torch.where(false_mask==True)
        py_end = py+len(false_tokens)
        for x, y, y_end in zip(px, py, py_end):
            mask[x,y:y_end] = False
    
    right_most_false_positions = torch.where((mask==0).any(1),(mask.shape[-1]-1)-(mask==0).flip(dims=[1]).float().argmax(dim=1), -1)
    mask = (inputs*mask)[:, :, None].eq(target_tokens.view(1, 1, -1)).any(2)
    positions = torch.where(mask.any(1), (inputs.shape[-1] - 1) - mask.flip(dims=[1]).float().argmax(dim=1), -1)

    return positions, right_most_false_positions

def get_leftmost_token_position_with_false_tokens_noextreme(inputs, target_tokens, false_tokens_list, wnby):
    mask = torch.ones_like(inputs, device=inputs.device)
    for false_tokens in false_tokens_list:
        window = inputs.unfold(1, false_tokens.shape[0], 1)
        false_mask = (window==false_tokens).all(2)
        padding = torch.zeros(false_mask.shape[0], false_tokens.shape[0]-1, device=false_mask.device, dtype=torch.bool)
        false_mask = torch.cat((false_mask, padding),dim=-1)
        px, py = torch.where(false_mask==True)
        py_end = py+len(false_tokens)
        for x, y, y_end in zip(px, py, py_end):
            mask[x,y:y_end] = False
    first_non_pad_token_index = (inputs!=0).float().argmax(dim=1).reshape(-1,1)
    mask.scatter_(1, first_non_pad_token_index, False)
    # right_most_false_positions = torch.where((mask==0).any(1),(mask.shape[-1]-1)-(mask==0).flip(dims=[1]).float().argmax(dim=1), -1)
    mask = (inputs*mask)[:, :, None].eq(target_tokens.view(1, 1, -1)).any(2)
    # positions = torch.where(mask.any(1), (inputs.shape[-1] - 1) - mask.flip(dims=[1]).float().argmax(dim=1), -1)
    positions = torch.where(mask.any(1), mask.float().argmax(dim=1), inputs.shape[-1]-1 if wnby else inputs.shape[-1])

    return positions


def get_leftmost_token_position_with_false_tokens(inputs, target_tokens, false_tokens_list, wnby):
    target_masks = (inputs[:,:,None]==target_tokens)
    target_mask = target_masks[:,:,0] | target_masks[:,:,1]
    target_mask_shifted= torch.zeros_like(target_mask, device=inputs.device)
    target_mask_shifted[:,1:] = target_mask[:,:-1]
    target_mask = target_mask & (~target_mask_shifted)

    false_token_mask = torch.ones_like(inputs, device=inputs.device)
    for false_tokens in false_tokens_list:
        window = inputs.unfold(1, false_tokens.shape[0], 1)
        false_mask = (window==false_tokens).all(2)
        padding = torch.zeros(false_mask.shape[0], false_tokens.shape[0]-1, device=false_mask.device, dtype=torch.bool)
        false_mask = torch.cat((false_mask, padding),dim=-1)
        px, py = torch.where(false_mask==True)
        py_end = py+len(false_tokens)
        for x, y, y_end in zip(px, py, py_end):
            false_token_mask[x,y:y_end] = False

    first_non_pad_token_index = (inputs!=0).float().argmax(dim=1).reshape(-1,1)
    false_token_mask.scatter_(1, first_non_pad_token_index, False)

    final_mask = target_mask & false_token_mask
    positions = torch.where(final_mask.any(1), final_mask.float().argmax(dim=1), inputs.shape[-1]-1 if wnby else inputs.shape[-1])

    return positions


def find_leftmost_tokens_positions(input_ids: torch.LongTensor, tokens: Union[int, torch.LongTensor], wnby: bool=True) -> torch.LongTensor:
    """
    Get the indices where `tokens` first appear in the `input_ids` for each sample in the batch. When there aren't `tokens`, return seq_len-1 when `within_boundary`

    e.g.
    input_ids = torch.tensor([[1, 2, 3, 3], [7, 0, 4, 0], [3, 2, 1, 2]])
    tokens = torch.tensor([3, 0])
    find_leftmost_tokens_positions(input_ids, tokens)
    >> tensor([2, 1, 0])

    tokens = torch.tensor([3, 2])
    find_leftmost_tokens_positions(input_ids, tokens, wnby=True)
    >> tensor([1, 3, 0])

    find_leftmost_tokens_positions(input_ids, tokens, wnby=False)
    >> tensor([1, 4, 0])
    """
    assert input_ids.ndim == 2
    bsz, seq_len = input_ids.shape
    if isinstance(tokens, int):
        mask = input_ids.eq(tokens)
    elif isinstance(tokens, torch.Tensor):
        mask = input_ids[:, :, None].eq(tokens.view(1, 1, -1)).any(2)
    positions = torch.where(mask.any(1), mask.float().argmax(dim=1), seq_len-1 if wnby else seq_len)
    return positions


def find_rightmost_tokens_positions(input_ids: torch.LongTensor, tokens: Union[int, torch.LongTensor], wnby: bool=True) -> torch.LongTensor:
    """
    Get the index where `tokens` last appear in the `input_ids` for each sample in the batch. When there aren't `tokens`, return 0 when `within_boundary`

    e.g.
    input_ids = torch.tensor([[1, 2, 3, 3], [7, 0, 4, 0], [3, 2, 1, 2]])
    tokens = torch.tensor([3, 0])
    find_rightmost_tokens_positions(input_ids, tokens)
    >> tensor([3, 3, 0])

    tokens = torch.tensor([3, 2])
    find_rightmost_tokens_positions(input_ids, tokens, wnby=True)
    >> tensor([3, 0, 3])

    find_rightmost_tokens_positions(input_ids, tokens, wnby=False)
    >> tensor([3, -1, 3])
    """
    assert input_ids.ndim == 2
    bsz, seq_len = input_ids.shape
    if isinstance(tokens, int):
        mask = input_ids.eq(tokens)
    elif isinstance(tokens, torch.Tensor):
        mask = input_ids[:, :, None].eq(tokens.view(1, 1, -1)).any(2)
    positions = torch.where(mask.any(1), (seq_len - 1) - mask.flip(dims=[1]).float().argmax(dim=1), 0 if wnby else -1)
    return positions


def find_leftmost_notpadded_positions(tensor: torch.Tensor, pad_value: Union[int, float], wnby: bool=True) -> torch.Tensor:
    """Get the index of the first not-pad token in the left for each sample in the batch `tensor`. When they are all pad_value, return seq_len-1 when within_boundary"""
    assert tensor.ndim == 2
    bsz, seq_len = tensor.shape
    mask = tensor.ne(pad_value)
    positions = torch.where(mask.any(1), mask.float().argmax(dim=1), seq_len-1 if wnby else seq_len)
    return positions


def count_left_padding(tensor: torch.Tensor, pad_value: Union[int, float]) -> torch.Tensor:
    """For left padding. Count pad_value in the left of `tensor`"""
    seq_len = tensor.shape[-1]
    positions = find_leftmost_notpadded_positions(tensor, pad_value=pad_value, wnby=False)
    return positions


def count_not_left_padding(tensor: torch.Tensor, pad_value: Union[int, float]) -> torch.Tensor:
    """For left padding. Count not pad_value of `tensor`"""
    counts = count_left_padding(tensor, pad_value=pad_value)
    return tensor.shape[-1] - counts


def count_shared_left_padding(tensor: torch.Tensor, pad_value: Union[int, float]) -> torch.Tensor:
    """For left padding. Return the minimal padding length in the batch `tensor`"""
    return count_left_padding(tensor, pad_value).min()


def find_rightmost_notpadded_positions(tensor: torch.Tensor, pad_value: Union[int, float], wnby: bool=True) -> torch.Tensor:
    """For right padding. Get the index of the last not-pad token for each sample in the batch `tensor`. When they are all pad_value, return 0 when within_boundary"""
    assert tensor.ndim == 2
    bsz, seq_len = tensor.shape
    mask = tensor.ne(pad_value)
    positions = torch.where(mask.any(1), (seq_len - 1) - mask.flip(dims=[1]).float().argmax(dim=1), 0 if wnby else -1)
    return positions


def count_right_padding(tensor: torch.Tensor, pad_value: Union[int, float]) -> torch.Tensor:
    """For right padding. Count pad_value in the right of `tensor`"""
    seq_len = tensor.shape[-1]
    positions = find_rightmost_notpadded_positions(tensor, pad_value=pad_value, wnby=False)
    return (seq_len - 1) - positions


def get_new_generated_tokens(input_ids: torch.LongTensor, past_token_lens: torch.LongTensor, pad_token_id: int=0):
    """Mask past tokens and only reserve the newly generated tokens"""
    n_paddings = count_left_padding(input_ids, pad_value=pad_token_id)
    return mask_by_borders_2D(input_ids, right_borders=n_paddings + past_token_lens, include_right=False, value=pad_token_id)


def get_new_generated_tokens_with_forward(input_ids: torch.LongTensor, past_token_lens: torch.LongTensor, pad_token_id: int=0, forward_num: int=0):
    """Mask past tokens and only reserve the newly generated tokens"""
    n_paddings = count_left_padding(input_ids, pad_value=pad_token_id)
    borders =(n_paddings + past_token_lens) - forward_num
    return mask_by_borders_2D(input_ids, right_borders=borders, include_right=False, value=pad_token_id)


def get_mask_for_seq_area(tensor: torch.Tensor, left_borders: Optional[torch.LongTensor]=None, right_borders: Optional[torch.LongTensor]=None, include_left: bool=False, include_right: bool=False):
    """Return a mask with True in the specified areas"""
    assert not (left_borders is None and right_borders is None)
    bsz, seq_len = tensor.shape

    if include_left and left_borders is not None:
        left_borders = left_borders - 1
    if include_right and right_borders is not None:
        right_borders = right_borders + 1

    if left_borders is not None and right_borders is not None:
        mask = torch.logical_and(
            torch.arange(seq_len).view(1, -1).to(tensor.device) > left_borders.view(-1, 1),
            torch.arange(seq_len).view(1, -1).to(tensor.device) < right_borders.view(-1, 1)
        )
    elif left_borders is not None:
        mask = (torch.arange(seq_len).view(1, -1).to(tensor.device) > left_borders.view(-1, 1))
    elif right_borders is not None:
        mask = (torch.arange(seq_len).view(1, -1).to(tensor.device) < right_borders.view(-1, 1))
    return mask


def mask_by_borders_2D(
    tensor: torch.Tensor, 
    left_borders: Optional[torch.LongTensor] = None, 
    right_borders: Optional[torch.LongTensor] = None, 
    include_left: bool = False, 
    include_right: bool = False,
    value: Union[int, float] = 0,
):
    """Fill before/after borders into value"""
    mask = get_mask_for_seq_area(tensor=tensor, left_borders=left_borders, right_borders=right_borders, include_left=include_left, include_right=include_right)
    return tensor.masked_fill(mask, value=value)


def mask_by_borders_past_key_values(
    past_key_values: Tuple[Tuple[torch.FloatTensor]], 
    left_borders: torch.LongTensor = None, 
    right_borders: torch.LongTensor = None, 
    include_left: bool = False, 
    include_right: bool = False,
    value: Union[int, float] = 0,
):
    """Fill before/after borders into value"""
    mask = get_mask_for_seq_area(past_key_values[0][0][:, 0, :, 0], left_borders=left_borders, right_borders=right_borders, include_left=include_left, include_right=include_right)
    mask = mask[:, None, :, None].expand_as(past_key_values[0][0])

    return tuple(tuple(past_key_value.masked_fill(mask.to(past_key_value.device), value=value) for past_key_value in layer_past_key_values) for layer_past_key_values in past_key_values)


def batched_shift_along_seq_dim_2D(tensor: torch.Tensor, shifts: torch.LongTensor=None):
    """Shift a tensor based on the shifts along seq_dim"""
    bsz, seq_len = tensor.shape
    assert shifts.numel() == bsz

    arange1 = torch.arange(seq_len).view((1, seq_len)).to(tensor.device)
    arange2 = ((arange1 - shifts.view((bsz, 1))) % seq_len)

    return torch.gather(tensor, 1, arange2)


def batched_shift_along_seq_dim_past_key_values(past_key_values: Tuple[Tuple[torch.FloatTensor]], shifts: torch.LongTensor=None):
    """Shift a tensor based on the shifts along seq_dim"""
    bsz = past_key_values[0][0].shape[0]
    seq_len = past_key_values[0][0].shape[2]
    assert shifts.numel() == bsz

    arange1 = torch.arange(seq_len).view((1, seq_len)).to(past_key_values[0][0].device)
    arange2 = ((arange1 - shifts.view((bsz, 1))) % seq_len)

    arange2 = arange2[:, None, :, None].expand_as(past_key_values[0][0])
    return tuple(tuple(torch.gather(past_key_values[i][j], 2, arange2.to(past_key_values[i][j].device)) for j in range(len(past_key_values[i]))) for i in range(len(past_key_values)))


def shift_padding_to_left_2D(tensor: torch.Tensor, pad_value: Union[int, float] = 0):
    """Shift right padding in `tensor` to the left"""
    bsz, seq_len = tensor.shape
    shifts = count_right_padding(tensor, pad_value=pad_value)

    return batched_shift_along_seq_dim_2D(tensor, shifts=shifts)


class StepBeamSearch(GenerationMixin):

    def __init__(self, generate_policy, generate_kwargs, value_function, value_kwargs, generation_config, model_config, device):
        self.generate_policy = generate_policy
        self.generate_kwargs = generate_kwargs
        self.value_function = value_function
        self.value_kwargs = value_kwargs
        self.generation_config = deepcopy(generation_config)
        self.config = model_config
        self.device = device
        self.end_of_step_ids = torch.tensor([generation_config.end_of_step_id, generation_config.eos_token_id], device=self.device)
        self.false_eos_ids = [torch.tensor(false_eos_id).to(self.device) for false_eos_id in self.generation_config.false_eos_ids]
    
    def _shift_padding_to_left(self, token_ids: torch.LongTensor,  verifier_scores: torch.LongTensor, past_key_values: Tuple[Tuple[torch.FloatTensor]]=None, transition_scores: torch.FloatTensor=None):
        """Shift right padding in `token_ids` to the left, and adjust `past_key_values` and `transition_scores` correspondingly"""
        bsz, seq_len = token_ids.shape
        shifts = count_right_padding(token_ids, pad_value=self.generation_config.pad_token_id)

        token_ids = batched_shift_along_seq_dim_2D(token_ids, shifts=shifts)
        verifier_scores = shift_padding_to_left_2D(verifier_scores, pad_value=0)
        past_key_values = batched_shift_along_seq_dim_past_key_values(past_key_values, shifts=shifts) if past_key_values is not None else None
        transition_scores = shift_padding_to_left_2D(transition_scores, pad_value=0) if transition_scores is not None else None
        return token_ids, verifier_scores, past_key_values, transition_scores
    
    def _cut_after_eos_lp(self, input_ids: torch.LongTensor, verifier_scores: torch.LongTensor, past_key_values: Tuple[Tuple[torch.FloatTensor]]=None, transition_scores: torch.FloatTensor=None, past_token_lens: torch.LongTensor=None):
        """Mask the tokens after eos and keep it left padding"""
        new_past_key_values = past_key_values
        new_transition_scores = transition_scores

        valid_borders_right = find_leftmost_tokens_positions(input_ids, self.generation_config.eos_token_id, wnby=True)

        # new_verifier_scores = torch.gather(verifier_scores,1,valid_borders_right.reshape(-1,1))
        new_verifier_scores = mask_by_borders_2D(verifier_scores, left_borders=valid_borders_right, include_left=False, value=self.generation_config.pad_token_id)

        new_input_ids = mask_by_borders_2D(input_ids, left_borders=valid_borders_right, include_left=False, value=self.generation_config.pad_token_id)

        if past_key_values is not None:
            new_past_key_values = mask_by_borders_past_key_values(past_key_values, left_borders=valid_borders_right, include_left=False, value=0)
        
        if transition_scores is not None:
            generate_begin_indices = count_left_padding(input_ids, pad_value=self.generation_config.pad_token_id) + past_token_lens
            n_left_padding = count_left_padding(transition_scores, pad_value=0)
            borders_for_transitions = valid_borders_right - generate_begin_indices + n_left_padding
            new_transition_scores = mask_by_borders_2D(transition_scores, left_borders=borders_for_transitions, include_left=False, value=0)

        new_input_ids, new_verifier_scores, new_past_key_values, new_transition_scores = self._shift_padding_to_left(new_input_ids, new_verifier_scores, new_past_key_values, new_transition_scores)
        return new_input_ids, new_verifier_scores, new_past_key_values, new_transition_scores
    
    def _cut_latter_steps(self, input_ids: torch.LongTensor, verifier_scores: torch.LongTensor, past_key_values: Tuple[Tuple[torch.FloatTensor]]=None, transition_scores: torch.FloatTensor=None, past_token_lens: torch.LongTensor=None):
        """Mask the latter steps and keep it left padding"""
        new_past_key_values = past_key_values
        new_transition_scores = transition_scores

        new_tokens = get_new_generated_tokens(input_ids, past_token_lens=past_token_lens, pad_token_id=self.generation_config.pad_token_id)
        cur_step_borders_right = find_rightmost_tokens_positions(new_tokens, self.end_of_step_ids, wnby=True)

        new_input_ids = mask_by_borders_2D(input_ids, left_borders=cur_step_borders_right, include_left=False, value=self.generation_config.pad_token_id)
        new_verifier_scores =  mask_by_borders_2D(verifier_scores, left_borders=cur_step_borders_right, include_left=False, value=self.generation_config.pad_token_id)

        if past_key_values is not None:
            new_past_key_values = mask_by_borders_past_key_values(past_key_values, left_borders=cur_step_borders_right, include_left=False, value=0)

        if transition_scores is not None:
            generate_begin_indices = count_left_padding(input_ids, pad_value=self.generation_config.pad_token_id) + past_token_lens
            n_left_padding = count_left_padding(transition_scores, pad_value=0)
            borders_for_transitions = cur_step_borders_right - generate_begin_indices + n_left_padding
            new_transition_scores = mask_by_borders_2D(transition_scores, left_borders=borders_for_transitions, include_left=False, value=0)

        new_input_ids, new_verifier_scores, new_past_key_values, new_transition_scores = self._shift_padding_to_left(new_input_ids, new_verifier_scores, new_past_key_values, new_transition_scores)
        return new_input_ids, new_verifier_scores, new_past_key_values, new_transition_scores
    
    def _cut_latter_steps_with_false_tokens(self, input_ids: torch.LongTensor, verifier_scores: torch.LongTensor, past_key_values: Tuple[Tuple[torch.FloatTensor]]=None, transition_scores: torch.FloatTensor=None, past_token_lens: torch.LongTensor=None):
        """Mask the latter steps and keep it left padding"""
        new_past_key_values = past_key_values
        new_transition_scores = transition_scores

        new_tokens = get_new_generated_tokens_with_forward(input_ids, past_token_lens=past_token_lens, pad_token_id=self.generation_config.pad_token_id, forward_num=0)
        cur_step_borders_right = get_leftmost_token_position_with_false_tokens(new_tokens, self.end_of_step_ids, self.false_eos_ids, True)

        new_input_ids = mask_by_borders_2D(input_ids, left_borders=cur_step_borders_right, include_left=False, value=self.generation_config.pad_token_id)
        new_verifier_scores =  mask_by_borders_2D(verifier_scores, left_borders=cur_step_borders_right, include_left=False, value=self.generation_config.pad_token_id)

        if past_key_values is not None:
            new_past_key_values = mask_by_borders_past_key_values(past_key_values, left_borders=cur_step_borders_right, include_left=False, value=0)

        if transition_scores is not None:
            generate_begin_indices = count_left_padding(input_ids, pad_value=self.generation_config.pad_token_id) + past_token_lens
            n_left_padding = count_left_padding(transition_scores, pad_value=0)
            borders_for_transitions = cur_step_borders_right - generate_begin_indices + n_left_padding
            new_transition_scores = mask_by_borders_2D(transition_scores, left_borders=borders_for_transitions, include_left=False, value=0)

        new_input_ids, new_verifier_scores, new_past_key_values, new_transition_scores = self._shift_padding_to_left(new_input_ids, new_verifier_scores, new_past_key_values, new_transition_scores)
        return new_input_ids, new_verifier_scores, new_past_key_values, new_transition_scores

    def _mask_former_steps(self, input_ids: torch.LongTensor, past_token_lens: torch.LongTensor=None):
        """Mask the former steps"""
        n_paddings = count_left_padding(input_ids, pad_value=self.generation_config.pad_token_id)
        cur_step_borders_left = n_paddings + past_token_lens

        input_ids = mask_by_borders_2D(input_ids, right_borders=cur_step_borders_left, include_right=False, value=self.generation_config.pad_token_id)
        return input_ids

    def _truncate_left_padding(self, token_ids: torch.LongTensor, past_key_values: Tuple[Tuple[torch.FloatTensor]]=None, transition_scores: torch.FloatTensor=None):
        n_truncate = count_shared_left_padding(token_ids, pad_value=self.generation_config.pad_token_id)

        token_ids = token_ids[:, n_truncate:]
        if past_key_values is not None:
            past_key_values = tuple(tuple(past_key_value[:, :, n_truncate:] for past_key_value in layer_past_key_values) for layer_past_key_values in past_key_values)

        if transition_scores is not None:
            n_truncate = count_shared_left_padding(transition_scores, pad_value=0)
            transition_scores = transition_scores[:, n_truncate:]
        return token_ids, past_key_values, transition_scores
    
    def step_beam_search(
        self,
        input_ids: torch.LongTensor,
        num_beams: Optional[int] = None,
        max_search_steps: Optional[int] = None,
        num_samples_per_search_step: Optional[int] = None,
        max_new_token_per_step: Optional[int] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        **model_kwargs,
        ):
        batch_size = input_ids.shape[0]
        if num_beams is None:
            num_beams = self.generation_config.search_num_beams
        if max_search_steps is None:
            max_search_steps = self.generation_config.max_search_steps
        if num_samples_per_search_step is None:
            num_samples_per_search_step = self.generation_config.num_samples_per_search_step
        if max_new_token_per_step is None:
            max_new_token_per_step = self.generation_config.max_new_token_per_step
        if logits_processor is None:
            logits_processor = self.generate_kwargs.get("logits_processor", None)
        if stopping_criteria is None:
            stopping_criteria = self.generate_kwargs.get("stopping_criteria", None)
        # generate_kwargs contains 
        # 'input_ids', 'logits_processor', 'stopping_criteria', 'pad_token_id', 'eos_token_id', 
        # 'output_scores', 'return_dict_in_generate', 'synced_gpus', 'streamer', 'output_past_key_values', 
        # 'past_key_values', 'output_attentions', 'output_hidden_states', 'use_cache', 'attention_mask'
        self.generation_config.max_length = input_ids.shape[-1] + self.generation_config.max_new_tokens
        input_ids = input_ids.repeat_interleave(num_beams, dim=0)
        # 写search算法
        input_length = input_ids.shape[-1]
        cur_length = input_ids.shape[-1]
        cur_step = 0
        past_key_values = None

        all_path_sequences = []
        all_path_vscores = []
        all_choices = []

        output_transition_scores = self.generation_config.value_by_transition_scores
       
        cur_node_values=torch.zeros((input_ids.shape[0],1), device=input_ids.device)
        cur_node_transition_scores=torch.zeros((input_ids.shape[0],1),device=input_ids.device) if output_transition_scores else None
        
        while True:
            max_length_till_node = cur_length + max_new_token_per_step
            sample_nums = num_samples_per_search_step // num_beams
            node_logits_processor = logits_processor
            node_stop_criteria = None
            
            candidate_nodes = self._expand_nodes(
                input_ids=input_ids,
                cur_node_values=cur_node_values,
                cur_node_transition_scores=cur_node_transition_scores,
                past_key_values=past_key_values,
                sample_nums=sample_nums,
                batch_size=batch_size,
                max_length_till_node=min(max_length_till_node, self.generation_config.max_length),
                logits_processor=node_logits_processor,
                stopping_criteria=node_stop_criteria,
                output_past_key_values=False,
                output_transition_scores=output_transition_scores,
            )

            if candidate_nodes.verifier_scores is None: 
                # 给所有children单步结果打分、logprob连乘、reward function、value function、蒙特卡罗 TODO: give value by input_ids
                candidate_nodes.verifier_scores = self.value_function(
                    **{**self.value_kwargs, **{'input_ids':candidate_nodes.sequences}}
                    )

            # 记录每步的candidate sequence
            all_path_sequences.append(candidate_nodes.sequences) 
            all_path_vscores.append(candidate_nodes.verifier_scores)

            # select 根据candidates_scores选择candidates 不同的算法算分逻辑不一样
            selected_squences, selected_vscores, selected_transition_scores, selected_past_key_values, selected_indices = self._select_nodes(
                batch_candidates=candidate_nodes,
                batch_size=batch_size,
                num_beams=num_beams,
                dedup_mode=self.generation_config.dedup_mode)

            all_choices.append(selected_indices)
            
            # 到达停止条件停止
            if selected_squences.eq(self.generation_config.eos_token_id).any(1).all():
                break
            if stopping_criteria(selected_squences, ()):
                break

            # 根据选择的candidate sequence处理input，left padding，相当于前进一步
            input_ids, past_key_values, selected_transition_scores = self._truncate_left_padding(selected_squences, selected_past_key_values, selected_transition_scores)
            cur_node_values = selected_vscores
            cur_node_transition_scores = selected_transition_scores if output_transition_scores else None

            # 更新长度
            cur_length = input_ids.shape[-1]
            cur_step += 1

            # 到达停止条件停止：所有candidates都有EOS、到达最大长度、最大step深度
            if cur_step >= max_search_steps:
                break
         
        # final selection
        reshaped_selected_vscores = selected_vscores.reshape(batch_size, -1, selected_vscores.shape[-1])
        _, best_index = torch.topk(reshaped_selected_vscores, k=1, dim=1, largest=True)
        best_index = best_index.reshape(batch_size, -1)
        best_index = torch.tensor([i*reshaped_selected_vscores.shape[1]+idx for i, idxes in enumerate(best_index) for idx in idxes ]).to(best_index.device) 
        
        # all_path_sequences.append(selected_squences)
        # all_path_vscores.append(selected_vscores)
        # all_choices.append(best_index)

        sequence = selected_squences.index_select(0, best_index)
        transition_scores = selected_transition_scores.index_select(0, best_index) if selected_transition_scores is not None else selected_transition_scores
        
        if self.generation_config.return_dict_in_generate:
            return StepSamplingOutput(
                sequences=sequence,
                transition_scores=transition_scores
            )
        elif self.generation_config.return_all_search_sequences:
            end_sequences = all_path_sequences[-1].reshape(batch_size, num_samples_per_search_step, -1)
            end_vscores = all_path_vscores[-1].reshape(batch_size, num_samples_per_search_step)
            ordered_vscores_indices = torch.argsort(end_vscores, dim=1, descending=True)
            ordered_sequences = torch.zeros_like(end_sequences, device=end_sequences.device)
            ordered_vscores = torch.zeros_like(end_vscores, device=end_vscores.device)
            for b, i in enumerate(ordered_vscores_indices):
                ordered_sequences[b]= end_sequences[b][i]
                ordered_vscores[b] = end_vscores[b][i]
            return StepSamplingOutput(
                sequences=ordered_sequences,
                verifier_scores=ordered_vscores,
            )
        else:
            return sequence

    def _expand_nodes(self,
        input_ids,
        cur_node_values,
        cur_node_transition_scores,
        past_key_values,
        sample_nums,
        batch_size,
        max_length_till_node,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        output_past_key_values=True,
        output_transition_scores=True,
    ) -> StepSamplingOutput:

        mask = input_ids[:,-1].ne(self.generation_config.eos_token_id)
        origin_input_ids = deepcopy(input_ids)
        origin_past_key_values = deepcopy(past_key_values)
        # 未结束的继续进行生成
        input_ids = input_ids[mask]
        input_ids = input_ids.repeat_interleave(sample_nums, dim=0)

        if past_key_values is not None:
            past_key_values = tuple(
                tuple(
                    past_key_value[mask.to(past_key_value.device)].repeat_interleave(sample_nums, dim=0)
                    for past_key_value in layer_past_key_values
                )
                for layer_past_key_values in past_key_values
            )

        # 恢复batch size
        nseqs = input_ids.shape[0]  # nseqs = batch_size * seach_step_sample_nums
        n_split = math.ceil(nseqs / batch_size)

        batch_outputs=[]
        # 分batch generate
        for i in range(n_split):
            cur_input = input_ids[i*batch_size: min((i+1)*batch_size, nseqs)]
            if past_key_values is not None:
                cur_past_key_values = tuple(tuple(past_key_value[i*batch_size: min((i+1)*batch_size, nseqs)] \
                    for past_key_value in layer_past_key_values) for layer_past_key_values in past_key_values)
            else:
                cur_past_key_values = None

            # 生成\n或着EOS结尾或着最大长度的step，同时重用past key values
            step_outputs = self._generate_node_step(                                           
                input_ids=cur_input,
                past_key_values=cur_past_key_values,
                max_length=max_length_till_node,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                output_transition_scores=output_transition_scores,
                output_past_key_values=output_past_key_values,
            )
            batch_outputs.append(step_outputs)
        
        batch_outputs = self._concat_group_steps(batch_outputs, dim=0)     
        # 结束的进行复制然后和之前的拼接

        indices = torch.where(mask==False)[0]
        if len(indices) > 0:
            for index in indices:
                if past_key_values is not None:
                    new_past_key_values = tuple(
                        tuple(
                            past_key_value[None,index].repeat_interleave(sample_nums, dim=0)
                            for past_key_value in layer_past_key_values
                        )
                        for layer_past_key_values in origin_past_key_values
                    )
                new_outputs = StepSamplingOutput(
                    sequences=origin_input_ids[index].repeat(sample_nums,1),
                    steps=torch.zeros(sample_nums,1,device=origin_input_ids.device),
                    transition_scores=cur_node_transition_scores[index].repeat(sample_nums,1) if output_transition_scores else None,
                    verifier_scores=cur_node_values[index].repeat(sample_nums,1),
                    past_key_values=new_past_key_values if output_past_key_values else None,
                )
                batch_index = index * sample_nums
                if past_key_values is not None:
                    new_past_key_values1 = tuple(
                            tuple(
                                past_key_value[:batch_index]
                                for past_key_value in layer_past_key_values
                            )
                            for layer_past_key_values in  batch_outputs.past_key_values
                        )
                    new_past_key_values2 = tuple(
                            tuple(
                                past_key_value[batch_index:]
                                for past_key_value in layer_past_key_values
                            )
                            for layer_past_key_values in batch_outputs.past_key_values
                        )
                unfinished_outputs1 = StepSamplingOutput(
                    sequences=batch_outputs.sequences[:batch_index],
                    steps=batch_outputs.steps[:batch_index],
                    transition_scores=batch_outputs.transition_scores[:batch_index] if output_transition_scores else None,
                    verifier_scores=batch_outputs.verifier_scores[:batch_index],
                    past_key_values=new_past_key_values1 if output_past_key_values else None,
                )
                unfinished_outputs2 = StepSamplingOutput(
                    sequences=batch_outputs.sequences[batch_index:],
                    steps=batch_outputs.steps[batch_index:],
                    transition_scores=batch_outputs.transition_scores[batch_index:] if output_transition_scores else None,
                    verifier_scores=batch_outputs.verifier_scores[batch_index:],
                    past_key_values=new_past_key_values2 if output_past_key_values else None,
                )
                if batch_index < batch_outputs.sequences.shape[0]:
                    batch_outputs = self._concat_group_steps([unfinished_outputs1, new_outputs, unfinished_outputs2], dim=0)
                else:
                    batch_outputs = self._concat_group_steps([unfinished_outputs1, new_outputs], dim=0)
                                          # [n_beam * n_sampling_steps_per_beam, seq_len]

        return batch_outputs

    @torch.inference_mode(mode=True)
    def _generate_node_step(self,
        input_ids,
        past_key_values,
        max_length,
        stopping_criteria: StoppingCriteria = None, 
        logits_processor: LogitsProcessorList = None, 
        output_transition_scores=False,
        output_past_key_values=False,
        **kwargs
    ) -> StepSamplingOutput:
        if (
            self.generation_config.pad_token_id is not None
            and len(input_ids.shape) == 2
            and torch.sum(input_ids[:, -1] == self.generation_config.pad_token_id) > 0
        ):
            print(
                "A decoder-only architecture is being used, but right-padding was detected! For correct "
                "generation results, please set `padding_side='left'` when initializing the tokenizer."
            )
        

        if logits_processor is None:
            logits_processor = LogitsProcessorList()

        if stopping_criteria is None:
            cur_token_lens = count_not_left_padding(input_ids, pad_value=self.generation_config.pad_token_id)
            step_stopping_criteria = StepStoppingCriteria(
                cur_token_lens=cur_token_lens,
                end_token_ids=self.end_of_step_ids,
                pad_token_id=self.generation_config.pad_token_id,
                false_eos_ids=self.generation_config.false_eos_ids,
                device=self.device)
            stopping_criteria = StoppingCriteriaList([step_stopping_criteria])

        input_token_lens = count_not_left_padding(input_ids, pad_value=self.generation_config.pad_token_id)

        cur_length = input_ids.shape[-1]
        max_new_tokens = max_length - cur_length

        step_generation_config = deepcopy(self.generation_config)
        step_generation_config.max_new_tokens = max_new_tokens
        step_generation_config.max_length = max_length
        stopping_criteria = self._get_stopping_criteria(
            generation_config=step_generation_config,
            stopping_criteria=stopping_criteria)

        original_output_hidden_states = kwargs.get("output_hidden_states", None) or self.generation_config.output_hidden_states

        update_kwargs = {
            'input_ids': input_ids,
            'attention_mask': input_ids.ne(self.generation_config.pad_token_id),
            'stopping_criteria': stopping_criteria,
            'use_cache': True,
            'return_dict_in_generate': True,
            'output_scores': output_transition_scores,
            'output_hidden_states': [-1] if not original_output_hidden_states else True,
            'past_key_values': past_key_values,
            'output_past_key_values': output_past_key_values
        }

        outputs = self.generate_policy(**{**self.generate_kwargs, **update_kwargs, **kwargs})

        transition_scores = None
        if output_transition_scores:
            transition_scores = self.compute_transition_scores(sequences=outputs.sequences, scores=outputs.scores, beam_indices=outputs.get('beam_indices'), normalize_logits=True)
            transition_scores = F.pad(transition_scores, (outputs.sequences.shape[-1]-transition_scores.shape[-1],0), value=0)
        
        if self.generation_config.value_by_transition_scores:
            verifier_scores = self.get_average_transition_scores(transition_scores, len(outputs.scores))
        else:
            for key in self.value_kwargs.keys():
                if key == 'input_ids':
                    self.value_kwargs['input_ids'] = outputs.sequences
                if key == 'attention_mask':
                    self.value_kwargs['attention_mask'] = (outputs.sequences != self.generation_config.pad_token_id).to(torch.int64)
                if key == 'last_hidden_state':
                    if original_output_hidden_states:
                        last_hidden_state = [hidden_states_t[-1] for hidden_states_t in outputs.hidden_states][1:]
                    else:
                        last_hidden_state = outputs.hidden_states[1:]
                        outputs.hidden_states = None
                    if last_hidden_state is None or len(last_hidden_state)==0:
                        raise ValueError("last hiddent state is None of length is 0")
                    self.value_kwargs['last_hidden_state'] = last_hidden_state
    
            value_outputs = self.value_function(**self.value_kwargs)
            assert value_outputs.shape == outputs.sequences.shape, 'Value shape is different from the input shape!'
            verifier_scores = F.pad(value_outputs, (outputs.sequences.shape[-1]-value_outputs.shape[-1], 0),value=self.generation_config.pad_token_id)

        input_ids, past_key_values = outputs.sequences, outputs.past_key_values

        input_ids, verifier_scores, past_key_values, transition_scores = self._cut_after_eos_lp(input_ids, verifier_scores, past_key_values, transition_scores, past_token_lens=input_token_lens)

        sequences, verifier_scores, past_key_values, transition_scores = self._cut_latter_steps_with_false_tokens(input_ids, verifier_scores, past_key_values, transition_scores, past_token_lens=cur_token_lens)

        steps = self._mask_former_steps(sequences, past_token_lens=cur_token_lens)

        return StepSamplingOutput(
            sequences=sequences,
            steps=steps,
            transition_scores=transition_scores,
            verifier_scores=verifier_scores[:, None, -1],
            past_key_values=past_key_values,
        )
    
    def get_average_transition_scores(self, transition_scores, step_length):
        mean_value = (transition_scores.sum(1)/step_length)
        mean_value = mean_value.reshape(-1,1).repeat_interleave(step_length,dim=1)
        mean_transition_scores = F.pad(mean_value, (transition_scores.shape[-1] - mean_value.shape[-1], 0), value=0 )
        return mean_transition_scores
    
    def _select_nodes(self, batch_candidates, num_beams=1, batch_size:int=1, dedup_mode=False):
        batch_sequences = batch_candidates.sequences  # [n_beam * n_sampling_steps_per_beam, seq_len]
        batch_vscores = batch_candidates.verifier_scores
        batch_transition_scores = batch_candidates.transition_scores

        # select the best steps/sequences
        hvscores = self._highlight_unique_sequences(batch_sequences, batch_vscores, dedup_mode=dedup_mode)
        hvscores = hvscores.reshape(batch_size, -1, hvscores.shape[-1])
        _, indices = torch.topk(hvscores, k=num_beams, dim=1, largest=True)
        indices = indices.reshape(batch_size,-1)
        indices = torch.tensor([i*hvscores.shape[1]+idx for i, idxes in enumerate(indices) for idx in idxes  ]).to(indices.device) 
        sequences = batch_sequences.index_select(0, indices) # [n_beam, seq_len]
        past_key_values = batch_candidates.past_key_values
        if past_key_values is not None:
            past_key_values = tuple(
                tuple(
                    past_key_value.index_select(0, indices.to(past_key_value.device))
                    for past_key_value in layer_past_key_values
                )
                for layer_past_key_values in past_key_values
            )
        vscores = batch_vscores.index_select(0, indices)
        
        transition_scores = batch_transition_scores.index_select(0, indices) if batch_transition_scores is not None else batch_transition_scores

        return sequences, vscores, transition_scores, past_key_values, indices

    
    def _get_stopping_criteria(
        self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList]
    ) -> StoppingCriteriaList:
        criteria = StoppingCriteriaList()
        if generation_config.max_length is not None:
            max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
            criteria.append(
                MaxLengthCriteria(
                    max_length=generation_config.max_length,
                    max_position_embeddings=max_position_embeddings,
                )
            )
        if generation_config.max_time is not None:
            criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
        criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
        return criteria
    
    def _merge_criteria_processor_list(
        self,
        default_list: Union[LogitsProcessorList, StoppingCriteriaList],
        custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
    ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
        if len(custom_list) == 0:
            return default_list
        for default in default_list:
            for custom in custom_list:
                if type(custom) is type(default):
                    object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
                    raise ValueError(
                        f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
                        f" `.generate()`, but it has already been created with the values {default}. {default} has been"
                        " created by passing the corresponding arguments to generate or by the model's config default"
                        f" values. If you just want to change the default values of {object_type} consider passing"
                        f" them as arguments to `.generate()` instead of using a custom {object_type}."
                    )
        default_list.extend(custom_list)
        return default_list

    def _highlight_unique_sequences(self, sequences: torch.LongTensor, verifier_scores: torch.FloatTensor, dedup_mode: int=0) -> torch.FloatTensor:
        """
        Prioritize unique sequences: linguistics-level (mode=1)
        """
        if dedup_mode == 0:
            return verifier_scores
        
        seq_len = sequences.shape[-1]
        
        seqs = shift_padding_to_left_2D(sequences, pad_value=self.generation_config.pad_token_id)
        multipliers = torch.pow(torch.full((seq_len,), 31, dtype=seqs.dtype, device=self.device), torch.arange(seq_len, device=self.device))
        hashes = (seqs * multipliers).sum(dim=1)

        unique_hashes = torch.unique(hashes)
        hightlighted_indices = (unique_hashes[:, None] == hashes[None, :]).float().argmax(dim=1)

        highlighted_vscores = verifier_scores.clone()
        highlighted_vscores[hightlighted_indices] += 100
        return highlighted_vscores

    def _concat_group_tensors(self, tensor_list: List[torch.Tensor], left_padding = True, pad_value: int = 0, dim: int = 0):
        max_len = max(tensor.shape[-1] for tensor in tensor_list)
        if left_padding:
            tensor_list = [F.pad(tensor, (max_len - tensor.shape[-1], 0), value=pad_value) for tensor in tensor_list]
        else:
            tensor_list = [F.pad(tensor, (0, max_len - tensor.shape[-1]), value=pad_value) for tensor in tensor_list]

        tensors = torch.concat(tensor_list, dim=dim)
        return tensors

    def _concat_group_past_key_values(self, past_key_values: List[Tuple[Tuple[torch.FloatTensor]]], token_padding_lens: torch.LongTensor, dim: int = 0):
        # w/o beam: (bsz, n_heads, cache_len, embed_size)
        # w beam: (n_beam, n_sampling_steps_per_beam, n_heads, cache_len, embed_size)

        cache_lens = torch.LongTensor([cache[0][0].shape[-2] for cache in past_key_values]).to(self.device)
        padded_cache_lens = token_padding_lens + cache_lens
        min_cache_len = padded_cache_lens.min()
        cut_cache_lens = padded_cache_lens - min_cache_len

        past_key_values = tuple(
            tuple(
                torch.cat(
                    [F.pad(tensor.transpose(-2, -1), (token_padding_lens[i], -cut_cache_lens[i]), value=0).transpose(-2, -1) for i, tensor in enumerate(tensor_tuples)], 
                    dim=dim
                )
                for tensor_tuples in zip(*layer_tuples)
            )
            for layer_tuples in zip(*past_key_values)
        )
        return past_key_values
    
    def _concat_group_steps(self, instances: List[StepSamplingOutput], dim: int = 0):
        sequences, steps, transition_scores, verifier_scores, past_key_values = tuple([instance.get(key) for instance in instances] for key in ("sequences", "steps", "transition_scores", "verifier_scores", "past_key_values"))
        
        seq_lens = torch.LongTensor([seq.shape[-1] for seq in sequences]).to(self.device)
        max_seq_len = seq_lens.max()
        token_padding_lens = max_seq_len - seq_lens

        sequences = self._concat_group_tensors(sequences, pad_value=self.generation_config.pad_token_id, dim=dim)
        steps = self._concat_group_tensors(steps, pad_value=self.generation_config.pad_token_id, dim=dim)
        transition_scores = self._concat_group_tensors(transition_scores, pad_value=0, dim=dim) if transition_scores[0] is not None else None
        verifier_scores = torch.cat(verifier_scores, dim=dim) if verifier_scores[0] is not None else None

        past_key_values = self._concat_group_past_key_values(past_key_values, token_padding_lens, dim=dim) if past_key_values[0] is not None else None

        return StepSamplingOutput(
            sequences=sequences,
            steps=steps,
            transition_scores=transition_scores,
            verifier_scores=verifier_scores,
            past_key_values=past_key_values,
        )


class PythonCodeInterpreter:
    """ Code interpreter for executing python code presented as a list of strings, or nested list of strings.
    This class is intended to be used with model.generate method to evaluate python code on the fly.

    The idea is:
    1. detect that a code snippet has been generated by checking code pattern (code_prefix ... code_suffix);
    2. call Python executor to calculate the code results and save them to next_tokens_cache;
    3. when the input_ids is passed again, pop out the next token from next_tokens_cache.
    """
    def __init__(
        self, 
        tokenizer, 
        code_prefix="```[pP]ython", 
        code_suffix="```", 
        code_output_prefix="```output",
        code_output_suffix="```", 
        executor=None,
        ):
        # check tokenizer
        self.tokenizer = tokenizer
        assert self.tokenizer.pad_token_id is not None, RuntimeError("Tokenizer has no pad_token.")
        self._len_newline_ids = len(self.tokenizer.encode("\n", add_special_tokens=False))
        self.padding_side = self.tokenizer.padding_side
        assert self.padding_side in {"left", "right"}, NotImplementedError(
            "Only support left or right as padding_side; got {}".format(self.padding_side))
        # check code and output prefix/suffix
        self._code_prefix = code_prefix
        self._code_suffix = code_suffix
        self.code_pattern = fr"{code_prefix}([\s\S]*?){code_suffix}"
        self._code_suffix_ids = self.tokenizer.encode("\n"+self._code_suffix, add_special_tokens=False)[self._len_newline_ids:]
        self._len_suffix = len(self._code_suffix_ids)
        self._code_output_prefix = code_output_prefix
        self._code_output_suffix = code_output_suffix
        # configure python executor
        if executor is None:
            from utils.solver_utils import PythonExecutor

            self.executor = PythonExecutor(
                get_answer_from_stdout=True, share_runtime_in_batch=False)
        else:
            self.executor = executor
        # other arguments
        self.max_length = None  # passed during runtime to truncate the input_ids after padding
        self.next_tokens_cache = None
        self._message_printed = False

    def _parse_code_execution_result(self, result):
        if isinstance(result[0], list):
            return self._parse_code_execution_result(list(zip(*result))[-1])  # get last result

        result, exec_info = result
        if exec_info == "Done":
            result = result.strip()
            if re.match(r"^[0-9]+\.[0-9]+$", result):  # use round to avoid floating point error
                result = str(round(float(eval(result)), 8))
        else:
            result = exec_info

        if result:  # in rare case, the code has no output, and output prefix and suffix should not be attached
            result = self._code_output_prefix + result + self._code_output_suffix

        return result

    def execute_on_inputs_ids(self, input_ids, is_code_suffix=None):
        if not self._message_printed:
            self._message_printed = True
            logger.info("Confirmation of Code Interpreter in use")

        if is_code_suffix is None:
            is_code_suffix = [True] * input_ids.shape[0]
        batch_input_w_code = [self.tokenizer.decode(_input_ids) for _is_code_suffix, _input_ids in zip(is_code_suffix, input_ids) if _is_code_suffix]
        # double check
        if not all(input_w_code.endswith(self._code_suffix) for input_w_code in batch_input_w_code):
            for input_w_code in batch_input_w_code:
                if not input_w_code.endswith(self._code_suffix):
                    raise RuntimeError("Expect inputs to end with code tag {}; got\n{}".format(self._code_suffix, input_w_code))
        
        batch_code = [re.findall(self.code_pattern, input_w_code) for input_w_code in batch_input_w_code]
        batch_code = [code[0].strip() if len(code) == 1 else [_code.strip() for _code in code] for code in batch_code]
        batch_results = self.executor(batch_code)
        batch_results = [self._parse_code_execution_result(result) for result in batch_results]
        if not any(batch_results):  # in rare case, the code has no output, and no ids should be returned
            return []

        # NOTE cannot use tokenizer to encode input_text + result, because the following may not satisfy: 
        # encode(decode(input_ids, ...), ...) == input_ids
        # This is because the model may not always generate the tokens for a word that the tokenizer would use.
        # However, encode(decode(text, ...), ...) == text is often True
        # So the following may cause errors sometimes
        # batch_input_w_code_output = [code + result for code, result in zip(batch_input_w_code, batch_results)]
        # so we prepend a \n, encode it, and remove the \n
        batch_output_ids = [self.tokenizer.encode('\n' + result, add_special_tokens=False)[self._len_newline_ids:] for result in batch_results]

        return batch_output_ids

    def reset(self):
        self.next_tokens_cache = None

    def __call__(
        self, 
        input_ids, 
        next_tokens=None, 
        **kwargs
        ):
        
        if (input_ids.shape[1] > self._len_suffix) and (input_ids.shape[1] < self.max_length if self.max_length else True):
            # prepare code_cache and _code_suffix_ids
            if not isinstance(self.next_tokens_cache, list) or len(self.next_tokens_cache) != input_ids.shape[0]:
                self.next_tokens_cache = [None] * input_ids.shape[0]
            if isinstance(self._code_suffix_ids, list):
                self._code_suffix_ids = torch.tensor(self._code_suffix_ids, dtype=input_ids.dtype, device=input_ids.device)
                self._code_suffix_ids = self._code_suffix_ids.unsqueeze(dim=0)
            else:  # torch tensor
                self._code_suffix_ids = self._code_suffix_ids.to(input_ids.device)
            
            # check if new code snippets have been generated; if so, add the solution to self.next_tokens_cache
            is_code_suffix = (input_ids[:, -self._len_suffix:] == self._code_suffix_ids).all(dim=1)
            if is_code_suffix.any():
                batch_output_ids = self.execute_on_inputs_ids(input_ids, is_code_suffix)
                if batch_output_ids:  # in rare case, the code has no output, and no action should be conducted
                    batch_output_ids_iter = iter(batch_output_ids)
                    for _id, is_code_suffix in enumerate(is_code_suffix):
                        if is_code_suffix:
                            assert self.next_tokens_cache[_id] is None, RuntimeError("Cannot store output_ids in non-empty next_tokens_cache")
                            self.next_tokens_cache[_id] = next(batch_output_ids_iter)
            
            # check if existing self.next_tokens_cache is not empty; if so, pop it
            if any(self.next_tokens_cache):
                for _id, next_tokens_cache in enumerate(self.next_tokens_cache):
                    if next_tokens_cache:
                        next_tokens[_id] = next_tokens_cache.pop(0)
                self.next_tokens_cache = [
                    next_tokens_cache if isinstance(next_tokens_cache, list) and len(next_tokens_cache) > 0 else None \
                        for next_tokens_cache in self.next_tokens_cache]

        return input_ids, next_tokens