"""
Batch Generation with Calculator using Cache
"""



import torch
import json
import os
import re
import transformers
from transformers import StoppingCriteria, StoppingCriteriaList, LogitsProcessor, LogitsProcessorList
import torch.nn.functional as F
import torch.nn as nn
from accelerate.accelerator import Accelerator
from typing import Union, Sequence, List, Set, Tuple, Dict, Any, Optional
from dataclasses import dataclass
import numpy as np
import math
import gc
from tqdm import tqdm
from transformers.generation.utils import ModelOutput
from utils.cached_models import PreTrainedTokenizer
from utils.constants import LLAMA_EQUALS_TOKENS, LLAMA_LEFTMARK_TOKENS, LLAMA_RIGHTMARK_TOKEN, LLAMA_NEWLINE_TOKEN
from utils.gsm8k.decoding import use_calculator



@dataclass
class SamplingOutput(ModelOutput):
    sequences: torch.LongTensor = None
    transition_scores: Optional[torch.FloatTensor] = None
    past_key_values: Tuple[Tuple[torch.FloatTensor]] = None

@dataclass
class StepSamplingOutput(ModelOutput):
    sequences: torch.LongTensor = None
    steps: torch.LongTensor = None
    transition_scores: Optional[torch.FloatTensor] = None
    verifier_scores: Optional[torch.FloatTensor] = None
    past_key_values: Tuple[Tuple[torch.FloatTensor]] = None


class BatchCalculatorCallingCriteria(StoppingCriteria):
    # stop when one sample in the batch meets the calculator marker
    def __init__(self, keywords_ids: set, left_mark_tokens: set, device: torch.device):
        self.keyword_tokens = torch.tensor(list(keywords_ids)).to(device)
        self.left_mark_tokens = torch.tensor(list(left_mark_tokens)).to(device)

    def _is_hit_keywords(self, input_ids: torch.LongTensor) -> bool:
        return input_ids.unsqueeze(-1).eq(self.keyword_tokens.view((1, 1, -1))).any(2)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        is_hit_keywords = self._is_hit_keywords(input_ids[:, -1:])
        if not is_hit_keywords.any():
            return False

        index = is_hit_keywords.nonzero()[:, 0]
        token_ids = input_ids.index_select(0, index)[:, :-1]

        last_mark_indices = find_rightmost_tokens_positions(token_ids, self.left_mark_tokens, wnby=False)
        is_hit_left_mark = last_mark_indices.ne(-1)
        is_hit_keywords_after_mark = count_tokens_after_positions(token_ids, tokens=self.keyword_tokens, positions=last_mark_indices, include_pos=False).ne(0)
        if torch.logical_and(is_hit_left_mark, ~is_hit_keywords_after_mark).any():
            return True
        return False

class BatchEndStoppingCriteria(StoppingCriteria):
    # stop when all samples in the batch have generated the end token
    def __init__(self, end_token_id: int, device: torch.device):
        self.end_token_id = torch.tensor([end_token_id]).to(device)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids.eq(self.end_token_id).any(1).all():
            return True
        return False

class BatchMultipleEndStoppingCriteria(StoppingCriteria):
    # stop when all samples in the batch have generated the end token. allow specify multiple end tokens
    def __init__(self, end_token_ids: Set[int], device: torch.device):
        self.end_token_ids = torch.tensor(list(end_token_ids)).to(device)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[:, :, None].eq(self.end_token_ids.view((1, 1, -1))).any(2).any(1).all():
            return True
        return False

class StepStoppingCriteria(StoppingCriteria):
    # stop when all samples in the batch have completed one step
    def __init__(self, cur_token_lens: torch.LongTensor, end_token_ids: Set[int], pad_token_id: int, device: torch.device):
        self.cur_token_lens = cur_token_lens
        self.end_token_ids = torch.tensor(list(end_token_ids)).to(device)
        self.pad_token_id = pad_token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        new_tokens = get_new_generated_tokens(input_ids, past_token_lens=self.cur_token_lens, pad_token_id=self.pad_token_id)
        if new_tokens[:, :, None].eq(self.end_token_ids.view((1, 1, -1))).any(2).any(1).all():
            return True
        return False




def find_leftmost_tokens_positions(input_ids: torch.LongTensor, tokens: Union[int, torch.LongTensor], wnby: bool=True) -> torch.LongTensor:
    """
    Get the indices where `tokens` first appear in the `input_ids` for each sample in the batch. When there aren't `tokens`, return seq_len-1 when `within_boundary`

    e.g.
    input_ids = torch.tensor([[1, 2, 3, 3], [7, 0, 4, 0], [3, 2, 1, 2]])
    tokens = torch.tensor([3, 0])
    find_leftmost_tokens_positions(input_ids, tokens)
    >> tensor([2, 1, 0])

    tokens = torch.tensor([3, 2])
    find_leftmost_tokens_positions(input_ids, tokens, wnby=True)
    >> tensor([1, 3, 0])

    find_leftmost_tokens_positions(input_ids, tokens, wnby=False)
    >> tensor([1, 4, 0])
    """
    assert input_ids.ndim == 2
    bsz, seq_len = input_ids.shape
    if isinstance(tokens, int):
        mask = input_ids.eq(tokens)
    elif isinstance(tokens, torch.Tensor):
        mask = input_ids[:, :, None].eq(tokens.view(1, 1, -1)).any(2)
    positions = torch.where(mask.any(1), mask.float().argmax(dim=1), seq_len-1 if wnby else seq_len)
    return positions


def find_rightmost_tokens_positions(input_ids: torch.LongTensor, tokens: Union[int, torch.LongTensor], wnby: bool=True) -> torch.LongTensor:
    """
    Get the index where `tokens` last appear in the `input_ids` for each sample in the batch. When there aren't `tokens`, return 0 when `within_boundary`

    e.g.
    input_ids = torch.tensor([[1, 2, 3, 3], [7, 0, 4, 0], [3, 2, 1, 2]])
    tokens = torch.tensor([3, 0])
    find_rightmost_tokens_positions(input_ids, tokens)
    >> tensor([3, 3, 0])

    tokens = torch.tensor([3, 2])
    find_rightmost_tokens_positions(input_ids, tokens, wnby=True)
    >> tensor([3, 0, 3])

    find_rightmost_tokens_positions(input_ids, tokens, wnby=False)
    >> tensor([3, -1, 3])
    """
    assert input_ids.ndim == 2
    bsz, seq_len = input_ids.shape
    if isinstance(tokens, int):
        mask = input_ids.eq(tokens)
    elif isinstance(tokens, torch.Tensor):
        mask = input_ids[:, :, None].eq(tokens.view(1, 1, -1)).any(2)
    positions = torch.where(mask.any(1), (seq_len - 1) - mask.flip(dims=[1]).float().argmax(dim=1), 0 if wnby else -1)
    return positions


def find_leftmost_notpadded_positions(tensor: torch.Tensor, pad_value: Union[int, float], wnby: bool=True) -> torch.Tensor:
    """Get the index of the first not-pad token in the left for each sample in the batch `tensor`. When they are all pad_value, return seq_len-1 when within_boundary"""
    assert tensor.ndim == 2
    bsz, seq_len = tensor.shape
    mask = tensor.ne(pad_value)
    positions = torch.where(mask.any(1), mask.float().argmax(dim=1), seq_len-1 if wnby else seq_len)
    return positions


def find_rightmost_notpadded_positions(tensor: torch.Tensor, pad_value: Union[int, float], wnby: bool=True) -> torch.Tensor:
    """For right padding. Get the index of the last not-pad token for each sample in the batch `tensor`. When they are all pad_value, return 0 when within_boundary"""
    assert tensor.ndim == 2
    bsz, seq_len = tensor.shape
    mask = tensor.ne(pad_value)
    positions = torch.where(mask.any(1), (seq_len - 1) - mask.flip(dims=[1]).float().argmax(dim=1), 0 if wnby else -1)
    return positions


def count_right_padding(tensor: torch.Tensor, pad_value: Union[int, float]) -> torch.Tensor:
    """For right padding. Count pad_value in the right of `tensor`"""
    seq_len = tensor.shape[-1]
    positions = find_rightmost_notpadded_positions(tensor, pad_value=pad_value, wnby=False)
    return (seq_len - 1) - positions


def count_left_padding(tensor: torch.Tensor, pad_value: Union[int, float]) -> torch.Tensor:
    """For left padding. Count pad_value in the left of `tensor`"""
    seq_len = tensor.shape[-1]
    positions = find_leftmost_notpadded_positions(tensor, pad_value=pad_value, wnby=False)
    return positions


def count_not_left_padding(tensor: torch.Tensor, pad_value: Union[int, float]) -> torch.Tensor:
    """For left padding. Count not pad_value of `tensor`"""
    counts = count_left_padding(tensor, pad_value=pad_value)
    return tensor.shape[-1] - counts


def count_shared_left_padding(tensor: torch.Tensor, pad_value: Union[int, float]) -> torch.Tensor:
    """For left padding. Return the minimal padding length in the batch `tensor`"""
    return count_left_padding(tensor, pad_value).min()




def get_mask_for_seq_area(tensor: torch.Tensor, left_borders: Optional[torch.LongTensor]=None, right_borders: Optional[torch.LongTensor]=None, include_left: bool=False, include_right: bool=False):
    """Return a mask with True in the specified areas"""
    assert not (left_borders is None and right_borders is None)
    bsz, seq_len = tensor.shape

    if include_left and left_borders is not None:
        left_borders = left_borders - 1
    if include_right and right_borders is not None:
        right_borders = right_borders + 1

    if left_borders is not None and right_borders is not None:
        mask = torch.logical_and(
            torch.arange(seq_len).view(1, -1).to(tensor.device) > left_borders.view(-1, 1),
            torch.arange(seq_len).view(1, -1).to(tensor.device) < right_borders.view(-1, 1)
        )
    elif left_borders is not None:
        mask = (torch.arange(seq_len).view(1, -1).to(tensor.device) > left_borders.view(-1, 1))
    elif right_borders is not None:
        mask = (torch.arange(seq_len).view(1, -1).to(tensor.device) < right_borders.view(-1, 1))
    return mask


def mask_by_borders_2D(
    tensor: torch.Tensor, 
    left_borders: Optional[torch.LongTensor] = None, 
    right_borders: Optional[torch.LongTensor] = None, 
    include_left: bool = False, 
    include_right: bool = False,
    value: Union[int, float] = 0,
):
    """Fill before/after borders into value"""
    mask = get_mask_for_seq_area(tensor=tensor, left_borders=left_borders, right_borders=right_borders, include_left=include_left, include_right=include_right)
    return tensor.masked_fill(mask, value=value)


def mask_by_borders_past_key_values(
    past_key_values: Tuple[Tuple[torch.FloatTensor]], 
    left_borders: torch.LongTensor = None, 
    right_borders: torch.LongTensor = None, 
    include_left: bool = False, 
    include_right: bool = False,
    value: Union[int, float] = 0,
):
    """Fill before/after borders into value"""
    mask = get_mask_for_seq_area(past_key_values[0][0][:, 0, :, 0], left_borders=left_borders, right_borders=right_borders, include_left=include_left, include_right=include_right)
    mask = mask[:, None, :, None].expand_as(past_key_values[0][0])

    return tuple(tuple(past_key_value.masked_fill(mask, value=value) for past_key_value in layer_past_key_values) for layer_past_key_values in past_key_values)



def count_tokens_after_positions(input_ids: torch.LongTensor, positions: torch.LongTensor, tokens: Union[int, torch.LongTensor], include_pos: bool=False) -> torch.LongTensor:
    """Count `tokens` after `positions`"""
    seq_len = input_ids.shape[-1]
    mask = get_mask_for_seq_area(input_ids, right_borders=positions, include_right=not include_pos)
    input_ids = input_ids.masked_fill(mask, value=-1)
    if isinstance(tokens, int):
        return input_ids.eq(tokens).sum(1)
    elif isinstance(tokens, torch.Tensor):
        return input_ids[:, :, None].eq(tokens.view(1, 1, -1)).any(2).sum(1)
    

def get_new_generated_tokens(input_ids: torch.LongTensor, past_token_lens: torch.LongTensor, pad_token_id: int=0):
    """Mask past tokens and only reserve the newly generated tokens"""
    n_paddings = count_left_padding(input_ids, pad_value=pad_token_id)
    return mask_by_borders_2D(input_ids, right_borders=n_paddings + past_token_lens, include_right=False, value=pad_token_id)

    

def batched_shift_along_seq_dim_2D(tensor: torch.Tensor, shifts: torch.LongTensor=None):
    """Shift a tensor based on the shifts along seq_dim"""
    bsz, seq_len = tensor.shape
    assert shifts.numel() == bsz

    arange1 = torch.arange(seq_len).view((1, seq_len)).to(tensor.device)
    arange2 = ((arange1 - shifts.view((bsz, 1))) % seq_len)

    return torch.gather(tensor, 1, arange2)



def batched_shift_along_seq_dim_past_key_values(past_key_values: Tuple[Tuple[torch.FloatTensor]], shifts: torch.LongTensor=None):
    """Shift a tensor based on the shifts along seq_dim"""
    bsz = past_key_values[0][0].shape[0]
    seq_len = past_key_values[0][0].shape[2]
    assert shifts.numel() == bsz

    arange1 = torch.arange(seq_len).view((1, seq_len)).to(past_key_values[0][0].device)
    arange2 = ((arange1 - shifts.view((bsz, 1))) % seq_len)

    arange2 = arange2[:, None, :, None].expand_as(past_key_values[0][0])
    return tuple(tuple(torch.gather(past_key_values[i][j], 2, arange2) for j in range(len(past_key_values[i]))) for i in range(len(past_key_values)))



def shift_padding_to_left_2D(tensor: torch.Tensor, pad_value: Union[int, float] = 0):
    """Shift right padding in `tensor` to the left"""
    bsz, seq_len = tensor.shape
    shifts = count_right_padding(tensor, pad_value=pad_value)

    return batched_shift_along_seq_dim_2D(tensor, shifts=shifts)


def shift_padding_to_right_2D(tensor: torch.Tensor, pad_value: Union[int, float] = 0):
    """Shift left padding in `tensor` to the right"""
    bsz, seq_len = tensor.shape
    shifts = count_left_padding(tensor, pad_value=pad_value)

    return batched_shift_along_seq_dim_2D(tensor, shifts=-shifts)





class SamplingWithCalculator:
    def __init__(
        self,
        accelerator: Accelerator = None,
        model: transformers.PreTrainedModel = None, 
        verifier: nn.Module = None,
        tokenizer: PreTrainedTokenizer = None, 
        generation_args: dataclass = None,
    ):
        self.accelerator = accelerator
        self.model = model
        self.verifier = verifier
        self.tokenizer = tokenizer
        self.generation_args = generation_args
        self.device = accelerator.device

        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id
        self.pad_token_id = tokenizer.pad_token_id

        print('tokenizer.bos_token_id: ', tokenizer.bos_token_id)
        print('tokenizer.eos_token_id: ', tokenizer.eos_token_id)
        print('tokenizer.pad_token_id: ', tokenizer.pad_token_id)
        print('tokenizer.unk_token_id: ', tokenizer.unk_token_id)

        equals_tokens = set()
        for equal_id in self.tokenizer.convert_tokens_to_ids(['_=', ')=', '=', '=-']) + [self.tokenizer.encode(' =')[-1]]:
            if equal_id is not None and equal_id != self.tokenizer.unk_token_id:
                equals_tokens.add(equal_id)
                print('equal_id: ', equal_id, ' / ', self.tokenizer.decode([equal_id]))
        leftmark_tokens = set()
        for leftmark_id in self.tokenizer.convert_tokens_to_ids(['<<', '_<<']) + [self.tokenizer.encode(' <<')[-1]]:
            if leftmark_id is not None and leftmark_id != self.tokenizer.unk_token_id:
                leftmark_tokens.add(leftmark_id)
                print('leftmark_id: ', leftmark_id, ' / ', self.tokenizer.decode([leftmark_id]))
        rightmark_token = set()
        for rightmark_id in self.tokenizer.convert_tokens_to_ids(['>>']): # + [self.tokenizer.encode(' >>')[-1]]:
            if rightmark_id is not None and rightmark_id != self.tokenizer.unk_token_id:
                rightmark_token.add(rightmark_id)
                print('rightmark_id: ', rightmark_id, ' / ', self.tokenizer.decode([rightmark_id]))
        newline_token = set()
        for newline_id in [self.tokenizer.encode('\n')[-1]]:
            if newline_id is not None and newline_id != self.tokenizer.unk_token_id:
                newline_token.add(newline_id)
                print('newline_id: ', newline_id, ' / ', self.tokenizer.decode([newline_id]))

        self.equals_tokens = equals_tokens
        self.equal_token_ids = torch.LongTensor(list(equals_tokens)).to(self.device)
        self.leftmark_tokens = leftmark_tokens
        self.left_mark_token_ids = torch.LongTensor(list(leftmark_tokens)).to(self.device)
        self.rightmark_token = rightmark_token
        self.right_mark_token_ids = torch.LongTensor(list(rightmark_token)).to(self.device)
        self.newline_token = newline_token
        self.newline_token_ids = torch.LongTensor(list(newline_token)).to(self.device)
        # self.equal_token_ids = torch.LongTensor(list(LLAMA_EQUALS_TOKENS)).to(self.device)
        # self.left_mark_token_ids = torch.LongTensor(list(LLAMA_LEFTMARK_TOKENS)).to(self.device)
        # self.right_mark_token_ids = torch.LongTensor([LLAMA_RIGHTMARK_TOKEN]).to(self.device)
        # self.newline_token_ids = torch.LongTensor([LLAMA_NEWLINE_TOKEN]).to(self.device)
        self.inter_step_end_token_ids = self.newline_token_ids
        self.step_end_token_ids = torch.concat([self.newline_token_ids, torch.tensor([self.eos_token_id], device=self.device)])

        self.max_new_tokens = generation_args.max_new_tokens
        self.max_length = generation_args.max_length
        self.generation_config = {k:v for k,v in generation_args.__dict__.items() if k not in ('max_new_tokens', 'max_length')}

    def _shift_padding_to_left(self, token_ids: torch.LongTensor, past_key_values: Tuple[Tuple[torch.FloatTensor]]=None, transition_scores: torch.FloatTensor=None):
        """Shift right padding in `token_ids` to the left, and adjust `past_key_values` and `transition_scores` correspondingly"""
        bsz, seq_len = token_ids.shape
        shifts = count_right_padding(token_ids, pad_value=self.pad_token_id)

        token_ids = batched_shift_along_seq_dim_2D(token_ids, shifts=shifts)

        past_key_values = batched_shift_along_seq_dim_past_key_values(past_key_values, shifts=shifts) if past_key_values is not None else None
        transition_scores = shift_padding_to_left_2D(transition_scores, pad_value=0) if transition_scores is not None else None
        return token_ids, past_key_values, transition_scores

    def _truncate_left_padding(self, token_ids: torch.LongTensor, past_key_values: Tuple[Tuple[torch.FloatTensor]]=None, transition_scores: torch.FloatTensor=None):
        n_truncate = count_shared_left_padding(token_ids, pad_value=self.pad_token_id)

        token_ids = token_ids[:, n_truncate:]
        if past_key_values is not None:
            past_key_values = tuple(tuple(past_key_value[:, :, n_truncate:] for past_key_value in layer_past_key_values) for layer_past_key_values in past_key_values)

        if transition_scores is not None:
            n_truncate = count_shared_left_padding(transition_scores, pad_value=0)
            transition_scores = transition_scores[:, n_truncate:]
        return token_ids, past_key_values, transition_scores

    def _add_new_tokens_and_adjust(self, token_ids: torch.LongTensor, new_ids: torch.LongTensor, past_key_values: Tuple[Tuple[torch.FloatTensor]]=None, transition_scores: torch.FloatTensor=None):
        """Add the new token ids to the existing batch and adjust the past_key_values and transition_scores accordingly. Shift as left padding + truncate redundant ones"""
        new_token_ids = torch.cat([token_ids, new_ids], dim=1)
        new_past_key_values = past_key_values
        new_transition_scores = transition_scores

        shifts = count_right_padding(new_token_ids, pad_value=self.pad_token_id)
        new_token_ids = batched_shift_along_seq_dim_2D(new_token_ids, shifts=shifts)
        if past_key_values is not None:
            new_past_key_values = batched_shift_along_seq_dim_past_key_values(new_past_key_values, shifts=shifts)
            new_past_key_values = mask_by_borders_past_key_values(new_past_key_values, right_borders=shifts, include_right=False, value=0)
        if transition_scores is not None:
            new_scores = torch.zeros_like(new_ids).to(transition_scores.dtype)
            new_scores.masked_fill_(new_ids.ne(self.pad_token_id), 100)
            new_transition_scores = torch.concat([transition_scores, new_scores], dim=1)
            new_transition_scores = shift_padding_to_left_2D(new_transition_scores, pad_value=0)

        new_token_ids, new_past_key_values, new_transition_scores = self._truncate_left_padding(new_token_ids, new_past_key_values, new_transition_scores)
        return new_token_ids, new_past_key_values, new_transition_scores

    def _call_calculator(self, token_ids: torch.FloatTensor) -> List[str]:
        """Identify which samples require to use calculator and call calculator for them. Return correspondingly new token_ids"""
        indices = token_ids[:, -1, None].eq(self.equal_token_ids.view(1, -1)).any(1)
        indices = torch.arange(token_ids.shape[0])[indices.cpu()]

        new_text_list = [''] * token_ids.shape[0]
        for i in indices:
            answer = use_calculator(self.tokenizer.decode(token_ids[i]))
            if answer is not None:
                new_text_list[i] += (str(answer) + ">>")

        new_token_ids = self.tokenizer.get_continued_input_ids(new_text_list, right_padding=True, return_tensors=True).to(self.device)
        return new_token_ids

    def _cal_generation_scores(self, transition_scores: torch.FloatTensor):
        length = (transition_scores < 0).sum(1)
        length_penalty = self.generation_config.get('length_penalty', 1)
        return transition_scores.clip(max=0).sum(1) / (length**length_penalty)

    @torch.inference_mode(mode=True)
    def _sample_tokens_with_calculator(
        self, 
        input_ids: torch.LongTensor = None, 
        past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, 
        max_length: int = None, 
        stopping_criteria_end: StoppingCriteria = None, 
        logits_processor: LogitsProcessorList = None, 
        output_transition_scores: bool = False,
        output_past_key_values: bool = True,
        **kwargs,
    ) -> SamplingOutput:
        """
        Batch sampling with calculator - model generation (token-level)
        """

        # if (
        #     self.pad_token_id is not None
        #     and len(input_ids.shape) == 2
        #     and (input_ids[:, 0] == self.pad_token_id).all()
        # ):
        #     print(
        #         "There are extra padding in the left!"
        #     )

        if stopping_criteria_end is None:
            stopping_criteria_end = BatchEndStoppingCriteria(end_token_id=self.eos_token_id, device=self.device)
        stopping_criteria_calculator = BatchCalculatorCallingCriteria(keywords_ids=self.equals_tokens, left_mark_tokens=self.leftmark_tokens, device=self.device)
        stopping_criteria = StoppingCriteriaList([stopping_criteria_calculator, stopping_criteria_end])

        if logits_processor is None:
            logits_processor = LogitsProcessorList()

        all_transition_scores = None
        if output_transition_scores:
            all_transition_scores = torch.empty((input_ids.shape[0], 0)).float().to(self.device)

        input_token_lens = count_not_left_padding(input_ids, pad_value=self.pad_token_id)

        cur_length = input_ids.shape[-1]
        while cur_length < max_length:
            max_new_tokens = max_length - cur_length
            outputs = self.accelerator.unwrap_model(self.model).generate(
                input_ids=input_ids, 
                attention_mask=input_ids.ne(self.pad_token_id),
                past_key_values=past_key_values,
                max_new_tokens=max_new_tokens, 
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                output_scores=output_transition_scores,
                return_dict_in_generate=True,
                **kwargs,
                **self.generation_config,
            )
            input_ids, past_key_values = outputs.sequences, outputs.past_key_values
            if output_transition_scores:
                transition_scores = self.accelerator.unwrap_model(self.model).compute_transition_scores(sequences=input_ids, scores=outputs.scores, beam_indices=outputs.get('beam_indices'), normalize_logits=True)
                all_transition_scores = torch.cat([all_transition_scores, transition_scores], dim=1)

            if stopping_criteria_end(input_ids, None):
                break
            
            new_token_ids = self._call_calculator(input_ids)
            if new_token_ids.numel():
                input_ids, past_key_values, all_transition_scores = self._add_new_tokens_and_adjust(input_ids, new_token_ids, past_key_values, all_transition_scores)
                
            cur_length = input_ids.shape[-1]

        # For the convenience of subsequent operations/processing
        input_ids, past_key_values, all_transition_scores = self._cut_after_eos_lp(input_ids, past_key_values, all_transition_scores, past_token_lens=input_token_lens)

        return SamplingOutput(
            sequences=input_ids,
            transition_scores=all_transition_scores,
            past_key_values=past_key_values if output_past_key_values else None,
        )


    def _convert_into_tensors(self, qns: Union[str, List[str], torch.LongTensor]):
        if isinstance(qns, list) and isinstance(qns[0], str):
            token_ids = self.tokenizer(qns, padding=True, return_tensors='pt').input_ids
        elif isinstance(qns, str):
            token_ids = self.tokenizer([qns], return_tensors='pt').input_ids
        elif isinstance(qns, torch.Tensor):
            token_ids = qns
            if token_ids.dim() == 1:
                token_ids = token_ids.unsqueeze(0)
        else:
            raise ValueError

        return token_ids.to(self.device)

    def sample(
        self, 
        qns: Union[str, List[str]],
    ) -> Union[str, List[str]]:
        """
        Batch sampling with calculator (string-level)

        Return:
            responses (`Union[str, List[str]]`)
        """
        
        input_ids = self._convert_into_tensors(qns)

        cur_length = input_ids.shape[-1]
        if self.max_new_tokens > 0:
            max_length = cur_length + self.max_new_tokens
        else:
            max_length = self.max_length
        
        outputs = self._sample_tokens_with_calculator(
            input_ids=input_ids,
            max_length=max_length,
            stopping_criteria_end=BatchEndStoppingCriteria(end_token_id=self.eos_token_id, device=self.device),
        )
        completions = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
        responses = [completion[len(qn):].strip() for qn, completion in zip(qns, completions)]
        
        if isinstance(qns, str):
            return responses[0]
        return responses

    def _cut_after_eos_lp(self, input_ids: torch.LongTensor, past_key_values: Tuple[Tuple[torch.FloatTensor]]=None, transition_scores: torch.FloatTensor=None, past_token_lens: torch.LongTensor=None):
        """Mask the tokens after eos and keep it left padding"""
        new_past_key_values = past_key_values
        new_transition_scores = transition_scores

        valid_borders_right = find_leftmost_tokens_positions(input_ids, self.eos_token_id, wnby=True)

        new_input_ids = mask_by_borders_2D(input_ids, left_borders=valid_borders_right, include_left=False, value=self.pad_token_id)

        if past_key_values is not None:
            new_past_key_values = mask_by_borders_past_key_values(past_key_values, left_borders=valid_borders_right, include_left=False, value=0)
        
        if transition_scores is not None:
            generate_begin_indices = count_left_padding(input_ids, pad_value=self.pad_token_id) + past_token_lens
            n_left_padding = count_left_padding(transition_scores, pad_value=0)
            borders_for_transitions = valid_borders_right - generate_begin_indices + n_left_padding
            new_transition_scores = mask_by_borders_2D(transition_scores, left_borders=borders_for_transitions, include_left=False, value=0)

        new_input_ids, new_past_key_values, new_transition_scores = self._shift_padding_to_left(new_input_ids, new_past_key_values, new_transition_scores)
        return new_input_ids, new_past_key_values, new_transition_scores

    def _cut_latter_steps(self, input_ids: torch.LongTensor, past_key_values: Tuple[Tuple[torch.FloatTensor]]=None, transition_scores: torch.FloatTensor=None, past_token_lens: torch.LongTensor=None):
        """Mask the latter steps and keep it left padding"""
        new_past_key_values = past_key_values
        new_transition_scores = transition_scores

        new_tokens = get_new_generated_tokens(input_ids, past_token_lens=past_token_lens, pad_token_id=self.pad_token_id)
        cur_step_borders_right = find_leftmost_tokens_positions(new_tokens, self.step_end_token_ids, wnby=True)

        new_input_ids = mask_by_borders_2D(input_ids, left_borders=cur_step_borders_right, include_left=False, value=self.pad_token_id)

        if past_key_values is not None:
            new_past_key_values = mask_by_borders_past_key_values(past_key_values, left_borders=cur_step_borders_right, include_left=False, value=0)

        if transition_scores is not None:
            generate_begin_indices = count_left_padding(input_ids, pad_value=self.pad_token_id) + past_token_lens
            n_left_padding = count_left_padding(transition_scores, pad_value=0)
            borders_for_transitions = cur_step_borders_right - generate_begin_indices + n_left_padding
            new_transition_scores = mask_by_borders_2D(transition_scores, left_borders=borders_for_transitions, include_left=False, value=0)

        new_input_ids, new_past_key_values, new_transition_scores = self._shift_padding_to_left(new_input_ids, new_past_key_values, new_transition_scores)
        return new_input_ids, new_past_key_values, new_transition_scores

    def _mask_former_steps(self, input_ids: torch.LongTensor, past_token_lens: torch.LongTensor=None):
        """Mask the former steps"""
        n_paddings = count_left_padding(input_ids, pad_value=self.pad_token_id)
        cur_step_borders_left = n_paddings + past_token_lens

        input_ids = mask_by_borders_2D(input_ids, right_borders=cur_step_borders_left, include_right=False, value=self.pad_token_id)
        return input_ids

    def _step_level_sample_tokens(
        self,
        input_ids: torch.LongTensor, 
        past_key_values: Tuple[Tuple[torch.FloatTensor]],
        num_sampling_sequences: int = 1,
        max_length: int = 2048,
        output_transition_scores: bool = False,
        output_verifier_scores: bool = False,
        output_past_key_values: bool = True,
    ) -> StepSamplingOutput:
        """
        Step-level sampling with calculator (token-level)
        """
        
        if (
            self.pad_token_id is not None
            and len(input_ids.shape) == 2
            and torch.sum(input_ids[:, -1] == self.pad_token_id) > 0
        ):
            print(
                "A decoder-only architecture is being used, but right-padding was detected! For correct "
                "generation results, please set `padding_side='left'` when initializing the tokenizer."
            )
        
        input_ids = input_ids.repeat_interleave(num_sampling_sequences, dim=0)
        if past_key_values is not None:
            past_key_values = tuple(
                tuple(
                    past_key_value.repeat_interleave(num_sampling_sequences, dim=0)
                    for past_key_value in layer_past_key_values
                )
                for layer_past_key_values in past_key_values
            )

        cur_token_lens = count_not_left_padding(input_ids, pad_value=self.pad_token_id)
        stopping_criteria_step = StepStoppingCriteria(cur_token_lens=cur_token_lens, end_token_ids=self.step_end_token_ids.tolist(), pad_token_id=self.pad_token_id, device=self.device)

        outputs = self._sample_tokens_with_calculator(
            input_ids=input_ids,
            past_key_values=past_key_values,
            max_length=max_length,
            stopping_criteria_end=stopping_criteria_step,
            output_transition_scores=output_transition_scores,
            output_past_key_values=output_past_key_values,
        )
        sequences, past_key_values, transition_scores = outputs.sequences, outputs.past_key_values, outputs.transition_scores

        # For the convenience of subsequent operations/processing
        sequences, past_key_values, transition_scores = self._cut_latter_steps(sequences, past_key_values, transition_scores, past_token_lens=cur_token_lens)
        steps = self._mask_former_steps(sequences, past_token_lens=cur_token_lens)

        return StepSamplingOutput(
            sequences=sequences,
            steps=steps,
            transition_scores=outputs.transition_scores,
            verifier_scores=self.verifier_scoring(sequences) if output_verifier_scores else None,
            past_key_values=past_key_values if output_past_key_values else None,
        )

    def _group_step_level_sample(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Tuple[Tuple[torch.FloatTensor]],
        batch_size: int = 1,
        vs_batch_size: int = 1,
        num_sampling_sequences: int = 1,
        max_length: int = 2048,
        output_transition_scores: bool = False,
        output_verifier_scores: bool = False,
        output_past_key_values: bool = True,
    ) -> StepSamplingOutput:

        input_ids = input_ids.repeat_interleave(num_sampling_sequences, dim=0)                       # [n_beam * n_sampling_steps_per_beam, seq_len]
        if past_key_values is not None:
            past_key_values = tuple(tuple(past_key_value.repeat_interleave(num_sampling_sequences, dim=0) for past_key_value in layer_past_key_values) for layer_past_key_values in past_key_values)
        nseqs = input_ids.shape[0]
        n_split = math.ceil(nseqs / batch_size)

        batch_outputs = []
        for i in range(n_split):
            cur_input = input_ids[i*batch_size: min((i+1)*batch_size, nseqs)]
            cur_past_key_values = tuple(tuple(past_key_value[i*batch_size: min((i+1)*batch_size, nseqs)] for past_key_value in layer_past_key_values) for layer_past_key_values in past_key_values) if past_key_values is not None else None

            step_outputs = self._step_level_sample_tokens(                                           # [cur_batch_size, seq_len]
                input_ids=cur_input,
                past_key_values=cur_past_key_values,
                max_length=max_length, 
                output_transition_scores=output_transition_scores,
                output_verifier_scores=False,
                output_past_key_values=output_past_key_values,
            )
            batch_outputs.append(step_outputs)
        
        batch_outputs = self._concat_group_steps(batch_outputs, dim=0)                               # [n_beam * n_sampling_steps_per_beam, seq_len]

        if output_verifier_scores:
            batch_outputs.verifier_scores = self.verifier_scoring(batch_outputs.sequences, batch_size=vs_batch_size)

        return batch_outputs


    def _step_level_majority_tokens(
        self,
        step_ids: torch.FloatTensor, 
        step_generation_scores: torch.FloatTensor, 
    ) -> torch.LongTensor:
        """
        Majority voting at the step level (token-level)

        Parameters:
            step_ids (`torch.FloatTensor`):
                of shape `(num_sampling_sequences, generated_length)`
            step_generation_scores (`torch.FloatTensor`)
        Return:
            index (`torch.LongTensor`):
                of shape `(1,)`
        """

        bsz, seq_len = step_ids.shape

        # identify valid step and only keep the `=...>>` area in the valid step
        cal_result_ids = step_ids
        right_borders = find_rightmost_tokens_positions(cal_result_ids, self.right_mark_token_ids, wnby=False)
        cal_result_ids = mask_by_borders_2D(cal_result_ids, left_borders=right_borders, include_left=False, value=self.pad_token_id)
        left_borders = find_rightmost_tokens_positions(cal_result_ids, self.equal_token_ids, wnby=False)
        cal_result_ids = mask_by_borders_2D(cal_result_ids, right_borders=left_borders, include_right=False, value=self.pad_token_id)
        hit_res = cal_result_ids.ne(self.pad_token_id).any(1)
        
        if hit_res.any():
            # hash to count: first unify the padding loc
            cal_result_ids = shift_padding_to_left_2D(cal_result_ids, pad_value=self.pad_token_id)
            multipliers = torch.pow(torch.full((seq_len,), 31, dtype=cal_result_ids.dtype, device=self.device), torch.arange(seq_len, device=self.device))
            hashes = (cal_result_ids * multipliers).sum(dim=1)
            _, inverse_indices = torch.unique(hashes, return_inverse=True)
            counts = torch.bincount(inverse_indices)
            counts_sorted, indices = torch.sort(counts, descending=True)
            # derive the majority steps and the corresponding count
            if cal_result_ids[inverse_indices == indices[0]].ne(self.pad_token_id).any():  # valid steps are majority
                count_step = counts_sorted[0]
                majority_step_indices = (inverse_indices == indices[0]).nonzero().view(-1)
            elif indices.shape[0] > 1:                                                     # non-steps are majority, consider the second
                count_step = counts_sorted[1]
                majority_step_indices = (inverse_indices == indices[1]).nonzero().view(-1)
        else:                                                                              # there are not valid steps
            count_step = 0
            majority_step_indices = None
        
        # find answers (eos)
        hit_ans = step_ids.eq(self.eos_token_id).any(1)
        ans_indices = torch.where(hit_ans)[0].unique()
        count_ans = ans_indices.shape[0]
        
        # others
        others_indices = torch.where(torch.logical_and(~hit_res, ~hit_ans))[0]

        if count_step > count_ans:
            step_scores = step_generation_scores.index_select(0, majority_step_indices)
            index = majority_step_indices[step_scores.argmax(0)]
        elif count_ans != 0:
            ans_scores = step_generation_scores.index_select(0, ans_indices)
            index = ans_indices[ans_scores.argmax(0)]
        else:
            others_scores = step_generation_scores.index_select(0, others_indices)
            index = others_indices[others_scores.argmax(0)]

        return index

    def verifier_scoring(self, sequences: torch.LongTensor, batch_size: int = 1):
        nseq = sequences.shape[0]
        n_split = math.ceil(nseq / batch_size)

        outputs = []
        for i in range(n_split):
            batch = sequences[i*batch_size: min((i+1)*batch_size, nseq)]
            vscores = self.accelerator.unwrap_model(self.verifier).scoring_sequences(batch)

            outputs.append(vscores)
        return torch.cat(outputs, dim=0) # [bsz,]
    
    def sample_by_steps(
        self,
        qn_str: str = None,
        batch_size: int = 1,
        vs_batch_size: int = 1,
        n_beam: int = 1,
        n_sampling_steps: int = 2,
        max_n_step: int = 10,
        max_step_length: int = 100,
        inference_mode: str = 'beam',
        dedup_mode: int = 0,
    ) -> str:
        """
        Sampling with step-level techniques

        Only support one string by now

        Parameters:
            qn_str (`str`)
            batch_size (`int`):
                used for sampling at each time
            vs_batch_size (`int`):
                batch size of verifier scoring
            n_beam (`int`)
            n_sampling_steps (`int`):
                number of total sampling sequences as next step candidates
            max_n_step (`int`):
                maximum number of steps
            max_step_length (`int`):
                maximum length for a single step
            inference_mode (`str`):
                'verifier', 'majority', or 'beam'
            dedup_mode (`int`):
                0/1
        """
        assert inference_mode in ('majority', 'beam')
        input_ids = self._convert_into_tensors(qn_str)

        if self.max_new_tokens > 0:
            max_length = input_ids.shape[-1] + self.max_new_tokens
        else:
            max_length = self.max_length
        

        if inference_mode == 'majority':
            sequence, all_steps, all_choices, all_step_transition_scores = self._steps_majority(
                input_ids, 
                batch_size=batch_size, 
                n_sampling_steps=n_sampling_steps, 
                max_step_length=max_step_length, 
                max_length=max_length,
            )
            all_scores = all_step_transition_scores

        elif inference_mode == 'beam':
            sequence, all_sequences, all_choices, all_vscores = self._steps_beam_search(
                input_ids, 
                batch_size=batch_size, 
                vs_batch_size=vs_batch_size,
                n_beam=n_beam, 
                n_sampling_steps=n_sampling_steps, 
                max_n_step=max_n_step, 
                max_step_length=max_step_length, 
                max_length=max_length, 
                dedup_mode=dedup_mode,
            )
            all_scores = all_vscores


        completion = self.tokenizer.batch_decode(sequence, skip_special_tokens=True)[0]
        response = completion[len(qn_str):].strip()

        if inference_mode in ('majority', 'verifier'):
            intermediates = [
                {
                    'steps':[
                        {
                            'sample_id': i,
                            'str': self.tokenizer.decode(step, skip_special_tokens=True),
                            'tscore': score.item(),
                        } 
                        if inference_mode == 'majority' else

                        {
                            'sample_id': i,
                            'str': self.tokenizer.decode(step, skip_special_tokens=True),
                            'vscore': score.item(),
                        }
                        for i, (step, score) in enumerate(zip(steps, scores))
                    ],
                    'choice': choice.item(),
                }
                for steps, choice, scores in zip(all_steps, all_choices, all_scores)
            ]
        elif inference_mode == 'beam':
            intermediates = [
                {
                    'sequences':[
                        {
                            'sample_id': i,
                            'str': self.tokenizer.decode(seq, skip_special_tokens=True)[len(qn_str):],
                            'vscore': score.item(),
                        } 
                        for i, (seq, score) in enumerate(zip(sequences, scores))
                    ],
                    'choices': choices.tolist(),
                }
                for sequences, choices, scores in zip(all_sequences, all_choices, all_scores)
            ]
            
        return response, intermediates

    def _steps_majority(
        self,
        input_ids: torch.LongTensor,
        batch_size: int = 1,
        n_sampling_steps: int = 2,
        max_step_length: int = 100,
        max_length: int = 2048,
    ):
        """
        Majority-voting

        Parameters:
            input_ids (`torch.LongTensor`)
            batch_size (`int`):
                used for sampling at each time
            n_sampling_steps (`int`):
                number of total sampling sequences as next step candidates
            max_step_length (`int`):
                maximal length for a single step
            max_length (`int`):
                maximal length for the complete response
        """

        cur_length = input_ids.shape[-1]

        past_key_values = None
        all_steps = []
        all_step_transition_scores = []
        all_choices = []
        while cur_length < max_length:
            cur_step_max_length = cur_length + max_step_length

            batch_candidates = self._group_step_level_sample(
                input_ids=input_ids,
                past_key_values=past_key_values,
                batch_size=batch_size,
                num_sampling_sequences=n_sampling_steps,
                max_length=min(cur_step_max_length, max_length), 
                output_transition_scores=True,
                output_verifier_scores=False,
            )
            batch_steps = batch_candidates.steps
            batch_transition_scores = batch_candidates.transition_scores

            # select the best steps/sequences
            batch_step_generation_scores = self._cal_generation_scores(batch_transition_scores)
            index = self._step_level_majority_tokens(batch_steps, batch_step_generation_scores)

            sequence = batch_candidates.sequences.index_select(0, index)
            past_key_values = batch_candidates.past_key_values
            if past_key_values is not None:
                past_key_values = tuple(
                    tuple(
                        past_key_value.index_select(0, index)
                        for past_key_value in layer_past_key_values
                    )
                    for layer_past_key_values in past_key_values
                )

            all_steps.append(batch_steps)
            all_step_transition_scores.append(batch_step_generation_scores)
            all_choices.append(index)
            if sequence.eq(self.eos_token_id).any():
                break
            input_ids, past_key_values, _ = self._truncate_left_padding(sequence, past_key_values)
            cur_length = input_ids.shape[-1]

        return sequence, all_steps, all_choices, all_step_transition_scores

    def _steps_beam_search(
        self,
        input_ids: torch.LongTensor,
        batch_size: int = 1,
        vs_batch_size: int = 1,
        n_beam: int = 2,
        n_sampling_steps: int = 2,
        max_n_step: int = 10,
        max_step_length: int = 100,
        max_length: int = 2048,
        dedup_mode: int = 0,
    ):
        """
        Beam search

        Parameters:
            input_ids (`torch.LongTensor`)
            batch_size (`int`):
                used for sampling at each time
            vs_batch_size (`int`):
                batch size of verifier scoring
            n_beam (`int`):
                number of kept sequences when progressing
            n_sampling_steps (`int`):
                number of total sampling sequences as next step candidates
            max_n_step (`int`)
            max_step_length (`int`):
                maximal length for a single step
            max_length (`int`):
                maximal length for the complete response
            dedup_mode (`int`):
                linguistics-level (mode=1); 0 indicates "no"
        """

        assert self.verifier is not None

        assert n_sampling_steps % n_beam == 0
        n_sampling_steps_per_beam = n_sampling_steps // n_beam

        input_ids = input_ids.repeat_interleave(n_beam, dim=0) # [n_beam, seq_len]

        cur_length = input_ids.shape[-1]

        past_key_values = None
        all_sequences = []
        all_vscores = []
        all_choices = []
        cur_step = 0
        while cur_length < max_length and cur_step < max_n_step:
            cur_step_max_length = cur_length + max_step_length

            batch_candidates = self._group_step_level_sample(
                input_ids=input_ids,
                past_key_values=past_key_values,
                batch_size=batch_size,
                vs_batch_size=vs_batch_size,
                num_sampling_sequences=n_sampling_steps_per_beam,
                max_length=min(cur_step_max_length, max_length), 
                output_transition_scores=False,
                output_verifier_scores=True,
            )
            batch_sequences = batch_candidates.sequences         # [n_beam * n_sampling_steps_per_beam, seq_len]
            batch_vscores = batch_candidates.verifier_scores

            # select the best steps/sequences
            hvscores = self._highlight_unique_sequences(batch_sequences, batch_vscores, dedup_mode=dedup_mode)
            _, indices = torch.topk(hvscores, k=n_beam, dim=0, largest=True)

            sequences = batch_sequences.index_select(0, indices) # [n_beam, seq_len]
            past_key_values = batch_candidates.past_key_values
            if past_key_values is not None:
                past_key_values = tuple(
                    tuple(
                        past_key_value.index_select(0, indices)
                        for past_key_value in layer_past_key_values
                    )
                    for layer_past_key_values in past_key_values
                )
            vscores = batch_vscores.index_select(0, indices)

            all_sequences.append(batch_sequences)
            all_vscores.append(batch_vscores)
            all_choices.append(indices)
            if sequences.eq(self.eos_token_id).any(1).all():
                break
            input_ids, past_key_values, _ = self._truncate_left_padding(sequences, past_key_values)
            cur_length = input_ids.shape[-1]

            cur_step += 1

        # final selection
        _, best_index = torch.topk(vscores, k=1, dim=0, largest=True)
        all_sequences.append(sequences)
        all_vscores.append(vscores)
        all_choices.append(best_index)

        sequence = sequences.index_select(0, best_index)
        return sequence, all_sequences, all_choices, all_vscores

    def _highlight_unique_sequences(self, sequences: torch.LongTensor, verifier_scores: torch.FloatTensor, dedup_mode: int=0) -> torch.FloatTensor:
        """
        Prioritize unique sequences: linguistics-level (mode=1)
        """
        if dedup_mode == 0:
            return verifier_scores
        
        seq_len = sequences.shape[-1]
        
        seqs = shift_padding_to_left_2D(sequences, pad_value=self.pad_token_id)
        multipliers = torch.pow(torch.full((seq_len,), 31, dtype=seqs.dtype, device=self.device), torch.arange(seq_len, device=self.device))
        hashes = (seqs * multipliers).sum(dim=1)

        unique_hashes = torch.unique(hashes)
        hightlighted_indices = (unique_hashes[:, None] == hashes[None, :]).float().argmax(dim=1)

        highlighted_vscores = verifier_scores.clone()
        highlighted_vscores[hightlighted_indices] += 100
        return highlighted_vscores


    def _resize_step_level_outputs_by_beam(self, step_outputs: StepSamplingOutput, n_beam: int, n_sampling_steps: int) -> StepSamplingOutput:

        sequences = step_outputs.sequences.view(n_beam, n_sampling_steps, -1)
        steps = step_outputs.steps.view(n_beam, n_sampling_steps, -1)

        transition_scores = step_outputs.transition_scores
        if transition_scores is not None:
            transition_scores = transition_scores.view(n_beam, n_sampling_steps, -1)

        verifier_scores = step_outputs.verifier_scores
        if verifier_scores is not None:
            verifier_scores = verifier_scores.view(n_beam, n_sampling_steps, -1)

        past_key_values = step_outputs.past_key_values
        if past_key_values is not None:
            past_key_values = tuple(
                tuple(
                    past_key_value.view(n_beam, n_sampling_steps, *past_key_value.shape[1:])
                    for past_key_value in layer_past_key_values
                ) 
                for layer_past_key_values in past_key_values
            )

        return StepSamplingOutput(
            sequences=sequences,
            steps=steps,
            transition_scores=transition_scores,
            verifier_scores=verifier_scores,
            past_key_values=past_key_values,
        )

    def _concat_group_tensors(self, tensor_list: List[torch.Tensor], left_padding = True, pad_value: int = 0, dim: int = 0):
        max_len = max(tensor.shape[-1] for tensor in tensor_list)
        if left_padding:
            tensor_list = [F.pad(tensor, (max_len - tensor.shape[-1], 0), value=pad_value) for tensor in tensor_list]
        else:
            tensor_list = [F.pad(tensor, (0, max_len - tensor.shape[-1]), value=pad_value) for tensor in tensor_list]

        tensors = torch.concat(tensor_list, dim=dim)
        return tensors

    def _concat_group_past_key_values(self, past_key_values: List[Tuple[Tuple[torch.FloatTensor]]], token_padding_lens: torch.LongTensor, dim: int = 0):
        # w/o beam: (bsz, n_heads, cache_len, embed_size)
        # w beam: (n_beam, n_sampling_steps_per_beam, n_heads, cache_len, embed_size)

        cache_lens = torch.LongTensor([cache[0][0].shape[-2] for cache in past_key_values]).to(self.device)
        padded_cache_lens = token_padding_lens + cache_lens
        min_cache_len = padded_cache_lens.min()
        cut_cache_lens = padded_cache_lens - min_cache_len

        past_key_values = tuple(
            tuple(
                torch.cat(
                    [F.pad(tensor.transpose(-2, -1), (token_padding_lens[i], -cut_cache_lens[i]), value=0).transpose(-2, -1) for i, tensor in enumerate(tensor_tuples)], 
                    dim=dim
                )
                for tensor_tuples in zip(*layer_tuples)
            )
            for layer_tuples in zip(*past_key_values)
        )
        return past_key_values

    def _concat_group_steps(self, instances: List[StepSamplingOutput], dim: int = 0):
        sequences, steps, transition_scores, verifier_scores, past_key_values = tuple([instance.get(key) for instance in instances] for key in ("sequences", "steps", "transition_scores", "verifier_scores", "past_key_values"))
        
        seq_lens = torch.LongTensor([seq.shape[-1] for seq in sequences]).to(self.device)
        max_seq_len = seq_lens.max()
        token_padding_lens = max_seq_len - seq_lens

        sequences = self._concat_group_tensors(sequences, pad_value=self.pad_token_id, dim=dim)
        steps = self._concat_group_tensors(steps, pad_value=self.pad_token_id, dim=dim)
        transition_scores = self._concat_group_tensors(transition_scores, pad_value=0, dim=dim) if transition_scores[0] is not None else None
        verifier_scores = torch.cat(verifier_scores, dim=dim) if verifier_scores[0] is not None else None

        past_key_values = self._concat_group_past_key_values(past_key_values, token_padding_lens, dim=dim) if past_key_values[0] is not None else None

        return StepSamplingOutput(
            sequences=sequences,
            steps=steps,
            transition_scores=transition_scores,
            verifier_scores=verifier_scores,
            past_key_values=past_key_values,
        )
    
    def _flatten_step_level_outputs_beam(self, step_outputs: StepSamplingOutput) -> StepSamplingOutput:
        sequences, steps, transition_scores, verifier_scores, past_key_values = step_outputs.sequences, step_outputs.steps, step_outputs.transition_scores, step_outputs.verifier_scores, step_outputs.past_key_values

        sequences = sequences.view(-1, sequences.shape[-1])
        steps = steps.view(-1, steps.shape[-1])
        transition_scores = transition_scores.view(-1, transition_scores.shape[-1]) if transition_scores is not None else None
        verifier_scores = verifier_scores.view(-1,) if verifier_scores is not None else None

        if past_key_values is not None:
            _, _, n_heads, cache_len, embed_size = past_key_values[0][0].shape
            past_key_values = tuple(tuple(past_key_value.view(-1, n_heads, cache_len, embed_size) for past_key_value in layer_past_key_values) for layer_past_key_values in past_key_values)

        return StepSamplingOutput(
            sequences=sequences,
            steps=steps,
            transition_scores=transition_scores,
            verifier_scores=verifier_scores,
            past_key_values=past_key_values,
        )

