# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Classes and functions related to data processing.

"""
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.path.dirname(os.getcwd()))

import copy
from dataclasses import dataclass, field
import io
import math
import numpy as np
from PIL import Image
import json
import re
import random
import torch
from typing import List, Optional, Union, Tuple, Any, Dict
from datasets import load_dataset, concatenate_datasets, disable_caching, Dataset
from transformers import (
    LlamaTokenizer,
    GenerationConfig,
    PreTrainedTokenizer,
    DataCollatorForSeq2Seq
)
from threading import Thread
from transformers import TextIteratorStreamer
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_utils import PILImageResampling
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from transformers.tokenization_utils_base import TruncationStrategy

try:
    from pyhie import allspark
    from pyhie.allspark import Engine as AllSparkEngine
except ImportError:
    AllSparkEngine = type(None)

try:
    from vllm import LLM as VLLM
    from vllm.sampling_params import SamplingParams as VLLMSamplingParams
except ImportError:
    VLLM = type(None)

# from utils.generation_utils import GenerationMixinSearch

ALL_PROMPT_TEMPLATES = [
    # QA
    "alpaca", "simple_qa", "simplest_qa", "llama_qa", "llama_new_qa", "qwen_qa", "granite_guardian_qa",
    # chat
    "simple_chat", "llama_chat", "llama_new_qa", "mistral_chat", "qwen_chat"
]


ALPACA_PROMPT_INPUT = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{INSTRUCTION}

### Input:
{INPUT}

### Response:
{OUTPUT}"""

ALPACA_PROMPT_WO_INPUT = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{INSTRUCTION}

### Response:
{OUTPUT}"""

SIMPLE_QA_PROMPT_INPUT = """<Human>:
{INSTRUCTION}
{INPUT}
<Assistant>:
{OUTPUT}"""

SIMPLE_QA_PROMPT_WO_INPUT = """<Human>:
{INSTRUCTION}
<Assistant>:
{OUTPUT}"""

SIMPLE_CHAT_TEMPLATE="""<{ROLE}>:
{CONTENT}"""

SIMPLEST_QA_PROMPT_INPUT = """{INSTRUCTION}
{INPUT}
{OUTPUT}"""

SIMPLEST_QA_PROMPT_WO_INPUT = """{INSTRUCTION}
{OUTPUT}"""

DEFAULT_LLAMA_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
DEFAULT_LLAMA_NEW_SYSTEM_PROMPT = """Cutting Knowledge Date: December 2023
Today Date: {TODAY}"""
LLAMA_SYSTEM_PROMPT = """<<SYS>>
{SYSTEM}
<</SYS>>

"""

DEFAULT_QWEN_SYSTEM_PROMPT = """You are a helpful assistant."""

QWEN_SYSTEM_PROMPT = """system
{SYSTEM}
"""
# should format SYSTEM_PROMPT first, then INSTRUCTION, INPUT and OUTPUT
# after tokenization, should start with <s> and end with </s>
# we move {{OUTPUT}} to a new line to avoid the bug calculating the prompt_len for en
LLAMA_QA_PROMPT_INPUT="""[INST] {SYSTEM_PROMPT}{{INSTRUCTION}}\n{{INPUT}} [/INST]
{{OUTPUT}}"""
LLAMA_QA_PROMPT_WO_INPUT="""[INST] {SYSTEM_PROMPT}{{INSTRUCTION}} [/INST]
{{OUTPUT}}"""
LLAMA_NEW_QA_PROMPT_INPUT="""<|start_header_id|>system<|end_header_id|>

{SYSTEM_PROMPT}

<|eot_id|><|start_header_id|>user<|end_header_id|>

{{INSTRUCTION}}\n{{INPUT}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{{OUTPUT}}"""
LLAMA_NEW_QA_PROMPT_WO_INPUT="""<|start_header_id|>system<|end_header_id|>

{SYSTEM_PROMPT}

<|eot_id|><|start_header_id|>user<|end_header_id|>

{{INSTRUCTION}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{{OUTPUT}}"""


NO_SOLUTION_DEFAULT_Q_CN="""求解器无法找到可行解。"""
NO_SOLUTION_DEFAULT_A_CN = [  # just for reference
    "抱歉，由于求解器无法提供运行结果，我将无法为您解释变量和目标值的意义。请提供运行结果，我将很乐意为您提供解释和帮助。",
    "对不起，求解器没有提供相关的运行结果。这可能有多种原因，比如约束冲突导致问题无解。请找出原因后再次调用求解器。待求解器提供结果后，我将很乐意为您解释。",
    "很抱歉，由于没有给出具体的运行结果，我无法进行解释。如果您能给出具体的运行结果，我将很乐意帮助您解释。",
    "抱歉，求解器未能提供可行的问题解决方案。我们可以考虑增删改一些约束条件，以使求解器能够找到一个可行的解。",
    "对不起，我无法继续因为求解器没有提供运行结果。一旦求解器提供了运行结果，我将帮助你解释变量和目标值的意义。",
    "对不起，我没看到您提供的优化问题的运行结果。如果您能提供运行结果，我将非常乐意帮助您解释问题中变量和目标值的意义。",
    "对不起，您提供的问题运行结果为空。请您检查问题的正确性，以便求解器能够给出运行结果，从而我能进行正确的结果解释。",
    "很抱歉，求解器未能提供可行解。但是，我们可以分析约束条件和目标函数，看看是否有任何调整可以改善结果。",
    "很遗憾，求解器未能找到问题的可行解决方案。我们可以更仔细地分析约束条件和目标函数，看看是否有任何不一致或错误。"
]
NO_SOLUTION_DEFAULT_Q_EN="""The solver fails to find a feasible solution."""
NO_SOLUTION_DEFAULT_A_EN = [  # just for reference
    "The solver results is missing from the problem. This may be caused by that solver failed to give a feasible solution. Please check your problem description and try again.",
    "The optimization problem is not solved, hence it is impossible to provide an explanation based on the solver results.",
    "I'm sorry to inform you that the solver was unable to provide a feasible solution to the problem. However, we can try to analyze the constraints and objective function to see if there are any adjustments we can make to improve the results.",
    "Unfortunately, the solver did not find a feasible solution for the problem. Perhaps we can review the objective and constraints and see if there are any conflicts that need to be resolved before running the solver again.",
    "I'm sorry, but the solver was unable to find a feasible solution to the problem at this time. We can investigate if there are any alternate approaches or algorithms that may yield better results.",
    "I regret to inform you that the solver could not generate a feasible solution for the problem. We can review the problem formulation and check if there are any errors or inconsistencies, which may be causing the issue.",
    "I apologize, but the solver was not able to provide a feasible solution for the problem. Perhaps there are some ambiguities or inconsistencies in the problem formulation which may be impacting the solver's performance.",
    "Unfortunately, the solver could not produce a feasible solution. We can review the constraints and try again."
]


def get_prompt_templates(prompt_templates: Union[List[str], str]):
    """
    Args:
        prompt_templates: The template to use in prompts. Could be:
            - str in [`simple_qa`, `simplest_qa`, `alpaca`]: to use preset templates
            - a list of two strings, corresponding to prompt with input placeholder, and prompt without input placeholder.
    """
    if isinstance(prompt_templates, (list, tuple)) and len(prompt_templates) > 1:
        prompt_w_input, prompt_wo_input = prompt_templates[0], prompt_templates[1]
    elif isinstance(prompt_templates, str):
        if prompt_templates == 'alpaca':
            prompt_w_input, prompt_wo_input = ALPACA_PROMPT_INPUT, ALPACA_PROMPT_WO_INPUT
        elif prompt_templates == 'simple_qa':
            prompt_w_input, prompt_wo_input = SIMPLE_QA_PROMPT_INPUT, SIMPLE_QA_PROMPT_WO_INPUT
        elif prompt_templates == 'simplest_qa':
            prompt_w_input, prompt_wo_input = SIMPLEST_QA_PROMPT_INPUT, SIMPLEST_QA_PROMPT_WO_INPUT
        elif prompt_templates == 'llama_qa':
            prompt_w_input, prompt_wo_input = LLAMA_QA_PROMPT_INPUT, LLAMA_QA_PROMPT_WO_INPUT
        elif prompt_templates == 'llama_new_qa':
            prompt_w_input, prompt_wo_input = LLAMA_NEW_QA_PROMPT_INPUT, LLAMA_NEW_QA_PROMPT_WO_INPUT
        elif prompt_templates == 'granite_guardian_qa':  # for granite guardian, we use template built in tokenizer
            prompt_w_input, prompt_wo_input = None, None
        elif "chat" in prompt_templates:  # due to inheritance, TokenizedChatProcessorWithDA will call this
            prompt_w_input, prompt_wo_input = None, None
        else:
            raise NotImplementedError('prompt_templates {} not supported.'.format(prompt_templates))
        
    return prompt_w_input, prompt_wo_input


GRANITE_HARM_RISKS = ["social_bias", "jailbreak", "profanity", "sexual_content", "unethical_behavior", "violence", "harm"]
GRANITE_RAG_RISKS = ["groundedness", "answer_relevance", "context_relevance"]
GRANITE_FUNCTION_CALL_RISKS = ["function_call"]


def get_granite_guardian_qa_prompt(tokenizer, instruction, inputs, output, risk_name, **kwargs):
    """Construct prompt for granite guardian, from instruction, inputs, output and risk_name.

    Note:
    - uses a template built in tokenizer to construct the prompt. This consists of two steps:
        - check the risk_name.
        - according to risk_name, compose the chat.
        - use the tokenizer to render the final prompt.
    - For granite guardian, the chat could have four parts (roles) in order:
        - system: optional
        - tools: optional
        - secondary_role: optional
        - primary_role: the message to be detected.

    Args:
        tokenizer: the tokenizer to use.
        instruction: the user query.
        inputs: the other context to help the LLM generate the response, e.g., retrieved documents, tool description, etc.
        output: the LLM response.
        risk_name: the detection task to perform, e.g., "privacy", "harm", "unfairness", etc. By default, risk_name is a string. 
            For some risks, the detection could be done on either the prompt or response. It may be further specified using a hyphen, e.g., "harm-response".
            If a list is given, the prompt will be constructed by listing the prompts of each risk_name.
    """
    if isinstance(risk_name, list):
        return [get_granite_guardian_qa_prompt(tokenizer, instruction, inputs, output, _risk_name, **kwargs) for _risk_name in risk_name]

    elif isinstance(risk_name, str):
        if '-' in risk_name:
            risk_name, prompt_or_response = [_name.strip() for _name in risk_name.split('-')]
        else:
            prompt_or_response = "prompt" if risk_name == 'jailbreak' else "response"

        if risk_name not in (GRANITE_HARM_RISKS + GRANITE_RAG_RISKS + GRANITE_FUNCTION_CALL_RISKS):
            raise NotImplementedError('risk_name {} not supported.'.format(risk_name))
    else:
        raise NotImplementedError('risk_name {} not supported.'.format(risk_name))

    # for string risk_name, construct the prompt
    chat = []
    if kwargs.get("system", ""):
        chat.append({"role": "system", "content": kwargs.get("system", "")})

    if risk_name in GRANITE_HARM_RISKS:
        if prompt_or_response == "prompt":
            chat.append({'role': 'user', 'content': instruction})
        else:
            chat.extend([{'role': 'user', 'content': instruction}, {'role': 'assistant', 'content': output}])
    elif risk_name in GRANITE_RAG_RISKS:
        if risk_name == "groundedness":
            chat.extend([{'role': 'context', 'content': inputs}, {'role': 'assistant', 'content': output}])
        elif risk_name == "answer_relevance":
            chat.extend([{'role': 'user', 'content': instruction}, {'role': 'assistant', 'content': output}])
        elif risk_name == "context_relevance":
            chat.extend([{'role': 'user', 'content': instruction}, {'role': 'context', 'content': inputs}])
        else:
            raise NotImplementedError('risk_name {} not supported.'.format(risk_name))
    elif risk_name == "function_call":
        chat.extend([{'role': 'user', 'content': instruction}, {'role': 'assistant', 'content': output}])
    else:
        raise NotImplementedError('risk_name {} not supported.'.format(risk_name))
    
    prompt = tokenizer.apply_chat_template(
        chat, guardian_config = {"risk_name": risk_name}, 
        add_generation_prompt=kwargs.get("add_generation_prompt", True), 
        tokenize=False
    )

    return prompt


def granite_guardian_post_process_fn(tokenizer, generation_output, responses, safe_token="No", unsafe_token="Yes", nlogprobs=20):
    """ Convert granite guardian output to a standard format.

    Args:
        generation_output.scores: Tuple(torch.Tensor), each with shape N-by-V, the normalized token logits (the last one EOS).
        responses: List[str], a batch of response from the generator model.
    """
    batch_size = len(responses)

    def get_probabilities(logprobs):
        safe_token_prob = [1e-50] * batch_size
        unsafe_token_prob = [1e-50] * batch_size
        # for each sample, sum up the safe_token/unsafe_token_prob probs in the scores
        for gen_token_i in logprobs:
            values = gen_token_i.values.tolist()
            indices = gen_token_i.indices.tolist()
            for index, (sample_values, sample_indices) in enumerate(zip(values, indices)):
                decoded_tokens = tokenizer.convert_ids_to_tokens(sample_indices)
                safe_token_prob[index] += sum([math.exp(value) for value, token in zip(
                    sample_values, decoded_tokens) if token.strip().lower() == safe_token.lower()])
                unsafe_token_prob[index] += sum([math.exp(value) for value, token in zip(
                    sample_values, decoded_tokens) if token.strip().lower() == unsafe_token.lower()])

        probabilities = torch.softmax(
            torch.log(torch.tensor([safe_token_prob, unsafe_token_prob])), dim=0
        )

        return probabilities

    prob_of_risk = None
    if nlogprobs > 0:
        list_index_logprobs_i = [
            torch.topk(token_i, k=nlogprobs, largest=True, sorted=True) for token_i in list(generation_output.scores)[:-1]]
        if list_index_logprobs_i:
            prob = get_probabilities(list_index_logprobs_i)
            prob_of_risk = prob[1].tolist()

    final_responses = []
    for index, response in enumerate(responses):
        if response.lower().startswith(unsafe_token.lower()):
            final_responses.append(unsafe_token if prob_of_risk is None else f"{unsafe_token}, {prob_of_risk[index]}")
        elif response.lower().startswith(safe_token.lower()):
            final_responses.append(safe_token if prob_of_risk is None else f"{safe_token}, {prob_of_risk[index]}")
        else:
            final_responses.append("Failed")

    return final_responses


@dataclass
class DataArguments:
    """ Arguments pertaining to what data we are going to input our model for training and eval.

    Parameters:
        train_file (`str`):
            The path of training data file. Could be a list of such files.
        validation_file (`str`):
            (Optional) the path of evaluation data file to evaluate the perplexity on. Could be a list of such files.
        validation_split (`float`):
            The percentage of the train set used as validation set in case there's no validation_file.
            If larger than 1, would be interpreted as the number of samples to use as validation set.
        max_seq_length (`int`):
            The maximum total input sequence length after tokenization. Sequences longer 
            than this will be truncated, sequences shorter will be padded.
            If not set, will pad the samples dynamically when batching to the maximum length in the batch.
    """
    train_file: str = field(
        default=None, metadata={"help": "The path of training data file."}
    )
    validation_file: str = field(
        default=None,
        metadata={"help": "(Optional) the path of evaluation data file to evaluate the perplexity on."},
    )
    validation_split: Optional[float] = field(
        default=0.0,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation_file."
            "If larger than 1, would be interpreted as the number of samples to use as validation set."
        },
    )
    max_seq_length: int = field(
        default=None,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
            "If not set, will pad the samples dynamically when batching to the maximum length in the batch."
        },
    )
    prompt_templates: str = field(
        default="simple_qa",
        metadata={
            "help": "The template to use in prompts. Could be a str in {} to use preset templates,"
            "or a list of two strings, corresponding to prompt with input placeholder, and prompt without input placeholder.".format(ALL_PROMPT_TEMPLATES)
        },
    )
    system_prompt: str = field(
        default=None,
        metadata={
            "help": "The template to use as system prompt. Could be any str, including special cases: "
            "None means using empty system prompt; default means using the default system prompt of current prompt_templates."
        },
    )

    def __post_init__(self):
        # in case multiple files are provided as a list/tuple
        if self.train_file.startswith(('[', '(')) and self.train_file.endswith((']', ')')):
            self.train_file = eval(self.train_file)
        if self.validation_file in ["None", '', None]:
            self.validation_file = None
        elif self.validation_file.startswith(('[', '(')) and self.validation_file.endswith((']', ')')):
            self.validation_file = eval(self.validation_file)
        if self.validation_file:
            self.validation_split = 0  # ignore validation_split
        else:
            if self.validation_split > 1:
                self.validation_split = int(self.validation_split)
        # check prompt_templates
        if self.prompt_templates.startswith(('[', '(')) and self.prompt_templates.endswith((']', ')')):
            self.prompt_templates = eval(self.prompt_templates)
            assert all('qa' in prompt_template or 'chat' in prompt_template for prompt_template in self.prompt_templates), ValueError(
                "prompt_templates should contain either qa or chat; got {}".format(self.prompt_templates))
        else:
            assert self.prompt_templates in ALL_PROMPT_TEMPLATES, NotImplementedError(
                "Prompt template {} has not been implemented.".format(self.prompt_templates))
        if self.system_prompt in {"", " ", "None"} or self.system_prompt is None:
            self.system_prompt = ""
        elif self.system_prompt == "default":
            if self.prompt_templates == "llama_qa":
                self.system_prompt = DEFAULT_LLAMA_SYSTEM_PROMPT
            elif self.prompt_templates == "llama_new_qa":
                from datetime import datetime

                self.system_prompt = DEFAULT_LLAMA_NEW_SYSTEM_PROMPT.format(TODAY=datetime.today().strftime("%d %b %Y"))
            else:
                raise NotImplementedError(
                    "Current prompt_templates {} does not have default system prompt.".format(self.prompt_templates))
        elif os.path.isfile(self.system_prompt):  # load system prompt from file
            with open(self.system_prompt, "r") as f:
                self.system_prompt = f.read()


@dataclass
class ImageProcessingArguments:
    """ Arguments pertaining to how to process image inputs and outputs.
    Parameters are set based on LlavaNextImageProcessor.

    Parameters:
        See metadata below.
    """
    do_resize: bool = field(
        default=True,
        metadata={"help": "Whether to resize the image's (height, width) dimensions to the specified `size`."},
    )
    image_size: str = field(
        default=None,
        metadata={
            "help": "Size of the image after resizing. Could be: tuple[int] for (H, W); int for shorest_edge; "
            "or dict such as {'shortest_edge': 224}. Set to None to use the default value of corresponding processor."
        },
    )
    image_grid_pinpoints: str = field(
        default=None,
        metadata={
            "help": "A list of possible resolutions to use for processing high resolution images."
            "Set to None to use the default value of corresponding processor."
        },
    )
    image_resample: str = field(
        default="bicubic",
        metadata={"help": "Whether to resize the image's (height, width) dimensions to the specified `size`."},
    )
    do_center_crop: bool = field(
        default=True,
        metadata={"help": "Whether to center crop the image to the specified `crop_size`."},
    )
    crop_size: int = field(
        default=224,
        metadata={"help": "Size of the output image after applying `center_crop`."},
    )
    # we hide do_scale and rescale_factor as: 
    # 1. it seems a must-do to scale the image to [0,1]; 
    # 2. it overlaps with do_normalize.
    do_normalize: bool = field(
        default=True,
        metadata={"help": "Whether to normalize the image by `image_mean` and `image_std`."},
    )
    image_mean: str = field(
        default="None",
        metadata={"help": "Mean to use if normalizing the image. Set to None to use the default value of corresponding processor."},
    )
    image_std: str = field(
        default="None",
        metadata={"help": "Standard deviation to use if normalizing the image. Set to None to use the default value of corresponding processor."},
    )

    def __post_init__(self):
        self.image_size = eval(self.image_size) if isinstance(self.image_size, str) else None
        self.image_grid_pinpoints = eval(self.image_grid_pinpoints) if isinstance(self.image_grid_pinpoints, str) else None
        if self.image_resample == "bicubic":
            self.image_resample = PILImageResampling.BICUBIC
        else:
            raise NotImplementedError("Image resampling method {} not supported.".format(self.image_resample))
        self.image_mean = eval(self.image_mean) if isinstance(self.image_mean, str) else None
        self.image_std = eval(self.image_std) if isinstance(self.image_std, str) else None


def _convert_image_bytes_to_rgb_array(image_bytes):
    """This function does:
    1. Convert image to RGB mode;
    2. Convert image in bytes to numpy array (uint8);
    The output array will be of shape (H, W, 3).
    """
    if isinstance(image_bytes, bytes):
        image = Image.open(io.BytesIO(image_bytes))
        image_arr = np.asarray(image if image.mode == "RGB" else image.convert("RGB"), np.uint8)
    else:
        raise TypeError("Expect image format to be bytes; got {}".format(type(image_bytes)))

    return image_arr


def convert_prompt_instance_to_chat(instance: dict, metadata_keys=None, image_processor=None):
    """ Convert instance of format:
    {
        "index": 32,
        "instruction": "...",
        "output": "...",
        ...
    }
    to chat of format:
    [
        {
            "role": "...", 
            "content": "...", 
            "metadata": {"language": "cn/en", ...},
        },
        ...
    ]
    that can be processed by ChatRecords.
    """
    if instance.get("chat", None) is not None:
        return instance

    if not metadata_keys:
        # e.g., "var_description", "objective_description", "constraint_description", "value_labels"
        exclude_keys = [
            "index", "instruction", "input", "output", 
            "solveResult", "solveResult_explanation", "language"
        ]
        metadata_keys = [key for key in instance if key not in exclude_keys]
    assert isinstance(metadata_keys, (list, tuple)), TypeError(
        "Expect metadata_keys to be list; got {}".format(type(metadata_keys)))
    metadata = {key:instance[key] for key in metadata_keys if key in instance}
    language = instance.get("language", None)
    if not language:
        language = "cn" if re.search(u'[\u4e00-\u9fff]', instance['instruction']) else "en"
    metadata["language"] = language

    chat = [
        {
            "role": "Human", "content": instance['instruction'], 
            'metadata': {
                "language": language,
                "task": "SimpleMathFormulationTask.Q",
                }
        },
        {
            "role": 'Assistant', 'content': instance['output'], 
            'metadata': {
                "language": language, 
                "task": "SimpleMathFormulationTask.A",
                }
        },
    ]
    if metadata.get("value_labels", None):  # move value_labels from general metadata to Assistant's metadata
        value_labels = metadata.pop("value_labels", None)
        chat[1]["metadata"].update({"value_labels": value_labels})
    if "solveResult" in instance and instance.get("solveResult_explanation", None):
        solve_result = instance['solveResult']
        solve_result_explanation = instance['solveResult_explanation']
        if not solve_result and solve_result_explanation:  # in case of no solution
            if language == 'cn':
                solve_result = NO_SOLUTION_DEFAULT_Q_CN
                solve_result_explanation = random.choice(NO_SOLUTION_DEFAULT_A_CN)
            else:
                solve_result = NO_SOLUTION_DEFAULT_Q_EN
                solve_result_explanation = random.choice(NO_SOLUTION_DEFAULT_A_EN)
        chat.extend([
            {
                "role": "Solver", "content": solve_result, 
                'metadata': {
                    "language": language,
                    "task": "SolverResultExplanationTask.Q",
                    }
            },
            {
                'role': 'Assistant', 'content': instance['solveResult_explanation'],
                'metadata': {
                    "language": language,
                    "task": "SolverResultExplanationTask.A",
                    }
            },
        ])
        metadata["solveResult"] = solve_result
        metadata["solveResult_explanation"] = solve_result_explanation
    if instance.get("image", None) is not None:
        image_arr = _convert_image_bytes_to_rgb_array(instance['image'])
        if image_processor is not None:
            # image_processor output is transformers.image_processing_utils.BatchFeature
            image_dict = image_processor(image_arr)
        else:
            image_dict = {"pixel_values": [image_arr,], "image_sizes": [image_arr.shape[:2], ]}
        chat[0]['metadata'].update(image_dict)
    new_instance = {
        "index": instance.get("index", 0),
        "chat": chat,
        "metadata": metadata,
    }

    return new_instance


def _check_possible_roles_in_chat_template(chat_template, roles):
    is_human = ('"human"' in chat_template) or ("'human'" in chat_template)
    is_Human = ('"Human"' in chat_template) or ("'Human'" in chat_template)
    is_user = ('"user"' in chat_template) or ("'user'" in chat_template)
    is_User = ('"User"' in chat_template) or ("'User'" in chat_template)
    is_assistant = ('"assistant"' in chat_template) or ("'assistant'" in chat_template)
    is_Assistant = ('"Assistant"' in chat_template) or ("'Assistant'" in chat_template)

    role_map = {}
    for role in roles:
        if role not in role_map:
            if role in ['system']:
                role_map[role] = 'system'
            if role in ['human', 'Human', 'user', 'User']:
                if is_human:
                    role_map[role] = 'human'
                elif is_Human:
                    role_map[role] = 'Human'
                elif is_user:
                    role_map[role] = 'user'
                elif is_User:
                    role_map[role] = 'User'
                else:
                    raise RuntimeError('Cannot detect roles in chat_template; please provide role map explicitly.')
            elif role in ['assistant', 'Assistant']:
                if is_assistant:
                    role_map[role] = 'assistant'
                elif is_Assistant:
                    role_map[role] = 'Assistant'
            else:  # for roles like solver, tools, check if its lower case and capital case in chat_template
                role_lower = role.lower()
                role_capital = role_lower[0].upper() + role_lower[1:]
                is_origin = (f'"{role}"' in chat_template) or (f"'{role}'" in chat_template)
                is_lower = (f'"{role_lower}"' in chat_template) or (f"'{role_lower}'" in chat_template)
                is_capital = (f'"{role_capital}"' in chat_template) or (f"'{role_capital}'" in chat_template)
                if is_origin:
                    role_map[role] = role
                elif is_lower:
                    role_map[role] = role_lower
                elif is_capital:
                    role_map[role] = role_capital
                else:
                    raise RuntimeError(f'chat_template does not support role {role}; please provide role map explicitly.')

    return role_map


class ChatRecords(list):
    """Extends list to manage multi-round chat, including:
    - update current chat by append/extend/set new items just like modifying a list
    - concatenate chat items into string via Chat._concat_to_string
    - tokenize current chat via ChatRecords.tokenize

    Args:
        iterable (List[dict]): 
            A list of dict of the following format:
            [
                {'role': 'user', 'content': '...', 'metadata': {...}},
                {'role': 'assistant', 'content': '...', 'metadata': {...}},
                {'role': 'solver', 'content': '...', 'metadata': {...}},
                {'role': 'assistant', 'content': '...', 'metadata': {...}},
                ...
            ]
            Optionally, `content` could be a list of str. In this case, it is a stochastic chat,
            where the trajectory is sampled at runtime. Do not use this feature in a preprocessor.
        prepend_token_to (Optional[dict]): 
            A dict of the following format:
            {
                'user': '...'  # some token
            }
            The token will be prepended to each content of the role when converting chat to string.
            Applicable when use_tokenizer_chat_template=False.
        append_token_to (Optional[dict]): 
            A dict of the following format:
            {
                'assistant': '...'  # some token
            }
            The token will be appended to each content of the role when converting chat to string.
            Applicable when use_tokenizer_chat_template=False.
        template (str): 
            A customized template with ROLE and CONTENT tags. Applicable when use_tokenizer_chat_template=False.
        tokenizer (Optional[PreTrainedTokenizerBase]): 
            Any tokenizer, if the input_ids needs to be got.
        max_seq_length (int, default to 8192): 
            The maximum number of tokens. 
        add_generation_prompt (bool, default to False): 
            Whether need to append a generation prompt, e.g., <|start_of_role|>assistant<|end_of_role|>, to the text. 
            Used in generation mode.
        use_tokenizer_chat_template (bool, default to False): 
            Whether to use the default template provided by the tokenizer.apply_chat_template. 
        system_prompt (str): 
            System message. May also be provided in iterable.
        role_tags (Optional[dict]): 
            Some tokenizers may have special tokens for the role tags (e.g., <｜User｜>, <｜Assistant｜>), instead of 
            standard format <bos_token> + role (e.g., <|im_start|>assistant). For these tokenizers, need to provide
            role_tags dict in order to get a correct label.
        role_map (Optional[dict]): 
            Some tokenizers cannot handle the roles in chat. Provide a role map if you do not want the role to be 
            automatically detected and mapped. 

    Usage:
        messages = [
            {"role": "system", "content": "You are a bot that responds to queries."},
            {"role": "user", "content": "Janet's ducks lay 16 eggs per day."},
            {"role": "assistant", "content": "It is -3 outside."},
            {"role": "user", "content": "What you said was not related to my question."},
            {"role": "assistant", "content": "You were not asking any question either."},
        ]
        tokenizer = AutoTokenizer.from_pretrained("../output/DeepSeek-R1-Distill-Qwen-1.5B_hf")
        chat = ChatRecords(
            messages, 
            tokenizer=tokenizer,
            max_seq_length=4096,
            add_generation_prompt=True,
            use_tokenizer_chat_template=True,
            role_tags={"user": "<｜User｜>", "assistant": "<｜Assistant｜>"},
        )
        print(chat.text)
        print(chat.input_ids)
        print(chat.labels())
    """
    def __init__(
        self, 
        iterable, 
        *, 
        prepend_token_to=None, 
        append_token_to=None, 
        template=None,
        tokenizer=None,
        max_seq_length=8192,
        add_generation_prompt=False,
        use_tokenizer_chat_template=False,
        system_prompt=None,
        **kwargs
        ):
        # set system prompt; if the role of iterable[0] is not system; prepend system utterrance
        if system_prompt and (len(iterable) > 0) and (iterable[0].get("role", "").lower() != "system"):
            iterable = [{"role": "system", "content": system_prompt}] + iterable
        # if system_prompt has not been set before, set it now
        # this is used in TokenizedChatProcessorWithDA to generate outputs without tokenization
        if not hasattr(self, "system_prompt"):
            self.system_prompt = system_prompt

        super(ChatRecords, self).__init__(iterable)
        assert all(isinstance(item, dict) for item in self.__iter__()), ValueError(
            'Expect each item of Chat to be dict. Got {}.'.format(
                [type(item) for item in self.__iter__()]))

        self.prepend_token_to = prepend_token_to or {}
        self.append_token_to = append_token_to or {}
        self.template = template or SIMPLE_CHAT_TEMPLATE  # TODO support external template for tokenizer
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        # image processing
        self.image_padding_side = kwargs.get('image_padding_side', "left")  # "left", "right", "random"
        # others
        self.add_generation_prompt = add_generation_prompt
        self._use_tokenizer_chat_template = use_tokenizer_chat_template
        self._role_map = None
        if self._use_tokenizer_chat_template and self.tokenizer:  # convert roles to those supported by tokenizer.chat_template
            self._role_map = kwargs.get("role_map", {})
            if not self._role_map:
                self._role_map = _check_possible_roles_in_chat_template(
                    self.tokenizer.chat_template, [utterrance['role'] for utterrance in self])
            if self._role_map and not kwargs.get("_ignore_role_map_check", False):
                for utterrance in self:
                    utterrance['role'] = self._role_map[utterrance['role']]
        self._chat_template_history = None
        self._role_tags = kwargs.get("role_tags", {})
        if self._role_tags:
            assert isinstance(self._role_tags, dict), TypeError(f"Expect role_tags to be a dict; got {type(self._role_tags)}")

        # Additionally, ChatRecords maintains another four lists:
        # _text: the text that concats role tag and content for each item
        # _input_ids: the tokenized ids, requires tokenizer
        # _role_tag_lens: the length of tokenized role tags, requires tokenizer
        # _join_tag_lens: the length of tokenized join tags, requires tokenizer
        # _predictable: whether this utterance can be used in labels
        # _stochastic: whether this utterance should be sampled
        # if _text, _input_ids, _role_tag_lens, _join_tag_lens, _predictable are provided in kwargs, use them.
        _text = kwargs.get('_text', [])
        if _text:
            self._text = _text
            self._input_ids = kwargs.get('_input_ids', [])
            self._pixel_values = kwargs.get('_pixel_values', [])
            self._image_sizes = kwargs.get('_image_sizes', [])
            self._role_tag_lens = kwargs.get('_role_tag_lens', [])
            self._join_tag_lens = kwargs.get('_join_tag_lens', [])
            self._predictable = kwargs.get('_predictable', [])
            self._stochastic = kwargs.get('_stochastic', [])
        else:
            self.format_items()

        # Additionally, ChatRecords maintains a global metadata
        # user needs to define how metadata is updated when editing ChatRecords
        # So for each chat utterance, there are two metadata: one global, one local
        self.metadata = kwargs.get('metadata', None)
        # for stochastic chat, record the trajectory
        self._trajectory_indices = []
        self._use_existing_trajectory_indices = False

    def format_items(self):
        # Additionally, ChatRecords maintains another three lists:
        self._text = []  # the text that concats role tag and content for each item
        self._input_ids = []  # the tokenized ids, requires tokenizer
        self._pixel_values = []  # processed the image pixel values
        self._image_sizes = []  # the size of the original images
        self._role_tag_lens = []  # the length of tokenized role tags, requires tokenizer
        self._join_tag_lens = []  # the length of tokenized join tags, requires tokenizer
        self._predictable = []  # whether this utterance can be used in labels
        self._stochastic = []  # whether this utterance should be sampled

        for index, item in enumerate(self.__iter__()):
            (
                formatted_item, input_ids, pixel_values, image_sizes, role_tag_len, 
                join_tag_lens, is_predictable, is_stochastic
            ) = self._format_item(item, index)
            self._text.append(formatted_item)
            if input_ids:  # when tokenizer is not provided, input_ids is None
                self._input_ids.append(input_ids)
                self._role_tag_lens.append(role_tag_len)
                self._join_tag_lens.append(join_tag_lens)
            self._pixel_values.append(pixel_values)
            self._image_sizes.append(image_sizes)
            self._predictable.append(is_predictable)
            self._stochastic.append(is_stochastic)
        else:
            self._chat_template_history = None

    def _consolidate_chat_template_history(self, chat, _history_id=None):
        if any(self._stochastic) or _history_id is not None:
            consolidate_chat = []
            for index, utterrance in enumerate(chat):
                if isinstance(utterrance["content"], str):
                    consolidate_chat.append(utterrance)
                elif isinstance(utterrance["content"], list):
                    # for historical utterrance, take the content at id 0, as we are not sure how many options it has.
                    # for the last utterrance, _history_id is the index of content to be tokenized.
                    _history_id2 = 0 if index < len(chat) - 1 else _history_id
                    utterrance = {key: (value if key != "content" else value[_history_id2]) for key, value in utterrance.items()}
                    consolidate_chat.append(utterrance)
                else:
                    content = utterrance["content"]
                    raise RuntimeError(f"Expect content to be str or list; got {content}")
            return consolidate_chat
        else:
            return chat

    def _fill_in_template(self, role, content, index, _history_id=None):
        if self._use_tokenizer_chat_template and self.tokenizer:
            # if _history_id is provided, it is the index of content to be tokenized.
            formatted_item = self.tokenizer.apply_chat_template(
                self._consolidate_chat_template_history(self[:(index+1)], _history_id), 
                add_generation_prompt=(index == (len(self)-1)) and self.add_generation_prompt, 
                tokenize=False
            )
            # if formatted_item.endswith("<think>\n"):
            #     formatted_item = formatted_item.rsplit('<think>\n')[0]
            # isolate the latest utterrace
            if self._chat_template_history:
                # for historical utterrance, take the content at id 0, as we are not sure how many options it has
                _chat_template_history = self._chat_template_history[0] \
                    if isinstance(self._chat_template_history, list) else self._chat_template_history
                formatted_item = formatted_item.split(_chat_template_history)[-1]
            elif index > 0:
                _chat_template_history = self.tokenizer.apply_chat_template(
                    self._consolidate_chat_template_history(self[:index]), 
                    add_generation_prompt=False, 
                    tokenize=False
                )
                formatted_item = formatted_item.split(_chat_template_history)[-1]
        else:
            content =  self.prepend_token_to.get(role, '') + content + self.append_token_to.get(role, '')
            formatted_item = self.template.format(ROLE=role, CONTENT=content)

        return formatted_item

    def _tokenize_formatted_item(self, formatted_item, role, index=None):       
        tokenized_prompt = self.tokenizer(
            formatted_item,
            padding=False,
            add_special_tokens=False,
        )
        input_ids = tokenized_prompt["input_ids"]
        
        # identify role_tag, such as <｜User｜>, <|im_start|>user
        role_tag = None
        role_tag_len = None
        if self._role_tags and role in self._role_tags:
            role_tag = self._role_tags[role]
            if formatted_item.startswith(self.tokenizer.bos_token + role_tag):
                role_tag = self.tokenizer.bos_token + role_tag
        else:  # identify role tag automatically
            matches = re.search(re.compile("^.*{}.*\n+".format(role)), formatted_item)
            if matches:
                role_tag = matches.group(0)
        if role_tag:
            tokenized_prompt = self.tokenizer(
                role_tag,
                padding=False,
                add_special_tokens=False,
            )
            role_tag_ids = tokenized_prompt["input_ids"]
            assert input_ids[:len(role_tag_ids)] == role_tag_ids, RuntimeError(
                "role tag ids cannot be identified.")
            role_tag_len = len(role_tag_ids)

        # identify join_tag, such as \n; we consider all tokens after eos_token as join_tag
        join_tag_len = 0
        if self.tokenizer.eos_token_id in input_ids:
            join_tag_len = len(input_ids) - 1 - input_ids.index(self.tokenizer.eos_token_id)

        return input_ids, role_tag_len, join_tag_len

    def _match_image_to_tag(self, content, pixel_values):
        """ When pixel_values is not None or empty, content must have the same number of <image> tags. 
        If not, prepend or append <image> tags
        """
        matches = re.findall(r"<image>", content)
        if matches:
            assert len(matches) == len(pixel_values)
        else:  # content has no <image> tag
            if self.image_padding_side == "random":
                image_padding_side = "left" if random.uniform(0, 1) < 0.5 else "right"
            else:
                image_padding_side = self.image_padding_side
            if image_padding_side == "left":
                content = "<image>\n" + content
            else:
                content = content + "\n<image>"

        return content

    def _format_item(self, item, index=None):
        role = item['role']
        pixel_values = item['metadata'].get('pixel_values', []) if 'metadata' in item else []
        image_sizes = item['metadata'].get('image_sizes', []) if 'metadata' in item else []
        assert isinstance(pixel_values, list)

        if isinstance(item['content'], str):
            if len(pixel_values) > 0:
                item['content'] = self._match_image_to_tag(item['content'], pixel_values)
            formatted_item = self._fill_in_template(role, item['content'], index)
            is_stochastic = False
        else:
            assert isinstance(item['content'], list) and all(isinstance(content, str) for content in item['content']), ValueError(
                "content must be a str or List[str]; got {}".format(item['content'])
            )
            if len(pixel_values) > 0:
                item['content'] = [self._match_image_to_tag(content, pixel_values) for content in item['content']]
            formatted_item = [
                self._fill_in_template(role, content, index, _history_id) for _history_id, content in enumerate(item['content'])]
            is_stochastic = True

        # for _use_tokenizer_chat_template; use _chat_template_history to manage the history
        if self._use_tokenizer_chat_template and self.tokenizer:
            self._chat_template_history = formatted_item

        input_ids = None
        role_tag_len = None
        if self.tokenizer:
            if isinstance(formatted_item, str):
                input_ids, role_tag_len, join_tag_len = self._tokenize_formatted_item(formatted_item, role, index)
            else:
                results = [self._tokenize_formatted_item(_formatted_item, role, index) for _formatted_item in formatted_item]
                input_ids = [result[0] for result in results]
                role_tag_len = [result[1] for result in results]
                join_tag_len = [result[2] for result in results]

        is_predictable = item.get("metadata", {}).get("predictable", True)

        return formatted_item, input_ids, pixel_values, image_sizes, role_tag_len, join_tag_len, is_predictable, is_stochastic

    def __setitem__(self, index, item):
        if isinstance(item, str):  # update content only
            ori_item = self.__getitem__(index)
            item = {'role': ori_item['role'], 'content': item}
            if 'metadata' in ori_item:
                item.update({'metadata': ori_item['metadata']})
        assert isinstance(item, dict), ValueError(
            'Expect the new item to be dict. Got {}.'.format(type(item)))
        super(ChatRecords, self).__setitem__(index, item)

        # set item in text and input_ids
        (
            formatted_item, input_ids, pixel_values, image_sizes, role_tag_len, 
            join_tag_lens, is_predictable, is_stochastic
        ) = self._format_item(item, index)
        self._text.__setitem__(index, formatted_item)
        if input_ids:
            self._input_ids.__setitem__(index, input_ids)
            self._role_tag_lens.__setitem__(index, role_tag_len)
            self._join_tag_lens.__setitem__(index, join_tag_lens)
        self._pixel_values.__setitem__(index, pixel_values)
        self._image_sizes.__setitem__(index, image_sizes)
        self._predictable.__setitem__(index, is_predictable)
        self._stochastic.__setitem__(index, is_stochastic)

    def __add__(self, other):
        if isinstance(other, type(self)):
            assert (
                self.prepend_token_to == other.prepend_token_to
                and self.append_token_to == other.append_token_to
                and self.template == other.template
                and type(self.tokenizer) == type(other.tokenizer)
            ), ValueError("Cannot add two ChatRecords because they have different settings.")
        else:
            assert all(isinstance(item, dict) for item in other), ValueError(
                'Expect each item of the extended list to be dict. Got {}.'.format(
                    [type(item) for item in other]))
        iterable = super(ChatRecords, self).__add__(other)
        # in case the format of item also depends on index, we redo the format of other
        base_index = len(self._text)
        for index, item in enumerate(other.__iter__()):
            (
                formatted_item, input_ids, pixel_values, image_sizes, role_tag_len, 
                join_tag_lens, is_predictable, is_stochastic
            ) = self._format_item(item, index + base_index)
            self._text.append(formatted_item)
            if input_ids:
                self._input_ids.append(input_ids)
                self._role_tag_lens.append(role_tag_len)
                self._join_tag_lens.append(join_tag_lens)
            self._pixel_values.append(pixel_values)
            self._image_sizes.append(image_sizes)
            self._predictable.append(is_predictable)
            self._stochastic.append(is_stochastic)
        else:
            self._chat_template_history = None

        return type(self)(
            iterable,
            prepend_token_to=self.prepend_token_to,
            append_token_to=self.append_token_to, 
            tokenizer=self.tokenizer,
            max_seq_length=self.max_seq_length,
            add_generation_prompt=self.add_generation_prompt,
            use_tokenizer_chat_template=self._use_tokenizer_chat_template,
            _text=self._text,
            _input_ids=self._input_ids,
            _role_tag_lens=self._role_tag_lens,
            _join_tag_lens=self._join_tag_lens,
            _predictable=self._predictable,
            _stochastic=self._stochastic,
            )

    def __iadd__(self, other):
        if isinstance(other, type(self)):
            assert (
                self.prepend_token_to == other.prepend_token_to
                and self.append_token_to == other.append_token_to
                and self.template == other.template
                and type(self.tokenizer) == type(other.tokenizer)
            ), ValueError("Cannot add two ChatRecords because they have different settings.")
        else:
            assert all(isinstance(item, dict) for item in other), ValueError(
                'Expect each item of the extended list to be dict. Got {}.'.format(
                    [type(item) for item in other]))
        # in case the format of item also depends on index, we redo the format of other
        base_index = len(self._text)
        for index, item in enumerate(other.__iter__()):
            (
                formatted_item, input_ids, pixel_values, image_sizes, role_tag_len, 
                join_tag_lens, is_predictable, is_stochastic
            ) = self._format_item(item, index + base_index)
            self._text.append(formatted_item)
            if input_ids:
                self._input_ids.append(input_ids)
                self._role_tag_lens.append(role_tag_len)
                self._join_tag_lens.append(join_tag_lens)
            self._pixel_values.append(pixel_values)
            self._image_sizes.append(image_sizes)
            self._predictable.append(is_predictable)
            self._stochastic.append(is_stochastic)
        else:
            self._chat_template_history = None

        return super(ChatRecords, self).__iadd__(other)

    def append(self, item):
        assert isinstance(item, dict), ValueError(
            'Expect the new item to be dict. Got {}.'.format(type(item)))
        if len(item) == 0:  # if item is an empty dict, do nothing
            return
        super(ChatRecords, self).append(item)

        # append item to text and input_ids
        index = len(self._text)
        (
            formatted_item, input_ids, pixel_values, image_sizes, role_tag_len, 
            join_tag_lens, is_predictable, is_stochastic
        ) = self._format_item(item, index)
        self._text.append(formatted_item)
        if input_ids:
            self._input_ids.append(input_ids)
            self._role_tag_lens.append(role_tag_len)
            self._join_tag_lens.append(join_tag_lens)
        self._pixel_values.append(pixel_values)
        self._image_sizes.append(image_sizes)
        self._predictable.append(is_predictable)
        self._stochastic.append(is_stochastic)

    def extend(self, other):
        if isinstance(other, type(self)):
            assert (
                self.prepend_token_to == other.prepend_token_to
                and self.append_token_to == other.append_token_to
                and self.template == other.template
                and type(self.tokenizer) == type(other.tokenizer)
            ), ValueError("Cannot extend two ChatRecords because they have different settings.")
            super(ChatRecords, self).extend(other)
            self._text.extend(other._text)
            self._input_ids.extend(other._input_ids)
            self._pixel_values.extend(other._pixel_values)
            self._image_sizes.extend(other._image_sizes)
            self._role_tag_lens.extend(other._role_tag_lens)
            self._join_tag_lens.extend(other._join_tag_lens)
            self._predictable.extend(other._predictable)
            self._stochastic.extend(other._stochastic)
        else:
            assert all(isinstance(item, dict) for item in other), ValueError(
                'Expect each item of the extended list to be dict. Got {}.'.format(
                    [type(item) for item in other]))
            if len(other) == 1 and len(other[0]) == 0:  # if item is [{}], do nothing
                return  
            super(ChatRecords, self).extend(other)

            # extend items to text and input_ids
            index = len(self._text)
            for item in other:
                (
                    formatted_item, input_ids, pixel_values, image_sizes, role_tag_len, 
                    join_tag_lens, is_predictable, is_stochastic
                ) = self._format_item(item, index)
                self._text.append(formatted_item)
                if input_ids:
                    self._input_ids.append(input_ids)
                    self._role_tag_lens.append(role_tag_len)
                    self._join_tag_lens.append(join_tag_lens)
                self._pixel_values.append(pixel_values)
                self._image_sizes.append(image_sizes)
                self._predictable.append(is_predictable)
                self._stochastic.append(is_stochastic)
                index += 1
            else:
                self._chat_template_history = None

    def insert(self, index, item):
        assert isinstance(item, dict), ValueError(
            'Expect the new item to be dict. Got {}.'.format(type(item)))
        super(ChatRecords, self).insert(index, item)

        (
            formatted_item, input_ids, pixel_values, image_sizes, role_tag_len, 
            join_tag_lens, is_predictable, is_stochastic
        ) = self._format_item(item, index)
        self._text.insert(index, formatted_item)
        if input_ids:
            self._input_ids.insert(index, input_ids)
            self._role_tag_lens.insert(index, role_tag_len)
            self._join_tag_lens.insert(index, join_tag_lens)
        self._pixel_values.insert(index, pixel_values)
        self._image_sizes.insert(index, image_sizes)
        self._predictable.insert(index, is_predictable)
        self._stochastic.insert(index, is_stochastic)

    @property
    def text(self):
        if any(self._stochastic):
            sampled_text = [random.choice(_text) if is_stochastic else _text for _text, is_stochastic in zip(self._text, self._stochastic)]
            return "".join(sampled_text)
        else:
            return "".join(self._text)

    @property
    def input_ids(self):
        if self._input_ids:  # Note that _past_key_values cannot handle token_sep among rounds of chat, so we do not insert any token
            if any(self._stochastic):
                joined_ids = []
                if not self._use_existing_trajectory_indices:
                    self._trajectory_indices = [-1] * len(self._input_ids)
                else:
                    assert (
                        isinstance(self._trajectory_indices, list) 
                        and len(self._trajectory_indices) == len(self._input_ids)
                    ), RuntimeError("Expecting existing trajectory_indices to be a list of length {}; got {}".format(
                        len(self._input_ids), self._trajectory_indices))
                for index, (_input_ids, is_stochastic) in enumerate(zip(self._input_ids, self._stochastic)):
                    if not self._use_existing_trajectory_indices:
                        self._trajectory_indices[index] = random.choice(range(len(_input_ids))) if is_stochastic else 0
                    joined_ids += _input_ids[self._trajectory_indices[index]] if is_stochastic else _input_ids
            else:
                joined_ids = [input_id for _input_ids in self._input_ids for input_id in _input_ids]

            # check max_seq_length
            if self.max_seq_length and self.max_seq_length < len(joined_ids):
                joined_ids = joined_ids[:self.max_seq_length]  # TODO introduce truncation_side

            return joined_ids
        else:
            return []

    @staticmethod
    def _get_label_from_input_ids(_input_ids, is_predictable, role_tag_len, join_tag_len):
        if is_predictable:
            if role_tag_len:  # remove role tag
                label = [-100] * role_tag_len + _input_ids[role_tag_len:].copy()
            else:
                label = _input_ids.copy()
            if join_tag_len:
                label = label[:-join_tag_len] + [-100] * join_tag_len
        else:
            label = [-100] * len(_input_ids)

        return label

    def labels(self, roles_to_predict: Union[str, List[str]] = None, trajectory_indices=None):
        if not self._input_ids:
            return []

        if roles_to_predict is None:
            roles_to_predict = []
        if isinstance(roles_to_predict, str):
            roles_to_predict = [roles_to_predict]
        if self._role_map:
            roles_to_predict = [self._role_map[role] for role in roles_to_predict]
        
        roles = [item['role'] for item in self.__iter__()]
        assert all(role in roles for role in roles_to_predict), ValueError(
            "Expect roles exist in chat; got {}.".format([role for role in roles_to_predict if role not in roles]))
        labels = []
        trajectory_indices = trajectory_indices or self._trajectory_indices  # use self._trajectory_indices if not provided
        if not trajectory_indices:
            trajectory_indices = [-1] * len(self._input_ids)
        for role, _input_ids, role_tag_len, join_tag_len, is_predictable, is_stochastic, trajectory_index in zip(
            roles, self._input_ids, self._role_tag_lens, self._join_tag_lens, self._predictable, self._stochastic, trajectory_indices):
            # if trajectory_index is provided, use it; otherwise do sampling
            if is_stochastic:
                if trajectory_index < 0:
                    trajectory_index = random.choice(range(len(_input_ids)))
                _input_ids = _input_ids[trajectory_index]
                role_tag_len = role_tag_len[trajectory_index]
                join_tag_len = join_tag_len[trajectory_index]
            # is_predictable is True if:
            # roles_to_predict is empty (i.e., every role is predictable)
            # roles_to_predict is not empty but current role is predictable
            is_predictable = (not roles_to_predict or role in roles_to_predict) and is_predictable
            label = self._get_label_from_input_ids(_input_ids, is_predictable, role_tag_len, join_tag_len)
            labels.append(label)
        joined_labels = [label for _labels in labels for label in _labels]

        # check max_seq_length
        if self.max_seq_length and self.max_seq_length < len(joined_labels):
            joined_labels = joined_labels[:self.max_seq_length]  # TODO introduce truncation_side

        return joined_labels

    @staticmethod
    def _flatten_list(sequences):
        if isinstance(sequences, list) and isinstance(sequences[0], list):  # lazy check nested list
            return [_seq for seq in sequences for _seq in seq]
        else:
            return sequences

    def tokenize(self, roles_to_predict):
        input_ids = self.input_ids
        tokenized_chat = {
            "input_ids": input_ids,
            "attention_mask": [1] * len(input_ids),
            "labels": self.labels(roles_to_predict, self._trajectory_indices),
        }
        # if pixel_values is provided, include it, otherwise pass only input_ids
        pixel_values = self._flatten_list(self._pixel_values)
        image_sizes = self._flatten_list(self._image_sizes)
        if pixel_values:
            tokenized_chat.update({
                "pixel_values": pixel_values,
                "image_sizes": image_sizes,
            })
        
        return tokenized_chat

    def clone(self):
        return type(self)(
            copy.deepcopy(list(self.__iter__())),
            prepend_token_to=self.prepend_token_to,
            append_token_to=self.append_token_to, 
            template=self.template,
            tokenizer=self.tokenizer,
            max_seq_length=self.max_seq_length,
            add_generation_prompt=self.add_generation_prompt,
            use_tokenizer_chat_template=self._use_tokenizer_chat_template,
            system_prompt=self.system_prompt,
            _text=copy.deepcopy(self._text),
            _input_ids=copy.deepcopy(self._input_ids),
            _pixel_values=copy.deepcopy(self._pixel_values),
            _image_sizes=copy.deepcopy(self._image_sizes),
            _role_tag_lens=copy.deepcopy(self._role_tag_lens),
            _join_tag_lens=copy.deepcopy(self._join_tag_lens),
            _predictable=copy.deepcopy(self._predictable),
            _stochastic=copy.deepcopy(self._stochastic),
            image_padding_side=self.image_padding_side,
            role_map=self._role_map,
            _ignore_role_map_check=True,
            role_tags=self._role_tags,
            metadata=self.metadata,
            )


class LlamaChatRecords(ChatRecords):
    """ Extends ChatRecords to manage Llama's special requirement on chat format:
    <s>[INST] <<SYS>>
    {SYSTEM}
    <</SYS>>

    {{INSTRUCTION}} [/INST]
    {{OUTPUT}}</s>
    <s>[INST] {{INSTRUCTION}} [/INST]
    {{OUTPUT}}</s>...

    Llama's chat format does not have role tag, so for ROLE not in {'human', 'user', 'assistant'}, 
    we also attach it to the query round.

    Args:
        prepend_token_to: optional, a dict of the following format:
            {
                'human': '...'  # some token
            }
            The token will be prepended to each content of the role when converting chat to string.
        append_token_to: optional, a dict of the following format:
            {
                'assistant': '...'  # some token
            }
            The token will be appended to each content of the role when converting chat to string.
        template:
        tokenizer: 

    Usage:
        chat = LlamaChatRecords(
            instance['chat'], 
            prepend_token_to=None,
            append_token_to={'Assistant': tokenizer.eos_token}, 
            tokenizer=tokenizer,
            max_seq_length=8192)
        print(chat.text)
        print(chat.input_ids)
        print(chat.labels())
    """
    def __init__(
        self, 
        iterable, 
        *, 
        prepend_token_to=None, 
        append_token_to=None, 
        template=None,
        tokenizer=None,
        max_seq_length=8192,
        **kwargs
        ):
        # get system prompt; if the role of iterable[0] is system, use its content as system prompt
        self.system_prompt = kwargs.pop("system_prompt", "")
        if len(iterable) > 0 and iterable[0].get("role", "").lower() == "system":
            self.system_prompt = iterable[0]["content"]
            iterable = iterable[1:]
        super(LlamaChatRecords, self).__init__(
            iterable, 
            prepend_token_to=prepend_token_to,
            append_token_to=append_token_to,
            tokenizer=tokenizer,
            max_seq_length=max_seq_length,
            **kwargs
            )

    def _fill_in_template(self, role, content, index):
        content =  self.prepend_token_to.get(role, '') + content + self.append_token_to.get(role, '')
        if index == 0:
            assert role.lower() in {'user', 'human'}
            system_prompt = LLAMA_SYSTEM_PROMPT.format(SYSTEM=self.system_prompt)
            template = """<s>[INST] {SYSTEM_PROMPT}{CONTENT} [/INST]"""
            formatted_item = template.format(SYSTEM_PROMPT=system_prompt, CONTENT=content)
        elif role.lower() in {'user', 'human'}:
            template = """<s>[INST] {CONTENT} [/INST]"""
            formatted_item = template.format(CONTENT=content)
        elif role.lower() == 'assistant':
            template = """{CONTENT}"""
            formatted_item = template.format(CONTENT=content)
        else:  # special role, such as solver
            template = """<s>[INST] [{ROLE}]\n{CONTENT} [/INST]"""
            formatted_item = template.format(ROLE=role, CONTENT=content)
        
        return formatted_item

    def _tokenize_formatted_item(self, formatted_item, role, index=None): 
        tokenized_prompt = self.tokenizer(
            formatted_item,
            padding=False,
            add_special_tokens=False,
        )
        input_ids = tokenized_prompt["input_ids"]
        # identify role tag
        role_tag_prefix = role_tag_suffix = None
        if index == 0:  # role_tag covers the system message
            matches = re.search(re.compile("<s>\[INST\]([\S\s]*)<<\/SYS>>"), formatted_item)
            role_tag_prefix = matches.group(0) + "\n\n"
            role_tag_suffix = "[/INST]"
        elif role.lower() in {'user', 'human'}:
            role_tag_prefix = "<s>[INST] "
            role_tag_suffix = "[/INST]"
        elif role.lower() != 'assistant':
            role_tag_prefix = "<s>[INST] [{ROLE}]\n".format(ROLE=role)
            role_tag_suffix = "[/INST]"

        role_tag_len = [None, None]
        if role_tag_prefix:
            role_tag_len = [0, 0]
            tokenized_prompt = self.tokenizer(
                role_tag_prefix,
                padding=False,
                add_special_tokens=False,
            )
            role_tag_ids_prefix = tokenized_prompt["input_ids"]
            if input_ids[:len(role_tag_ids_prefix)] == role_tag_ids_prefix:
                role_tag_len[0] = len(role_tag_ids_prefix)
            elif input_ids[:len(role_tag_ids_prefix)-1] == role_tag_ids_prefix[:-1]:
                role_tag_len[0] = len(role_tag_ids_prefix) - 1
            else:
                raise RuntimeError("role tag ids cannot be identified.")
        if role_tag_suffix:
            tokenized_prompt = self.tokenizer(
                role_tag_suffix,
                padding=False,
                add_special_tokens=False,
            )
            role_tag_ids_suffix = tokenized_prompt["input_ids"]
            assert input_ids[-len(role_tag_ids_suffix):] == role_tag_ids_suffix, RuntimeError(
                "role tag ids cannot be identified.")
            role_tag_len[1] = len(role_tag_ids_suffix)

        return input_ids, role_tag_len[0], role_tag_len[1]


class LlamaNewChatRecords(ChatRecords):
    """ Extends ChatRecords to manage Llama's special requirement on chat format (for llama version later than llama 3):
    <|begin_of_text|><|start_header_id|>system<|end_header_id|>

    {SYSTEM}

    <|eot_id|><|start_header_id|>user<|end_header_id|>

    {{INSTRUCTION}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

    {{OUTPUT}}<|eot_id|>

    Args:
        prepend_token_to: optional, a dict of the following format:
            {
                'user': '...'  # some token
            }
            The token will be prepended to each content of the role when converting chat to string.
        append_token_to: optional, a dict of the following format:
            {
                'assistant': '...'  # some token
            }
            The token will be appended to each content of the role when converting chat to string.
        template:
        tokenizer: 

    Usage:
        chat = LlamaNewChatRecords(
            instance['chat'], 
            prepend_token_to=None,
            append_token_to={'assistant': tokenizer.eos_token}, 
            tokenizer=tokenizer,
            max_seq_length=8192)
        print(chat.text)
        print(chat.input_ids)
        print(chat.labels())
    """
    def __init__(
        self, 
        iterable, 
        *, 
        prepend_token_to=None, 
        append_token_to=None, 
        template=None,
        tokenizer=None,
        max_seq_length=8192,
        add_generation_prompt=False,
        **kwargs
        ):
        # get system prompt; if the role of iterable[0] is system, use its content as system prompt
        self.system_prompt = kwargs.pop("system_prompt", "")
        if len(iterable) > 0 and iterable[0].get("role", "").lower() == "system":
            self.system_prompt = iterable[0]["content"]
            iterable = iterable[1:]
        # llama 3.3 requires a default system_prompt
        if self.system_prompt in {"default"}:
            from datetime import datetime

            self.system_prompt = DEFAULT_LLAMA_NEW_SYSTEM_PROMPT.format(TODAY=datetime.today().strftime("%d %b %Y"))

        super(LlamaNewChatRecords, self).__init__(
            iterable, 
            prepend_token_to=prepend_token_to,
            append_token_to=append_token_to,
            tokenizer=tokenizer,
            max_seq_length=max_seq_length,
            add_generation_prompt=add_generation_prompt,
            **kwargs
            )

    def _fill_in_template(self, role, content, index):
        content =  self.prepend_token_to.get(role, '') + content + self.append_token_to.get(role, '')
        role = role.lower()
        if role == 'human':  # in case chat is appended, the new utterance may use human as role
            role = 'user'
        if index == 0:
            assert role in {'user', 'human'}
            formatted_item = "<|begin_of_text|>"
            if self.system_prompt:
                sys_template = "<|start_header_id|>system<|end_header_id|>\n\n{SYSTEM_PROMPT}<|eot_id|>"
                formatted_item += sys_template.format(SYSTEM_PROMPT=self.system_prompt)
            template = "<|start_header_id|>user<|end_header_id|>\n\n{CONTENT}<|eot_id|>"
            formatted_item += template.format(CONTENT=content)
        elif role in {'user', 'human'}:
            template = """<|start_header_id|>user<|end_header_id|>\n\n{CONTENT}<|eot_id|>"""
            formatted_item = template.format(CONTENT=content)
        elif role == 'assistant':
            template = """<|start_header_id|>assistant<|end_header_id|>\n\n{CONTENT}<|eot_id|>"""
            formatted_item = template.format(CONTENT=content)
        else:  # special role, such as solver
            template = """<|start_header_id|>{ROLE}<|end_header_id|>\n\n{CONTENT}<|eot_id|>"""
            formatted_item = template.format(ROLE=role, CONTENT=content)

        if index == (len(self)-1) and self.add_generation_prompt: 
            formatted_item += '<|start_header_id|>assistant<|end_header_id|>\n\n'
        
        return formatted_item

    def _tokenize_formatted_item(self, formatted_item, role, index=None): 
        role = role.lower()
        if role == 'human':  # in case chat is appended, the new utterance may use human as role
            role = 'user'
        tokenized_prompt = self.tokenizer(
            formatted_item,
            padding=False,
            add_special_tokens=False,
        )
        input_ids = tokenized_prompt["input_ids"]
        # identify role tag
        role_tag_prefix = role_tag_suffix = None
        if index == 0:  # role_tag covers the system message
            if role == 'system':
                role_tag_prefix = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
            elif self.system_prompt:
                matches = re.search(re.compile("<\|begin_of_text\|><\|start_header_id\|>system<\|end_header_id\|>\\n\\n([\S\s]*?)<\|eot_id\|>"), formatted_item)
                role_tag_prefix = matches.group(0) + "<|start_header_id|>{ROLE}<|end_header_id|>\n\n".format(ROLE=role)
            else:
                role_tag_prefix = "<|begin_of_text|><|start_header_id|>{ROLE}<|end_header_id|>\n\n".format(ROLE=role)
        elif role.lower() in {'user', 'human'}:
            role_tag_prefix = "<|start_header_id|>user<|end_header_id|>\n\n"
        elif role.lower() == 'assistant':
            role_tag_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
        else:
            role_tag_prefix = "<|start_header_id|>[{ROLE}]<|end_header_id|>\n\n".format(ROLE=role)

        role_tag_len = [None, None]
        if role_tag_prefix:
            role_tag_len = [0, 0]
            tokenized_prompt = self.tokenizer(
                role_tag_prefix,
                padding=False,
                add_special_tokens=False,
            )
            role_tag_ids_prefix = tokenized_prompt["input_ids"]
            if input_ids[:len(role_tag_ids_prefix)] == role_tag_ids_prefix:
                role_tag_len[0] = len(role_tag_ids_prefix)
            elif input_ids[:len(role_tag_ids_prefix)-1] == role_tag_ids_prefix[:-1]:
                role_tag_len[0] = len(role_tag_ids_prefix) - 1
            else:
                raise RuntimeError("role tag ids cannot be identified.")
        if role_tag_suffix:  # llama3 model supports using 
            tokenized_prompt = self.tokenizer(
                role_tag_suffix,
                padding=False,
                add_special_tokens=False,
            )
            role_tag_ids_suffix = tokenized_prompt["input_ids"]
            assert input_ids[-len(role_tag_ids_suffix):] == role_tag_ids_suffix, RuntimeError(
                "role tag ids cannot be identified.")
            role_tag_len[1] = len(role_tag_ids_suffix)

        return input_ids, role_tag_len[0], role_tag_len[1]


class MistralChatRecords(LlamaChatRecords):
    """ Extends ChatRecords to manage mistral's special requirement on chat format:
    <s>[INST] {{INSTRUCTION}} [/INST]
    {{OUTPUT}}</s>...

    mistral's chat format does not have role tag, so for ROLE not in {'human', 'user', 'assistant'}, 
    we also attach it to the query round.

    Args:
        prepend_token_to: optional, a dict of the following format:
            {
                'human': '...'  # some token
            }
            The token will be prepended to each content of the role when converting chat to string.
        append_token_to: optional, a dict of the following format:
            {
                'assistant': '...'  # some token
            }
            The token will be appended to each content of the role when converting chat to string.
        template:
        tokenizer: 

    Usage:
        chat = MistralChatRecords(
            instance['chat'], 
            prepend_token_to=None,
            append_token_to={'Assistant': tokenizer.eos_token}, 
            tokenizer=tokenizer,
            max_seq_length=8192)
        print(chat.text)
        print(chat.input_ids)
        print(chat.labels())
    """
    def _fill_in_template(self, role, content, index):
        content =  self.prepend_token_to.get(role, '') + content + self.append_token_to.get(role, '')
        if role.lower() in {'user', 'human'}:
            template = """<s>[INST] {CONTENT} [/INST]"""
            formatted_item = template.format(
                CONTENT=self.system_prompt + "\n" + content if self.system_prompt and index == 0 else content)
        elif role.lower() == 'assistant':
            template = """{CONTENT}"""
            formatted_item = template.format(CONTENT=content)
        else:  # special role, such as solver
            template = """<s>[INST] [{ROLE}]\n{CONTENT} [/INST]"""
            formatted_item = template.format(ROLE=role, CONTENT=content)
        
        return formatted_item

    def _tokenize_formatted_item(self, formatted_item, role, index=None): 
        tokenized_prompt = self.tokenizer(
            formatted_item,
            padding=False,
            add_special_tokens=False,
        )
        input_ids = tokenized_prompt["input_ids"]
        # identify role tag
        role_tag_prefix = role_tag_suffix = None
        if role.lower() in {'user', 'human'}:
            role_tag_prefix = "<s>[INST] "
            role_tag_suffix = "[/INST]"
        elif role.lower() != 'assistant':
            role_tag_prefix = "<s>[INST] [{ROLE}]\n".format(ROLE=role)
            role_tag_suffix = "[/INST]"

        role_tag_len = [None, None]
        if role_tag_prefix:
            role_tag_len = [0, 0]
            tokenized_prompt = self.tokenizer(
                role_tag_prefix,
                padding=False,
                add_special_tokens=False,
            )
            role_tag_ids_prefix = tokenized_prompt["input_ids"]
            if input_ids[:len(role_tag_ids_prefix)] == role_tag_ids_prefix:
                role_tag_len[0] = len(role_tag_ids_prefix)
            elif input_ids[:len(role_tag_ids_prefix)-1] == role_tag_ids_prefix[:-1]:
                role_tag_len[0] = len(role_tag_ids_prefix) - 1
            else:
                raise RuntimeError("role tag ids cannot be identified.")
        if role_tag_suffix:
            tokenized_prompt = self.tokenizer(
                role_tag_suffix,
                padding=False,
                add_special_tokens=False,
            )
            role_tag_ids_suffix = tokenized_prompt["input_ids"]
            assert input_ids[-len(role_tag_ids_suffix):] == role_tag_ids_suffix, RuntimeError(
                "role tag ids cannot be identified.")
            role_tag_len[1] = len(role_tag_ids_suffix)

        return input_ids, role_tag_len[0], role_tag_len[1]
    

class QwenChatRecords(ChatRecords):
    """ Extends ChatRecords to manage qwen's special requirement on chat format:
    <|im_start|>' + message['role'] + '\n' + message['content']<|im_end|>\n
    <|im_start|>' + message['role'] + '\n' + message['content']<|im_end|>\n
    ...
    <|im_start|>' + message['role'] + '\n' + message['content']<|im_end|>\n
    or
    <|im_start|>system\nYou are a helpful assistant.<|im_end|>\n
    <|im_start|>' + message['role'] + '\n' + message['content']<|im_end|>\n
    ...
    <|im_start|>' + message['role'] + '\n' + message['content']<|im_end|>\n

    Args:
        prepend_token_to: optional, a dict of the following format:
            {
                'human': '...'  # some token
            }
            The token will be prepended to each content of the role when converting chat to string.
        append_token_to: optional, a dict of the following format:
            {
                'assistant': '...'  # some token
            }
            The token will be appended to each content of the role when converting chat to string.
        template:
        tokenizer: 

    Usage:
        chat = QwenChatRecords(
            instance['chat'], 
            prepend_token_to=None,
            append_token_to={'Assistant': tokenizer.eos_token}, 
            tokenizer=tokenizer,
            max_seq_length=8192)
        print(chat.text)
        print(chat.input_ids)
        print(chat.labels())
    """
    def __init__(
        self, 
        iterable, 
        *, 
        prepend_token_to=None, 
        append_token_to=None, 
        template=None,
        tokenizer=None,
        max_seq_length=8192,
        add_generation_prompt=False,
        **kwargs
        ):
        # chage chat format to qwen format role
        for chat in iterable:
            chat['role'] = chat['role'].lower()
            if chat['role'] == 'human':
                chat['role'] = 'user'
        # get system prompt; if the role of iterable[0] is system, use its content as system prompt
        self.system_prompt = kwargs.pop("system_prompt", "")
        if len(iterable) > 0 and iterable[0].get("role", "").lower() == "system":
            self.system_prompt = iterable[0]["content"]
            iterable = iterable[1:]

        super(QwenChatRecords, self).__init__(
            iterable, 
            prepend_token_to=prepend_token_to,
            append_token_to=append_token_to,
            tokenizer=tokenizer,
            max_seq_length=max_seq_length,
            add_generation_prompt=add_generation_prompt,
            **kwargs
            )
    
    def _fill_in_template(self, role, content, index):
        role = role.lower()
        if role == 'human':  # in case chat is appended, the new utterance may use human as role
            role = 'user'
        if index == 0: 
            if role.lower() == 'system':
                system_prompt = QWEN_SYSTEM_PROMPT.format(SYSTEM=content)
                template = """<|im_start|>{SYSTEM_PROMPT}<|im_end|>"""
                formatted_item = template.format(SYSTEM_PROMPT=system_prompt)
            else:
                system_prompt = QWEN_SYSTEM_PROMPT.format(SYSTEM=self.system_prompt if self.system_prompt else DEFAULT_QWEN_SYSTEM_PROMPT)
                template = """<|im_start|>{SYSTEM_PROMPT}<|im_end|>\n<|im_start|>{ROLE}\n{CONTENT}<|im_end|>"""
                formatted_item = template.format(SYSTEM_PROMPT=system_prompt, ROLE=role, CONTENT=content)
        else:  # special role, such as solver
            template = """\n<|im_start|>{ROLE}\n{CONTENT}<|im_end|>"""
            formatted_item = template.format(ROLE=role, CONTENT=content)
       
        if index == (len(self)-1) and self.add_generation_prompt: 
            formatted_item = formatted_item + '\n<|im_start|>assistant\n'

        return formatted_item

    def _tokenize_formatted_item(self, formatted_item, role, index=None):
        role = role.lower()
        if role == 'human':  # in case chat is appended, the new utterance may use human as role
            role = 'user'
        tokenized_prompt = self.tokenizer(
            formatted_item,
            padding=False,
            add_special_tokens=False,
        )
        input_ids = tokenized_prompt["input_ids"]
        # identify role tag
        role_tag_prefix = role_tag_suffix = None
        if index == 0:  # role_tag covers the system message
            if role == 'system':
                role_tag_prefix = "<|im_start|>system\n"
            else:
                matches = re.search(re.compile("<\|im_start\|>system\n([\S\s]*?)<\|im_end\|>\n"), formatted_item)
                role_tag_prefix = matches.group(0) + "<|im_start|>{ROLE}\n".format(ROLE=role)
        else:
            role_tag_prefix = "\n<|im_start|>{ROLE}\n".format(ROLE=role)

        role_tag_len = [None, None]
        if role_tag_prefix:
            role_tag_len = [0, 0]
            tokenized_prompt = self.tokenizer(
                role_tag_prefix,
                padding=False,
                add_special_tokens=False,
            )
            role_tag_ids_prefix = tokenized_prompt["input_ids"]
            if input_ids[:len(role_tag_ids_prefix)] == role_tag_ids_prefix:
                role_tag_len[0] = len(role_tag_ids_prefix)
            elif input_ids[:len(role_tag_ids_prefix)-1] == role_tag_ids_prefix[:-1]:
                role_tag_len[0] = len(role_tag_ids_prefix) - 1
            else:
                raise RuntimeError("role tag ids cannot be identified.")
        if role_tag_suffix:
            tokenized_prompt = self.tokenizer(
                role_tag_suffix,
                padding=False,
                add_special_tokens=False,
            )
            role_tag_ids_suffix = tokenized_prompt["input_ids"]
            assert input_ids[-len(role_tag_ids_suffix):] == role_tag_ids_suffix, RuntimeError(
                "role tag ids cannot be identified.")
            role_tag_len[1] = len(role_tag_ids_suffix)

        return input_ids, role_tag_len[0], role_tag_len[1]
    

class DataCollatorForSeq2SeqForAllKeys(DataCollatorForSeq2Seq):

    logprobs_pad_token_id: float = -0.0
    default_pad_token_id: int = -100
    
    def __call__(self, features, return_tensors=None):
        if return_tensors is None:
            return_tensors = self.return_tensors
        for fkey in features[0].keys():
            if fkey not in ("labels", "input_ids", "attention_mask"):
                features_to_pad = [feature[fkey] for feature in features]
                if isinstance(features_to_pad[0], int):
                    continue
                max_feature_length = max(len(f) for f in features_to_pad)
                if self.pad_to_multiple_of is not None:
                    max_feature_length = (
                        (max_feature_length + self.pad_to_multiple_of - 1)
                        // self.pad_to_multiple_of
                        * self.pad_to_multiple_of
                    )
                padding_side = self.tokenizer.padding_side
                for feature in features:
                    if fkey == "inferenced_logprobs":
                        remainder = [self.logprobs_pad_token_id] * (max_feature_length - len(feature[fkey]))
                    else:
                        remainder = [self.default_pad_token_id] * (max_feature_length - len(feature[fkey]))
                    if isinstance(feature[fkey], list):
                        feature[fkey] = (
                            feature[fkey] + remainder if padding_side == "right" else remainder + feature[fkey]
                        )
                    elif padding_side == "right":
                        feature[fkey] = np.concatenate([feature[fkey], remainder]).astype(np.float32)
                    else:
                        feature[fkey] = np.concatenate([remainder, feature[fkey]]).astype(np.float32)

        features = super().__call__(features, return_tensors=return_tensors)

        return features


def load_datasets(file_paths, columns_to_read=None, is_disable_caching=False):
    """ Load datasets from one or more json or other files.

    Args:
        file_paths: str or List[str]
        columns_to_read: optional, List[str], features to read in the loaded data
    """ 
    if isinstance(file_paths, (list, tuple)):
        # check if some dataset is repeated
        if len(set(file_paths)) != len(file_paths):
            raise RuntimeError("Repeated datasets detected.")

        multi_data = [load_datasets(file_path, columns_to_read, is_disable_caching=is_disable_caching) for file_path in file_paths]
        # merge a list of data
        data = concatenate_datasets(multi_data) if len(multi_data) > 1 else multi_data[0]
    elif isinstance(file_paths, str):
        if file_paths.startswith(".."):
            file_paths = os.path.abspath(file_paths)
        assert os.path.isfile(file_paths), FileExistsError(
            'The given dataset {} is not a valid path.'.format(file_paths))
        if is_disable_caching:
            download_mode = "force_redownload"
        else:
            download_mode = "reuse_dataset_if_exists"
        if file_paths.endswith(".json"):  # TODO: support jsonl
            data = load_dataset("json", data_files=file_paths, split='train', download_mode=download_mode)
        elif file_paths.endswith(".parquet"):
            data = load_dataset("parquet", data_files=file_paths, split='train', download_mode=download_mode)
        else:
            data = load_dataset(file_paths, split='train', download_mode=download_mode)
        
        if isinstance(columns_to_read, list) and columns_to_read:
            features = list(data.features.keys())
            columns_to_remove = [feature for feature in features if feature not in columns_to_read]
            if len(columns_to_remove) == len(features):
                raise RuntimeError('All features are removed.')
            else:
                data = data.remove_columns(columns_to_remove)
    else:
        data = None
    
    return data


def get_filtered_data_by_indices(data, indices, include_by_indices=True):
    if isinstance(indices, str) and os.path.isfile(indices):
        import pickle

        with open(indices, "rb") as filter_ids:
            indices = pickle.load(filter_ids)
            
    assert isinstance(indices, list), TypeError(
        "Expect filter_data_by_indices to be a list; got\n{}".format(indices))
    if isinstance(data, Dataset):
        if include_by_indices:
            data = data.filter(lambda example, idx: idx in indices, with_indices=True)
        else:
            data = data.filter(lambda example, idx: idx not in indices, with_indices=True)
    elif isinstance(data, list):
        if include_by_indices:
            data = [instance for index, instance in enumerate(data) if index in indices]
        else:
            data = [instance for index, instance in enumerate(data) if index not in indices]
    else:
        raise NotImplementedError("Currently only support list or datasets.arrow_dataset.Dataset; got {}".format(type(data)))

    return data


def prepare_datasets(
    data_args, 
    preprocessor_fn=None, 
    eval_preprocessor_fn=None, 
    postprocessor_fn=None, 
    eval_postprocessor_fn=None, 
    columns_to_read=None,
    shuffle=True
    ):
    """ This function:
    1. Loads datasets from possibly multiple json or other files.
    2. Preprocesses datasets (via map) if preprocessor_fn is provided.
    3. Postprocesses datasets (via set_format) if postprocessor_fn is provided.
    4. Outputs the datasets.
    The difference between preprocessing and outprocessing is that preprocessed data remains the same 
    during training while postprocessing happens every time an instance is visited. Data augmentation 
    is best passed as a postprocessor_fn.

    Args:
        data_args: any argument structure with attributes train_file, validation_file, validation_split, 
            and optionally disable_caching.
        preprocessor_fn: a callable function applied during preprocessing.
        postprocessor_fn: a callable function applied during postprocessing.
        columns_to_read: columns to read from data, before any data processing.
    """
    is_disable_caching = getattr(data_args, 'disable_caching', False)
    if is_disable_caching:
        disable_caching()
    # train
    train_data = load_datasets(data_args.train_file, columns_to_read=columns_to_read, is_disable_caching=is_disable_caching)

    filter_data_by_indices = getattr(data_args, "filter_data_by_indices", None)
    if filter_data_by_indices:
        train_data = get_filtered_data_by_indices(train_data, filter_data_by_indices)
    
    # validation
    if data_args.validation_file:
        validation_data = load_datasets(data_args.validation_file, columns_to_read=columns_to_read, is_disable_caching=is_disable_caching)
        if shuffle:
            train_data = train_data.shuffle()
            validation_data = validation_data.shuffle()
        if preprocessor_fn:
            train_data = train_data.map(preprocessor_fn, keep_in_memory=is_disable_caching, load_from_cache_file=not is_disable_caching)
            validation_data = validation_data.map(eval_preprocessor_fn, keep_in_memory=is_disable_caching, load_from_cache_file=not is_disable_caching)
    elif data_args.validation_split > 0:
        train_val = train_data.train_test_split(
            test_size=data_args.validation_split, shuffle=shuffle, #  default seed is acquired from np
        )
        if shuffle:
            train_data = train_val["train"].shuffle()
            validation_data = train_val["test"].shuffle()
        if preprocessor_fn:
            train_data = train_val["train"].map(preprocessor_fn, keep_in_memory=is_disable_caching, load_from_cache_file=not is_disable_caching)
            validation_data = train_val["test"].map(eval_preprocessor_fn, keep_in_memory=is_disable_caching, load_from_cache_file=not is_disable_caching)
    else:
        if shuffle:
            train_data = train_data.shuffle()
        if preprocessor_fn:
            train_data = train_data.map(preprocessor_fn, keep_in_memory=is_disable_caching, load_from_cache_file=not is_disable_caching)
        validation_data = None

    if postprocessor_fn:
        train_data.set_format('custom', transform=postprocessor_fn)
        if validation_data:
            validation_data.set_format('custom', transform=eval_postprocessor_fn)

    if is_disable_caching:
        train_data.cleanup_cache_files()
        if validation_data:
            validation_data.cleanup_cache_files()

    return train_data, validation_data


def load_jsons(file_path: Union[str, List[str]], is_multi_lines: bool = False):
    """ Load a single json file, or multiple json files and concatenate the instances.

    Arguments:
        file_path:
        is_multi_lines: bool, whether the data instances are saved in single line or one instance per line.
            single-line json: [dict, dict, ..., dict]
            multi-line json: dict, dict, ..., dict
    """
    if isinstance(file_path, str):
        assert os.path.isfile(file_path) and file_path.endswith(('json', 'jsonl')), FileNotFoundError(
            "{} is not a valid json file path.".format(file_path))
        with open(file_path, 'r', encoding='utf-8') as f:
            if is_multi_lines:
                instances = [json.loads(line) for line in f]
            else:
                instances = json.loads(f.read())
        if isinstance(instances[0], list):
            instances = instances[0]
    elif isinstance(file_path, list):
        instances = [] 
        for single_file_path in file_path:
            instances.extend(load_jsons(single_file_path, is_multi_lines))
    else:
        raise ValueError('Cannot parse {}. Expect it to be str or List[str].'.format(file_path))

    return instances


def sort_instances(instances):
    """Sort instances according to index value.
    """
    indices = [instance["index"] for instance in instances]
    sort_indices = [i for i, x in sorted(enumerate(indices), key=lambda x: x[1])]
    instances = [instances[index] for index in sort_indices]

    return instances


def find_sublist_in_full_list(full_list: List[Any], sublist: List[Any], count: int = None):
    """ Find the occurrance of a sublist in full_list.
    """
    len_full = len(full_list)
    len_sub = len(sublist)
    if count:
        assert count > 0, ValueError('count must be positive, got {}'.format(count))
        occurrances = []
        # get a generator to do the work
        occurrance_gen = ((idx, idx+len_sub) for idx in range(len_full - len_sub) if full_list[idx:idx+len_sub] == sublist)
        for _ in range(count):
            occurrance = next(occurrance_gen, None)
            if occurrance:
                occurrances.append(occurrance)
            else:
                break
    else:
        occurrances = [(idx, idx+len_sub) for idx in range(len_full - len_sub) if full_list[idx:idx+len_sub] == sublist]

    return occurrances


def fill_system_prompt(prompt_template, data_instance, default_system_prompt=""):
    if "SYSTEM_PROMPT" in prompt_template:
        if prompt_template in [LLAMA_QA_PROMPT_INPUT, LLAMA_QA_PROMPT_WO_INPUT]:
            system_prompt = LLAMA_SYSTEM_PROMPT.format(
                SYSTEM=data_instance.get("system", default_system_prompt))
            prompt_template = prompt_template.format(SYSTEM_PROMPT=system_prompt)
        elif prompt_template in [LLAMA_NEW_QA_PROMPT_INPUT, LLAMA_NEW_QA_PROMPT_WO_INPUT]:
            prompt_template = prompt_template.format(
                SYSTEM_PROMPT=data_instance.get("system", default_system_prompt))
        else:
            raise NotImplementedError(
                "SYSTEM_PROMPT for the following prompt_template is not supported:\n{}".format(prompt_template))

    return prompt_template


class TokenizedPromptProcessor:
    """ Data processor used for:
    1. processing data dict instance into prompt;
    2. tokenizing the prompt.
    
    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        config: (:class:`DataArguments`):
            Any DataArguments with the following attributes: 
            train_on_inputs: whether to calculate training loss on instruction/inputs, or sololy on outputs.
            max_seq_length:
            prompt_templates: the template to use in prompts. Could be:
                - str in [`simple_qa`, `simplest_qa`, `alpaca`]: to use preset templates
                - a list of two strings, corresponding to prompt with input placeholder, and prompt without input placeholder.
            system_prompt: string, could be empty
    """
    def __init__(
        self, 
        tokenizer: PreTrainedTokenizer, 
        config,
        is_eval=False,
        ) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.train_on_inputs = config.train_on_inputs
        self.max_seq_length =  config.max_seq_length
        self.prompt_w_input, self.prompt_wo_input = get_prompt_templates(config.prompt_templates)
        self.system_prompt = config.system_prompt
        self.is_eval = is_eval

    def fill_prompt_template(self, data_instance, eval_mode=False):
        if eval_mode:
            output = ""
        else:
            output = data_instance["output"]
        if data_instance.get("input", None):
            prompt_w_input = fill_system_prompt(
                self.prompt_w_input, data_instance, self.system_prompt)
            return prompt_w_input.format(
                INSTRUCTION=data_instance["instruction"], INPUT=data_instance["input"], OUTPUT=output)
        elif data_instance.get("output", None) is not None:
            prompt_wo_input = fill_system_prompt(
                self.prompt_wo_input, data_instance, self.system_prompt)
            return prompt_wo_input.format(
                INSTRUCTION=data_instance["instruction"], OUTPUT=output)
        else:  # multi-turn chat should be handled differently.
            raise NotImplementedError(
                "data_instance does not have input and output keys. Do you mean to use TokenizedChatProcessor instead?")

    def tokenize(self, prompt):
        tokenized_prompt = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.max_seq_length,
            padding=False,
            return_tensors=None,
        )
        if (
            tokenized_prompt["input_ids"][-1] != self.tokenizer.eos_token_id
            and len(tokenized_prompt["input_ids"]) < self.max_seq_length
        ):
            tokenized_prompt["input_ids"].append(self.tokenizer.eos_token_id)
            tokenized_prompt["attention_mask"].append(1)
        tokenized_prompt["labels"] = tokenized_prompt["input_ids"].copy()

        return tokenized_prompt

    def eval_tokenize(self, prompt, instance):
        total_prompt = self.fill_prompt_template(instance)
        tokenized_total_prompt = self.tokenizer(  # dict with keys ['input_ids', 'attention_mask']
            total_prompt,
            padding=True,
            return_tensors=None,
        )
        tokenized_prompt = self.tokenizer(  # dict with keys ['input_ids', 'attention_mask']
            prompt,
            padding=True,
            return_tensors=None,
        )
        tokenized_label = tokenized_total_prompt["input_ids"][len(tokenized_prompt["input_ids"]):]
        tokenized_prompt["labels"] = tokenized_label
        return tokenized_prompt
    
    def __call__(self, data_instance: dict) -> dict:
        """
        Args:
            data_instance: a dict with keys "instruction", "output" and optionally "input"

        Returns:
            tokenized_prompt: a dict with keys "input_ids", "attention_mask", "labels".
        """
        if self.is_eval:
            eval_prompt = self.fill_prompt_template(data_instance, eval_mode=True)
            tokenized_prompt = self.eval_tokenize(eval_prompt, data_instance)
        else:
            prompt = self.fill_prompt_template(data_instance)
            tokenized_prompt = self.tokenize(prompt)  # dict with keys ['input_ids', 'attention_mask', 'labels']
            if not self.train_on_inputs:
                if data_instance.get("output", None) is not None:  # single turn QA
                    # get the length of prompt without data_instance["output"]
                    partial_prompt = self.fill_prompt_template({**data_instance, "output": ""})
                    partial_tokenized_prompt = self.tokenize(partial_prompt)
                    prompt_len = len(partial_tokenized_prompt["input_ids"]) - 1  # partial_tokenized_prompt has eos_token
                    # set the labels to -100 except the output parts
                    tokenized_prompt["labels"] = [-100] * prompt_len + tokenized_prompt["labels"][prompt_len:]
                else:  # multi-turn chat
                    raise NotImplementedError("data_instance does not have output key. Do you mean to use TokenizedChatProcessor instead?")

        return tokenized_prompt


def simple_image_preprocessor(data_instance):
    """ If data_instance has image feature, this function does:
    1. Convert image to RGB mode;
    2. Convert image in bytes to numpy array (uint8);
    The output array will be of shape (H, W, 3). 
    """
    if data_instance.get("image", None) is not None:
        data_instance['image'] = _convert_image_bytes_to_rgb_array(data_instance['image'])

    return data_instance


class TokenizedChatProcessor:
    """ Data processor used for:
    1. processing data dict instance into ChatRecords;
    2. get the input_ids and labels from ChatRecords.
    
    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        config (:class:`DataArguments`):
            Any DataArguments with the following attributes: 
            train_on_inputs: whether to calculate training loss on instruction/inputs, or sololy on outputs.
            roles_to_predict: List[str], the roles to predict.
            max_seq_length:
        image_processor (:class:`BaseImageProcessor`):
            Any image processor that accepts image inputs, resizes, crop, normalize it and return the tensor outputs.
            The image inputs could be of format: `PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, 
            `List[np.ndarray]`, `List[torch.Tensor]`.

    Inputs:
        data_instance: dict or datasets.formatting.formatting.LazyRow object with the following keys:
            `chat`: List[dict], each with keys `content`, `role` and optionally `metadata`.
            `metadata`: (optional)
    """
    def __init__(
        self, 
        tokenizer: PreTrainedTokenizer, 
        config,
        is_eval=False,
        image_processor: BaseImageProcessor = None,
        ) -> None:
        self.tokenizer = tokenizer
        self.prepend_token_to = None
        self.append_token_to = {'Assistant': tokenizer.eos_token}
        if getattr(config, "roles_to_predict", None):
            self.roles_to_predict = config.roles_to_predict
        elif not config.train_on_inputs:  # keep it for back compatibility
            if config.prompt_templates == 'qwen_chat':
                self.roles_to_predict = ['assistant']
            else:
                self.roles_to_predict = ['Assistant']
        else:
            self.roles_to_predict = None
        self.max_seq_length =  config.max_seq_length
        self.prompt_templates = config.prompt_templates
        if self.prompt_templates is None:
            self.prompt_templates = "simple_chat"
        assert self.prompt_templates.endswith("chat"), ValueError(
            "Expect prompt_templates to be *_chat; got {}.".format(self.prompt_templates))
        self.system_prompt = config.system_prompt
        # image related
        self.image_processor = image_processor
        self.is_eval = is_eval
        # others
        self.do_tokenization = getattr(config, "do_tokenization", True)
        self.role_tags = getattr(config, "role_tags", {})
        self.role_map = getattr(config, "role_map", {})

    def _get_chat(self, data_instance):
        if self.prompt_templates == "llama_chat":
            chat = LlamaChatRecords(
                data_instance['chat'], 
                prepend_token_to=self.prepend_token_to,
                append_token_to=self.append_token_to, 
                tokenizer=self.tokenizer,
                max_seq_length=self.max_seq_length,
                metadata=data_instance.get('metadata', None),
                system_prompt=self.system_prompt,
                )
        elif self.prompt_templates == "llama_new_chat":
            chat = LlamaNewChatRecords(
                data_instance['chat'], 
                prepend_token_to=self.prepend_token_to,
                append_token_to=self.append_token_to, 
                tokenizer=self.tokenizer,
                max_seq_length=self.max_seq_length,
                add_generation_prompt=self.is_eval,
                metadata=data_instance.get('metadata', None),
                system_prompt=self.system_prompt,
                )
        elif self.prompt_templates == "mistral_chat":
            chat = MistralChatRecords(
                data_instance['chat'], 
                prepend_token_to=self.prepend_token_to,
                append_token_to=self.append_token_to, 
                tokenizer=self.tokenizer,
                max_seq_length=self.max_seq_length,
                metadata=data_instance.get('metadata', None),
                system_prompt=self.system_prompt,
                )
        elif self.prompt_templates == "qwen_chat":
            self.append_token_to = None
            chat = QwenChatRecords(
                data_instance['chat'], 
                prepend_token_to=self.prepend_token_to,
                append_token_to=self.append_token_to, 
                tokenizer=self.tokenizer,
                max_seq_length=self.max_seq_length,
                add_generation_prompt=self.is_eval,
                metadata=data_instance.get('metadata', None),
                system_prompt=self.system_prompt,
                )
        else:  # simple_chat
            chat = ChatRecords(
                data_instance['chat'], 
                prepend_token_to=self.prepend_token_to,
                append_token_to=self.append_token_to, 
                tokenizer=self.tokenizer,
                max_seq_length=self.max_seq_length,
                metadata=data_instance.get('metadata', None),
                add_generation_prompt=self.is_eval,
                use_tokenizer_chat_template=True,
                system_prompt=self.system_prompt,
                role_tags=self.role_tags,
                role_map=self.role_map,
                )

        return chat

    def __call__(self, data_instance: dict) -> dict:
        # keep it for back compatibility
        if "instruction" in data_instance and data_instance.get("chat", None) is None:
            data_instance = convert_prompt_instance_to_chat(data_instance, image_processor=self.image_processor)

        chat = self._get_chat(data_instance)
        if self.do_tokenization:
            # format output; datasets.map would write the datasets to tmp file; 
            # during writing, ChatRecords info would be lost. Thus, here we still 
            # convert it to input_ids, attention_mask and labels
            tokenized_chat = chat.tokenize(self.roles_to_predict)
        else:
            tokenized_chat = {"messages": [utterrance for utterrance in chat]}  # convert back to list as ChatRecords is not supported in accelerate
            if chat.system_prompt:
                tokenized_chat["messages"] = [{'role': 'system', 'content': chat.system_prompt}] + tokenized_chat["messages"]

        return tokenized_chat


def _check_chat_metadata(chat: ChatRecords, check_keys=None):
    """ Check whether check_keys are provided in chat. 

    Note there are two places to store metadata:
    - global metadata given by chat.metadata
    - local metadata of each utterance accessed by chat[index].get("metadata", None)
    """
    assert (
        chat.metadata is not None
        or any(utterance.get("metadata", None) is not None for utterance in chat)
    ), ValueError("metadata must be provided.")
    if check_keys is None:
        check_keys = ["objective_description", "var_description", "constraint_description"]
    assert isinstance(check_keys, (list, tuple)), ValueError("check_keys must be a list of string keys.")
    
    # check global metadata first
    global_metadata = chat.metadata
    if isinstance(global_metadata, dict):
        has_key = all(key in global_metadata for key in check_keys)
    elif isinstance(global_metadata, list):  # list of dict
        has_key = all(all(key in metadata for key in check_keys) for metadata in global_metadata)
    else:
        has_key = False
    if has_key:
        return True
    # check local metadata
    local_metadata = [utterance["metadata"] for utterance in chat if utterance.get("metadata", None) is not None]
    has_key = any(all(key in metadata for key in check_keys) for metadata in local_metadata)

    return has_key


class TextAugmentationBase:
    """ Base class for text augmentation methods.
    """
    def augment(self, instance):
        return instance

    def augment_chat(self, chat):
        return chat

    def __call__(self, instance):
        if isinstance(instance, dict) and "instruction" in instance:
            return self.augment(instance)
        elif isinstance(instance, ChatRecords):
            return self.augment_chat(instance)
        else:
            raise NotImplementedError(
                "Expect input to be either dict {'instruction': ...} or ChatRecords. Got {}.".format(instance))


class LPSymmetricShuffleAug(TextAugmentationBase):
    """
    Linear programming program is defined as,
    minimize \sum_{i}{a_i * x_i} + b, s.t. \sum_{i}{c_{ji} * x_{i}} + d_j <= 0 \forall j.
    Its solution is invariant under
    - rename the name of variables
    - shuffle of the order of constraints;
    - shuffle of the order of variables (with the corresponding coefs also shuffled);
    - scaling of the objective or any constraint by a non-zero constant s != 0;
    Note if the scaling constant is negative (s < 0), we need to
        - change the optimization direction: minimize \sum_{i}{a_i * x_i} + b --> maximize s * (\sum_{i}{a_i * x_i} + b);
        - change the sign of inequalities: \sum_{i}{c_{ji} * x_{i}} + d_j <= 0 --> s * (\sum_{i}{c_{ji} * x_{i}} + d_j) >= 0;
    
    We call the above the symmetries of LP. Here we implement the first two, because judging 
    straightforwardly from the LP description, we usually have a preference for the rest.
    """
    def __init__(self, p_rename=0.0, p_shuffle=1.0):
        self.p_rename = p_rename
        self.p_shuffle = p_shuffle

        self.variable_set = [
            ['x', 'y', 'z'],
            ['x1', 'x2', 'x3'],
            ['y1', 'y2', 'y3'],
            ['z1', 'z2', 'z3'],
            ['a', 'b', 'c'],
        ]

        self._do_rename = None
        self._do_shuffle = None

    def augment(self, instance, _do_copy=True):
        assert ("var_description" in instance) and ("constraint_description" in instance), KeyError(
            "Expect var_description and constraint_description to be provided.")

        self._do_rename = random.random() < self.p_rename
        self._do_shuffle = random.random() < self.p_shuffle
        # we should never directly change the original data
        if (self._do_rename or self._do_shuffle) and _do_copy:
            instance = copy.deepcopy(instance)

        if self._do_rename:
            choice = random.random()
            if choice < 0.5 or len(instance["var_description"]) > 3:  # uppercase to lowercase or reverse; may be disbaled if uppercase means matrix
                instance["var_description"] = {
                    key.swapcase(): value for key, value in instance["var_description"].items()}
                # minimize and maximize should not be case-swapped.
                instance["objective_description"][1] = instance["objective_description"][1][:8] + instance["objective_description"][1][8:].swapcase()
                instance["constraint_description"] = {
                    key.swapcase(): [value_.swapcase() for value_ in value] for key, value in instance["constraint_description"].items()}
            else:
                source_var_def = list(var.lower() for var in instance['var_description'].keys())
                indices = list(range(len(self.variable_set)))
                random.shuffle(indices)
                for index in indices:
                    if not [var for var in source_var_def if var in self.variable_set[index]]:
                        break
                else:
                    raise RuntimeError("source_var_def {} covers all variables sets.".format(source_var_def))
                target_var_def = self.variable_set[index]
                choice = random.random()
                if choice < 0.5:
                    target_var_def = [var.swapcase() for var in target_var_def]
                patterns = [re.compile("(?<![0-9a-zA-Z]){}(?![0-9a-zA-Z])".format(var)) for var in instance['var_description'].keys()]
                instance["var_description"] = {
                    key: value for key, value in zip(target_var_def, instance["var_description"].values())}
                for source, target in zip(patterns, target_var_def):
                    instance["objective_description"][1] = re.sub(source, target, instance["objective_description"][1])
                    instance["constraint_description"] = {
                        re.sub(source, target, key): [re.sub(source, target, value_) for value_ in value] for key, value in instance["constraint_description"].items()}
        if self._do_shuffle:
            constraints = list(instance["constraint_description"].keys())
            random.shuffle(constraints)
            instance["constraint_description"] = {key: instance["constraint_description"][key] for key in constraints}

        return instance

    def augment_chat(self, chat):
        # check if the required info is in chat; if not, cannot do augmentation
        if not _check_chat_metadata(chat, ["objective_description", "var_description", "constraint_description"]):
            return chat

        metadata = None
        copied = False
        for index, utterance in enumerate(chat):
            if (
                utterance["metadata"]["task"].startswith("SimpleMathFormulationTask.A") 
                or utterance["metadata"]["task"].startswith("IncorrectMathCheckingTask.A")
            ):
                if metadata is None:
                    metadata = self.augment(chat.metadata, _do_copy=False)
                if self._do_rename or self._do_shuffle:
                    if not copied:
                        chat = chat.clone()
                        copied = True
                    chat.metadata = metadata

        return chat


class LPIncorrectAnswerAug(TextAugmentationBase):
    """
    Linear programming program is defined as,
    minimize \sum_{i}{a_i * x_i} + b, s.t. \sum_{i}{c_{ji} * x_{i}} + d_j <= 0 \forall j.
    This method is used to automatically generate multiple incorrect answers.
    The types of incorrect answers are：
    wrong_fuctions = ['obj_dir_wrong','obj_coeff_wrong','obj_coeff_wrong_hd','obj_coeff_switch_wrong','obj_extra_wrong','obj_extra_wrong_hd','obj_miss_wrong',
            'cons_miss_wrong','cons_miss_wrong_hd','cons_more_wrong','cons_more_wrong_hd','cons_dir_wrong','cons_coeff_wrong','cons_coeff_wrong_hd','con_switch_coeff_wrong','cons_switch_coeff_wrong',
            'con_miss_wrong','con_extra_wrong','con_extra_wrong_hd','con_var_wrong','con_var_wrong_hd']
    """
    explain = r'\(\s*[\w\']+((\s)[\w\']+)+\s*\)'
    explain_ch= r'[(（]([^\(\)]*[一-龥][^\(\)]*)+[)）]'

    def __init__(self, p_wrong=0.5, p_wrong_map = None, p_real_wrong=0.05):
        '''
        change the direction of objective:
            change maximize 20x+30y
            to  minimize 20x+30y
        Args:
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''
        self.p_wrong = p_wrong

        self.wrong_fn_map = {'obj_dir_wrong': self.obj_dir_wrong,
            'obj_coeff_wrong': self.obj_coeff_wrong,
            'obj_coeff_wrong_hd': self.obj_coeff_wrong_hd,
            'obj_coeff_switch_wrong': self.obj_coeff_switch_wrong,
            'obj_extra_wrong': self.obj_extra_wrong,
            'obj_extra_wrong_hd': self.obj_extra_wrong_hd,
            'obj_miss_wrong': self.obj_miss_wrong,
            'cons_miss_wrong': self.cons_miss_wrong,
            'cons_miss_wrong_hd': self.cons_miss_wrong_hd,
            'cons_more_wrong': self.cons_more_wrong,
            'cons_more_wrong_hd': self.cons_more_wrong_hd,
            'cons_dir_wrong': self.cons_dir_wrong,
            'cons_coeff_wrong': self.cons_coeff_wrong,
            'cons_coeff_wrong_hd': self.cons_coeff_wrong_hd,
            'con_switch_coeff_wrong': self.con_switch_coeff_wrong,
            'cons_switch_coeff_wrong': self.cons_switch_coeff_wrong,
            'con_miss_wrong': self.con_miss_wrong,
            'con_extra_wrong': self.con_extra_wrong,
            'con_extra_wrong_hd': self.con_extra_wrong_hd,
            'con_var_wrong': self.con_var_wrong,
            'con_var_wrong_hd': self.con_var_wrong_hd
            }
        if p_wrong_map == None:
            self.p_wrong_map = { k:round(1/len(self.wrong_fn_map),2) for k in self.wrong_fn_map.keys()}
        else:
            self.p_wrong_map = { k:p_wrong_map.get(k, 0.0) for k in self.wrong_fn_map.keys()}
        self.p_real_wrong = p_real_wrong
        self.letters = letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
        self.max_sample = 20
        self.do_wrong = False

    def get_all_nums(self, obj, cons):
        nump = r'\b(?:\d+(?:\.\d+|\/\d+)?)'
        nums = list(re.findall(nump, obj))
        for con in cons:
            nums = nums + list(re.findall(nump, con))
        nums = [n for n in nums if eval(n) != 0]

        return list(set(nums))
    
    def obj_dir_wrong(self, obj):
        '''
        change the direction of objective:
            change maximize 20x+30y
            to  minimize 20x+30y
        Args:
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''

        min_obj = re.search(r'(Minimize|minimize|Min|min)', obj)
        max_obj = re.search(r'(Maximize|maximize|Max|max)', obj)
        if min_obj:
            return re.sub(min_obj[0],'maximize', obj)
        elif max_obj:
            return re.sub(max_obj[0],'minimize', obj)
        else:
            raise KeyError('The objective direction is wrong')
    
    def obj_coeff_wrong(self, obj):
        '''
        change a coefficiet of objective formula to a random number from[0-100]:
            change maximize 20x+30y
            to  maximize 20x+23y
        Args:
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''

        nump = r'\b(?:\d+(?:\.\d+|\/\d+)?)'
        varp = r'(?:[a-zA-Z](?:[a-zA-Z]|\d|_)*)'
        single_var1 = r'([+\-\(\/]\s*)([a-zA-Z](?:[a-zA-Z]|\d|_)*)'
        single_var2 = r'^\s*([a-zA-Z](?:[a-zA-Z]|\d|_)*)'

        obj_pattern = r'^(Minimize|Maximize|minimize|maximize|Min|Max|min|max):?\s*(.*)'
        obj_match = re.findall(obj_pattern, obj)[0]
        obj_dir = obj_match[0]
        obj_fm = obj_match[1]
        fm_parts = list(re.split(r'(\+|\/|<=|>=|<|>|=)',obj_fm))
        token_indexes = [i for i in range(0,len(fm_parts),2)]
        index_choice = random.choice(token_indexes)
        token = fm_parts[index_choice]

        new_token = re.sub(single_var1, r'\g<1>1 * \g<2>', token)
        new_token = re.sub(single_var2, r'1 * \g<1>',new_token)
        all_nums = list(re.findall(nump,new_token))
        num_choice = random.choice(all_nums)
        rand_num = num_choice
        count = 0
        while count < self.max_sample and rand_num==num_choice:
            rand_num = random.randint(1,100)
            count += 1

        new_token = re.sub(num_choice,str(rand_num),new_token)

        fm_parts[index_choice] = new_token
        new_formula = ''.join(fm_parts)

        return obj_dir + ' ' + new_formula
    
    def obj_coeff_wrong_hd(self, obj, cons):
        '''
        change a coefficiet of objective formula to a number sample from all nums form objective and constraints
            change maximize 20x+30y
            to  maximize 20x+23y
        Args:
            obj: objective string
            cons: a string list, constains constraints.
        Returns:
            processed objective
        '''

        nump = r'\b(?:\d+(?:\.\d+|\/\d+)?)'
        varp = r'(?:[a-zA-Z](?:[a-zA-Z]|\d|_)*)'
        single_var1 = r'([+\-\(\/]\s*)([a-zA-Z](?:[a-zA-Z]|\d|_)*)'
        single_var2 = r'^\s*([a-zA-Z](?:[a-zA-Z]|\d|_)*)'

        obj_pattern = r'^(Minimize|Maximize|minimize|maximize|Min|Max|min|max):?\s*(.*)'
        obj_match = re.findall(obj_pattern, obj)[0]
        obj_dir = obj_match[0]
        obj_fm = obj_match[1]
        
        fm_parts = list(re.split(r'(\+|\/|<=|>=|<|>|=)',obj_fm))
        token_indexes = [i for i in range(0,len(fm_parts),2)]
        index_choice = random.choice(token_indexes)
        token = fm_parts[index_choice]

        new_token = re.sub(single_var1, r'\g<1>1 * \g<2>', token)
        new_token = re.sub(single_var2, r'1 * \g<1>',new_token)
        token_all_nums = list(re.findall(nump,new_token))
        num_choice = random.choice(token_all_nums)

        all_nums = self.get_all_nums(obj, cons)
        rand_num = num_choice
        count = 0
        while count < self.max_sample and rand_num==num_choice:
            rand_num = random.choice(all_nums)
            count += 1
        new_token = re.sub(num_choice,str(rand_num),new_token)

        fm_parts[index_choice] = new_token
        new_formula = ''.join(fm_parts)

        return obj_dir + ' ' + new_formula
    
    def get_token_and_coefficients(self, token):
        nump = r'\b(?:\d+(?:\.\d+|\/\d+)?)'
        varp = r'(?:[a-zA-Z](?:[a-zA-Z]|\d|_)*)'
        single_var1 = r'([+\-\(\/]\s*)([a-zA-Z](?:[a-zA-Z]|\d|_)*)'
        single_var2 = r'^\s*([a-zA-Z](?:[a-zA-Z]|\d|_)*)'

        new_token = re.sub(single_var1, r'\g<1>1 * \g<2>', token)
        new_token = re.sub(single_var2, r'1 * \g<1>',new_token)
        return new_token, list(re.findall(nump,new_token))

    def find_2token_with_different_coeff(self, fm_parts):
        for i in range(0,len(fm_parts),2):
            for j in range(i+2,len(fm_parts),2):
                t1, t2 = fm_parts[i], fm_parts[j]
                nt1, t1l = self.get_token_and_coefficients(t1)
                nt2, t2l = self.get_token_and_coefficients(t2)
                for n1 in t1l:
                    for n2 in t2l:
                        if n1 != n2:
                            # subtitude n1 in t1 to n2 in t2
                            return i, j, nt1, nt2, n1, n2
        return None
    
    def obj_coeff_switch_wrong(self, obj, cons):
        '''
        change a coefficiet of objective formula to a number sample from all nums form objective and constraints
            change maximize 20x+30y
            to  maximize 20x+23y
        Args:
            obj: objective string
            cons: a string list, constains constraints.
        Returns:
            processed objective
        '''
        
        obj_pattern = r'^(Minimize|Maximize|minimize|maximize|Min|Max|min|max):?\s*(.*)'
        obj_match = re.findall(obj_pattern, obj)[0]
        obj_dir = obj_match[0]
        obj_fm = obj_match[1]

        fm_parts = list(re.split(r'(\+|\/|<=|>=|<|>|=)',obj_fm))
        rlt = self.find_2token_with_different_coeff(fm_parts)
        if rlt == None:
            return self.obj_coeff_wrong_hd(obj,cons)
        else:
            i, j, nt1, nt2, n1, n2 = rlt
            new1, new2 = re.sub(n1, n2, nt1), re.sub(n2, n1, nt2)

            fm_parts[i] = new1
            fm_parts[j] = new2
            new_formula = ''.join(fm_parts)
            
            return obj_dir + ' ' + new_formula
    
    def obj_extra_wrong(self, obj):
        '''
        add a token (random number from [1,100] multyply a random letter) to the objective formula:
            change maximize 20x+30y
            to  maximize 20x+30y+2s
        Args:
            obj: a string objective.
        Returns:
            processed constraints list
        '''

        obj_pattern = r'^(Minimize|Maximize|minimize|maximize|Min|Max|min|max):?\s*(.*)'
        obj_match = re.findall(obj_pattern, obj)[0]
        obj_dir = obj_match[0]
        obj_fm = obj_match[1]
        
        letter_choice = random.choice(self.letters)
        rand_num = random.randint(1,100)
        new_formula = f'{obj_fm} + {rand_num} * {letter_choice}'

        return obj_dir + ' ' + new_formula
    
    def obj_extra_wrong_hd(self, varlist, obj, cons):
        '''
        add a token (random sample from all nums form objective and constraints multyply a varialble from variable list) to the objective formula:
            change maximize 20x+30y
            to  maximize 20x+30y+2s
        Args:
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''

        obj_pattern = r'^(Minimize|Maximize|minimize|maximize|Min|Max|min|max):?\s*(.*)'
        obj_match = re.findall(obj_pattern, obj)[0]
        obj_dir = obj_match[0]
        obj_fm = obj_match[1]
        
        letter_choice = random.choice(varlist)
        rand_num = random.choice(self.get_all_nums(obj, cons))
        new_formula = f'{obj_fm} + {rand_num} * {letter_choice}'

        return obj_dir + ' ' + new_formula
    
    def obj_miss_wrong(self, obj, p_miss1 = 0.3, p_miss2 = 0.6):
        '''
        random delete the direction or formula of objection:
            change maximize 20x+30y
            to  20x+30y or maximize  or ' '
        Args:
            obj: a string , optimize objective.
            p_miss: the probability of delete a part of objective
        Returns:
            processed optimize objective
        '''

        obj_pattern = r'^(Minimize|Maximize|minimize|maximize|Min|Max|min|max):?\s*(.*)'
        obj_match = re.findall(obj_pattern, obj)[0]
        obj_dir = obj_match[0]
        obj_fm = obj_match[1]
        p_dir = random.random()
        if p_dir < p_miss1:
            obj_dir = ''
        elif p_dir < p_miss2:
            obj_fm = ''
        else:
            obj_dir = ''
            obj_fm = ''

        return obj_dir + ' ' + obj_fm
    
    def cons_miss_wrong(self, cons):
        '''
        random delete a constraint from constraint list:
            change [x+y<=10, x>=0, y>=0]
            to [x+y<=10, y>=0]
        Args:
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''

        delete_index = random.randint(0,len(cons)-1)
        cons.pop(delete_index)
        return cons
    
    def cons_miss_wrong_hd(self, cons):
        '''
        remove all the nonnegative constraints:
            change [x+y<=10, x>=0, y>=0]
            to [x+y<=10]
        Args:
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''
        return [ con for con in cons if re.match(r'(?:[a-zA-Z](?:[a-zA-Z]|\d|_)*)\s*>=\s*0', con) == None]

    def cons_more_wrong(self, varlist, cons):
        '''
        add a random constraint in constraint list:
            change [x+y<=10, x>=0, y>=0]
            to [x+y<=10, 50x+47y>=98, x>=0, y>=0]
        Args:
            varlist: a string list, constains old variables in constraints.
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''

        new_con_l = ' + '.join([f'{random.randint(1,100)} * {var}' for var in varlist])
        new_con_r = random.randint(1,100)
        new_opt = random.choice(['<=','>='])
        new_con= f'{new_con_l} {new_opt} {new_con_r}'
        cons.append(new_con)
        return cons
    
    def cons_more_wrong_hd(self, varlist, obj, cons):
        '''
        add a random constraint in constraint list, coefficient and varialbles are all form thoes in formula:
            change [x+y<=10, x>=0, y>=0]
            to [x+y<=10, 50x+47y>=98, x>=0, y>=0]
        Args:
            varlist: a string list, constains old variables in constraints.
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''

        all_nums = self.get_all_nums(obj, cons)
        new_con_l = ' + '.join([f'{random.choice(all_nums)} * {var}' for var in varlist])
        new_con_r = random.choice(all_nums)
        new_opt = random.choice(['<=','>='])
        new_con= f'{new_con_l} {new_opt} {new_con_r}'
        cons.append(new_con)
        return cons

    def cons_dir_wrong(self, cons):
        '''
        change the direction of a constraint which is random choosed from constraints list:
            change [x+y<=10, x>=0, y>=0]
            to [x+y>=10, x>=0, y>=0]
        Args:
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''

        choice_index = random.randint(0,len(cons)-1)
        choice_con = cons[choice_index]
        # mapping = {'>=':'<=', '<=':'>=', '>':'<', '<':'>'}
        mapping = {}
        mapping['>='] = random.choice(['<=','>'])
        mapping['<='] = random.choice(['>=','<'])
        mapping['='] = random.choice(['<=','>='])
        mapping['>'] = random.choice(['<','>='])
        mapping['<'] = random.choice(['>','<='])
        new_con = re.sub(r'(>=|<=|>|<|=)', lambda x: mapping.get(x.group(1), x.group(0)),choice_con)
        cons[choice_index] = new_con
        return cons

    def cons_coeff_wrong(self, cons):
        '''
        change a coefficiet of constraint which is random choosed from constraints list to a random number from[0-100]:
            change [x+y<=10, x>=0, y>=0]
            to [x+50y<=10, x>=0, y>=0]
        Args:
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''
        
        choice_index = random.randint(0,len(cons)-1)
        choice_con = cons[choice_index]

        fm_parts = list(re.split(r'(\+|\/|<=|>=|<|>|=)',choice_con))
        token_indexes = [i for i in range(0,len(fm_parts),2)]
        index_choice = random.choice(token_indexes)
        choice_token = fm_parts[index_choice]

        nump = r'\b(?:\d+(?:\.\d+|\/\d+)?)'
        varp = r'(?:[a-zA-Z](?:[a-zA-Z]|\d|_)*)'
        single_var1 = r'([+\-\(\/]\s*)([a-zA-Z](?:[a-zA-Z]|\d|_)*)'
        single_var2 = r'^\s*([a-zA-Z](?:[a-zA-Z]|\d|_)*)'

        new_token = re.sub(single_var1, r'\g<1>1 * \g<2>', choice_token)
        new_token = re.sub(single_var2, r'1 * \g<1>',new_token)
        all_nums = list(re.findall(nump,new_token))

        num_choice = random.choice(all_nums)
        rand_num = num_choice
        count = 0
        while count < self.max_sample and rand_num==num_choice:
            rand_num = random.randint(1,100)
            count += 1
        new_token = re.sub(num_choice,str(rand_num), new_token)

        fm_parts[index_choice] = new_token
        new_formula = ''.join(fm_parts)

        cons[choice_index] = new_formula

        return cons
    
    def cons_coeff_wrong_hd(self, obj, cons):
        '''
        change a coefficiet of constraint which is random choosed from constraints list to
        a number sample from all nums form objective and constraints:
            change [x+y<=10, x>=0, y>=0]
            to [x+50y<=10, x>=0, y>=0]
        Args:
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''
        
        choice_con = 'x >= 0'
        count = 0
        while count < self.max_sample and re.match(r'(?:[a-zA-Z](?:[a-zA-Z]|\d|_)*)\s*>=\s*0', choice_con):
            con_choice_index = random.randint(0,len(cons)-1)
            choice_con = cons[con_choice_index]
            count+=1

        fm_parts = list(re.split(r'(\+|\/|<=|>=|<|>|=)',choice_con))
        token_indexes = [i for i in range(0,len(fm_parts),2)]
        index_choice = random.choice(token_indexes)
        choice_token = fm_parts[index_choice]

        nump = r'\b(?:\d+(?:\.\d+|\/\d+)?)'
        varp = r'(?:[a-zA-Z](?:[a-zA-Z]|\d|_)*)'
        single_var1 = r'([+\-\(\/]\s*)([a-zA-Z](?:[a-zA-Z]|\d|_)*)'
        single_var2 = r'^\s*([a-zA-Z](?:[a-zA-Z]|\d|_)*)'

        new_token = re.sub(single_var1, r'\g<1>1 * \g<2>', choice_token)
        new_token = re.sub(single_var2, r'1 * \g<1>',new_token)
        token_all_nums = list(re.findall(nump,new_token))
        num_choice = random.choice(token_all_nums)

        all_nums = self.get_all_nums(obj, cons)
        rand_num = num_choice
        count = 0
        while count < self.max_sample and rand_num==num_choice:
            rand_num = random.choice(all_nums)
            count +=1 
        new_token = re.sub(num_choice,str(rand_num), new_token)

        fm_parts[index_choice] = new_token
        new_formula = ''.join(fm_parts)

        cons[con_choice_index] = new_formula

        return cons
    
    def con_switch_coeff_wrong(self, cons):
        '''
        switch 2 different coefficient of a constraints
            change [x+50y<=10, x>=0, y>=0]
            to [50x+y<=10, x>=0, y>=0]
        Args:
            obj: objective string
            cons: a string list, constains constraints.
        Returns:
            processed objective
        '''
        net_cons = [con for con in cons if not re.match(r'(?:[a-zA-Z](?:[a-zA-Z]|\d|_)*)\s*>=\s*0', con)]
        if len(net_cons) == 0:
            return self.con_extra_wrong(cons)
        else:
            choice_con = random.choice(net_cons)
            con_choice_index = cons.index(choice_con)
            fm_parts = list(re.split(r'(\+|\/|<=|>=|<|>|=)',choice_con))
            rlt = self.find_2token_with_different_coeff(fm_parts)
            if rlt == None:
                return self.cons_coeff_wrong(cons)
            else:
                i, j, nt1, nt2, n1, n2 = rlt
                new1, new2 = re.sub(n1, n2, nt1), re.sub(n2, n1, nt2)

                fm_parts[i] = new1
                fm_parts[j] = new2
                new_formula = ''.join(fm_parts)
                cons[con_choice_index] = new_formula

                return cons
    
    def cons_switch_coeff_wrong(self, cons):
        '''
        switch 2 different coefficient of 2 different constraints
            change [x+50y<=10, x>=0, y>=0]
            to [50x+y<=10, x>=0, y>=0]
        Args:
            obj: objective string
            cons: a string list, constains constraints.
        Returns:
            processed objective
        '''
        net_cons = [con for con in cons if not re.match(r'(?:[a-zA-Z](?:[a-zA-Z]|\d|_)*)\s*>=\s*0', con)]
        if len(net_cons) < 2:
            return self.con_switch_coeff_wrong(cons)
        else:
            choice_cons = random.sample(net_cons,2)
            con_choice_index1 = cons.index(choice_cons[0])
            con_choice_index2 = cons.index(choice_cons[1])
            fm_parts1 = list(re.split(r'(\+|\/|<=|>=|<|>|=)',choice_cons[0]))
            fm_parts2 = list(re.split(r'(\+|\/|<=|>=|<|>|=)',choice_cons[1]))
            def find_2token_with_different_coeff_form_2fm(fm_parts1,fm_parts2):
                for i in range(0,len(fm_parts1),2):
                    for j in range(0,len(fm_parts2),2):
                        t1, t2 = fm_parts1[i], fm_parts2[j]
                        nt1, t1l = self.get_token_and_coefficients(t1)
                        nt2, t2l = self.get_token_and_coefficients(t2)
                        for n1 in t1l:
                            for n2 in t2l:
                                if n1 != n2:
                                    # subtitude n1 in t1 to n2 in t2
                                    return i, j, nt1, nt2, n1, n2

            rlt = find_2token_with_different_coeff_form_2fm(fm_parts1, fm_parts2)
            if rlt == None:
                return self.cons_coeff_wrong(cons)
            else:
                i, j, nt1, nt2, n1, n2 = rlt
                new1, new2 = re.sub(n1, n2, nt1), re.sub(n2, n1, nt2)

                fm_parts1[i] = new1
                fm_parts2[j] = new2
                new_formula1 = ''.join(fm_parts1)
                new_formula2 = ''.join(fm_parts2)
                cons[con_choice_index1] = new_formula1
                cons[con_choice_index2] = new_formula2

                return cons

    def con_miss_wrong(self, cons):
        '''
        random remove a part of a constraint which is random choosed from constraints list:
            change [x+y<=10, x>=0, y>=0]
            to [y<=10, x>=0, y>=0]
        
        Args:
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''

        choice_con = 'x >= 0'
        count = 0
        while count < self.max_sample and re.match(r'(?:[a-zA-Z](?:[a-zA-Z]|\d|_)*)\s*>=\s*0', choice_con):
            con_choice_index = random.randint(0,len(cons)-1)
            choice_con = cons[con_choice_index]
            count += 1

        fm_parts = list(re.split(r'(\+|\/|<=|>=|<|>|=)',choice_con))
        token_indexes = [i for i in range(0,len(fm_parts),2)]
        token_choice_index = random.choice(token_indexes)
   
        if token_choice_index == 0:
            new_formula = ''.join(fm_parts[2:])
        elif token_choice_index == token_indexes[-1]:
            new_formula = ''.join(fm_parts[:token_choice_index-1])
        else:
            # if fm_parts[index_choice-1] in ['<=','>=','<','>','=']:
            new_formula = ''.join(fm_parts[:token_choice_index-1]+fm_parts[token_choice_index+1:])

        cons[con_choice_index] = new_formula

        return cons
    
    def con_extra_wrong(self, cons):
        '''
        add a token (random number from [1,100] multyply a random letter) to a random constraint in constraint list:
            change [x+y<=10, x>=0, y>=0]
            to [x+y<=10, x+10*z>=0, y>=0]
        
        Args:
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''

        con_choice_index = random.randint(0,len(cons)-1)
        choice_con = cons[con_choice_index]
        le, opt, bo = re.match(r'(.*?)(>=|<=|>|<|=)(.*)', choice_con).groups()

        letter_choice = random.choice(self.letters)
        rand_num = random.randint(1,100)

        new_formula = f'{le}+ {rand_num} * {letter_choice} {opt}{bo}'
        # cons.append(new_formula)
        cons[con_choice_index] = new_formula

        return cons
    
    def con_extra_wrong_hd(self, varlist, obj,cons):
        '''
        add a token (random sample from all nums form objective and constraints multyply a varialble from variable list) to a random constraint in constraint list:
            change [x+y<=10, x>=0, y>=0]
            to [x+y<=10, x+10*z>=0, y>=0]
        
        Args:
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''

        choice_con = 'x >= 0'
        count = 0
        while count < self.max_sample and re.match(r'(?:[a-zA-Z](?:[a-zA-Z]|\d|_)*)\s*>=\s*0', choice_con):
            con_choice_index = random.randint(0,len(cons)-1)
            choice_con = cons[con_choice_index]
            count+=1
        le, opt, bo = re.match(r'(.*?)(>=|<=|>|<|=)(.*)', choice_con).groups()

        letter_choice = random.choice(varlist)
        rand_num = random.choice(self.get_all_nums(obj, cons))

        new_formula = f'{le}+ {rand_num} * {letter_choice} {opt}{bo}'
        # cons.append(new_formula)
        cons[con_choice_index] = new_formula

        return cons

    def con_var_wrong(self, varlist, cons):
        '''
        change variables in constraints to random letters:
            change [x+y<=10, x>=0, y>=0]
            to [c+d<=10, c>=0, d>=0]
        
        Args:
            varlist: a string list, constains old variables in constraints.
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''
        left_letter = [ l for l in self.letters if l not in varlist]
        sample_list = random.sample(left_letter, len(varlist))
        letter_mapping = {v:l for v, l in zip(varlist, sample_list)}
        p = r"\b|\b".join(letter_mapping.keys())
        p = r'\b'+p+r'\b'
        pattern = re.compile(p)
        for i, con in enumerate(cons):
            cons[i] = pattern.sub(lambda x: letter_mapping[x.group()], con)
        return cons
    
    def con_var_wrong_hd(self, varlist, cons):
        '''
        change variables in constraints to [x1, x2, ....], iff origin variables then 
        change variables in constraints to [a, b, ....]:
            change [x+y<=10, x>=0, y>=0]
            to [x1+x2<=10, x1>=0, x2>=0]
        
        Args:
            varlist: a string list, constains old variables in constraints.
            cons: a string list, constains constraints.
        Returns:
            processed constraints list
        '''
        if 'x1' in varlist:
            left_letter = [ l for l in self.letters if l not in varlist]
            sample_list = left_letter[:len(varlist)]
        else:
            sample_list = [f'x{i}' for i in range(1,len(varlist)+1)]
        letter_mapping = {v:l for v, l in zip(varlist, sample_list)}
        p = r"\b|\b".join(letter_mapping.keys())
        p = r'\b'+p+r'\b'
        pattern = re.compile(p)
        for i, con in enumerate(cons):
            cons[i] = pattern.sub(lambda x: letter_mapping[x.group()], con)
        return cons

    def extract_obj(self, answer):
        pred_obj_pattern = r'(?:The objective is|目标)[:：]?\s*\b(Minimize|Maximize|minimize|maximize|Min|Max|min|max):?\s*(.*?)(?:(?=(?:约束|The constraints|(?:s\.t\.?)|(?:[Ss]ubject to)))|(?:(?:\.[^0-9])|。|\n))'
        pred_obj_match = re.findall(pred_obj_pattern, answer)
        obj_match = pred_obj_match[0]
    
        obj_dir = obj_match[0]
        obj_fm = obj_match[1]
        obj_fm = re.sub(self.explain_ch, '',obj_fm)
        obj_fm = re.sub(self.explain, '',obj_fm)
        index = obj_fm.find('=')
        if index != -1 :
            obj_fm = obj_fm[index+1:]
        obj = re.sub(r'\s+','',(obj_dir + ',' + obj_fm))
        
        # str
        return obj

    def extract_cons(self, answer):
        cons_pattern = r'(?:(?:The constraints are:?|约束条件[：:]|(?:s\.t\.?)|(?:[Ss]ubject to))\s*,?\s*(.*?)(?:(?:\.[^0-9]|。)))|(?:(?:The constraints are:?|约束条件[：:]|(?:s\.t\.?)|(?:[Ss]ubject to))\s*,?\s*(.*)\.$)|(?:(?:The constraints are:?|约束条件[：:]|(?:s\.t\.?)|(?:[Ss]ubject to))\s*,?\s*(.*)$)'
        cons_match = re.findall(cons_pattern,answer)[0]
        for cons in cons_match:
            if cons != '':
                cons_match = cons.strip()
                break

        cons_match = re.sub(r'(?<=\d),(\d\d\d)', r'\g<1>',cons_match)
        cons_match = re.sub(self.explain_ch, '', cons_match)
        cons_match = re.sub(self.explain, '', cons_match)
        cons = re.split(r'[,;]\s+and|and|\s*[,，]\s*|;|；',cons_match)
        cons = [re.sub(r'\s+','',c) for c in cons]
        for con in cons:
            if re.search(r'[一-龥]{2,}',con):
                cons.remove(con)
        return cons

    def parse_formula_from_answer(self, answer):
        try:
            if re.search(r'[一-龥]', answer):
                if answer[-1] == '，':
                    answer = answer[:-1]
                if answer[-1] != '。':
                    answer = answer + '。'
            else:
                if answer[-1] == ',':
                    answer = answer[:-1]
                if answer[-1] != '.':
                    answer = answer + '.'

            pred_obj = self.extract_obj(answer)
            pred_cons = self.extract_cons(answer)
        except:
            if not ('pred_obj' in locals()):
                pred_obj = ''
            if not ('pred_cons' in locals()):
                pred_cons = []

        return pred_obj, pred_cons

    def augment(self, instance, _do_copy=True):
        assert ("var_description" in instance) and ("constraint_description" in instance), KeyError(
            "Expect var_description and constraint_description to be provided.")
        
        # we should never directly change the original data
        instance = copy.deepcopy(instance)

        self.do_wrong = random.random() < self.p_wrong
        if self.do_wrong:
            obj = instance['objective_description'][1]
            varlist = list(instance['var_description'].keys())
            cons_dict = instance["constraint_description"]
            cons = list(cons_dict.keys())
            cons_values = list(instance["constraint_description"].values())
            
            # if the instance is incorrect then re-parse the wrong ojective and constraints
            do_real_wrong = random.random()<self.p_real_wrong
            # if instance['incorrect output']!= "" and do_real_wrong:
            if instance.get('incorrect output', "") != "" and do_real_wrong:
                wrong_obj, wrong_cons = self.parse_formula_from_answer(instance['incorrect output'])
                instance['objective_description'][1] = wrong_obj
                wrong_cons_dict = {con:cons_dict.get(con,[con]) for con in wrong_cons}
                instance["constraint_description"] = wrong_cons_dict
            else:
                wrong_fn = random.choices(list(self.wrong_fn_map.values()), list(self.p_wrong_map.values()))[0]
                if wrong_fn in [self.obj_dir_wrong, self.obj_coeff_wrong, self.obj_extra_wrong, self.obj_miss_wrong]:
                    wrong_obj = wrong_fn(obj)
                    instance['objective_description'][1] = wrong_obj
                elif wrong_fn in (self.obj_coeff_wrong_hd, self.obj_coeff_switch_wrong):
                    wrong_obj = wrong_fn(obj, cons)
                    instance['objective_description'][1] = wrong_obj
                elif wrong_fn == self.obj_extra_wrong_hd:
                    wrong_obj = wrong_fn(varlist, obj, cons)
                    instance['objective_description'][1] = wrong_obj
                elif wrong_fn in (self.cons_more_wrong, self.con_var_wrong, self.con_var_wrong_hd):
                    wrong_cons = wrong_fn(varlist,cons)
                    wrong_cons_dict = {con:cons_dict.get(con,[con]) for con in wrong_cons}
                    instance["constraint_description"] = wrong_cons_dict
                elif wrong_fn in (self.cons_more_wrong_hd, self.con_extra_wrong_hd):
                    wrong_cons = wrong_fn(varlist, obj, cons)
                    wrong_cons_dict = {con:cons_dict.get(con,[con]) for con in wrong_cons}
                    instance["constraint_description"] = wrong_cons_dict
                elif wrong_fn == self.cons_coeff_wrong_hd:
                    wrong_cons = wrong_fn(obj, cons)
                    wrong_cons_dict = {con:cons_dict.get(con,[con]) for con in wrong_cons}
                    instance["constraint_description"] = wrong_cons_dict
                else:
                    wrong_cons = wrong_fn(cons)
                    wrong_cons_dict = {con:cons_dict.get(con,[con]) for con in wrong_cons}
                    instance["constraint_description"] = wrong_cons_dict
            instance['label'] = 'no'
        else:
            instance['label'] = 'yes'
        instance['add_nonnegative_from_var_def'] = False
        return instance

    def augment_chat(self, chat):
        # check if the required info is in chat; if not, cannot do augmentation
        if not _check_chat_metadata(chat, ["objective_description", "var_description", "constraint_description"]):
            return chat

        metadata = None
        copied = False
        for index, utterance in enumerate(chat):
            if (
                utterance["metadata"]["task"].startswith("SimpleMathFormulationTask.A") 
                or utterance["metadata"]["task"].startswith("IncorrectMathCheckingTask.A")
            ):
                if metadata is None:
                    metadata = self.augment(chat.metadata, _do_copy=False)
                if self.do_wrong:
                    if not copied:
                        chat = chat.clone()
                        copied = True
                    chat.metadata = metadata
                    utterance["predictable"] = False

        return chat


class LPContrastiveLearningAug(TextAugmentationBase):
    def augment(self, instance):
        preferenced_instance = copy.deepcopy(instance)
        wrong_len = len(instance["wrong_instances_output"])
        random_index = random.randint(0, wrong_len-1)
        # random choice a wrong instance from wrong intance, include input, output, inferenced_logprobs
        rejected_instance = {
            "instruction": instance["instruction"],
            "input": instance["input"],
            "output": instance["wrong_instances_output"][random_index],
            "inferenced_logprobs": instance["wrong_instances_logprobs"][random_index]
        }
    
        return [preferenced_instance, rejected_instance]

    def augment_chat(self, chat):
        raise NotImplementedError("augment_chat for LPContrastiveLearningAug is not implemented.")


class TextGenerationTaskBase:
    """ Base class for text related generation task.
    """
    def process_instance(self, instance):
        return instance

    def process_chat(self, chat):
        return chat

    def __call__(self, instance):
        if isinstance(instance, dict) and "instruction" in instance:
            return self.process_instance(instance)
        elif isinstance(instance, ChatRecords):
            return self.process_chat(instance)
        else:
            raise NotImplementedError(
                "Expect input to be either dict {'instruction': ...} or ChatRecords. Got {}.".format(instance))


class SimpleMathFormulationTask(TextGenerationTaskBase):
    """ Convert info in instance to a math formulation for single-round QA, e.g.,
    <Human>:
    一家药厂计划雇用生产工人和质检员。生产工人每周赚取 800 美元，质检员每周赚取 1000 美元。药厂要求至少 50 名员工，
    其中至少 5 名必须是质检员。为了保证生产质量，质检员的人数至少应为生产工人人数的十分之一。药厂希望将每周用人成本
    控制在60,000美元以下。请建立LP模型来最小化用人成本。
    <Assistant>:
    变量：生产工人雇用人数，质检员雇用人数。分别定义为：x，y。\n目标：minimize 800 * x + 1000 * y。\n约束条件：
    x + y >= 50，y >= 1/10 * x，800 * x + 1000 * y <= 60000，y >= 5，x >= 0，y >= 0。

    If a chat is provided, we assume:
        every utterrance, which has metadata on var_description, objective_description and constraint_description, 
        is trying to answer an instruction in previous round;
    thus we reformat the utterrance in this round.

    Args:
        instance: dict with keys:
            var_description: a dict
            objective_description: a list of two string
            constraint_description: a dict
            add_nonnegative_from_var_def: bool
            language: str
    """
    def __init__(self, deduplicate=False, keep_origin_instance=False):
        self.deduplicate = deduplicate
        self.non_negative_pattern = re.compile("^[a-zA-Z]+[0-9]?[a-zA-Z]* >= 0$")
        self.keep_origin_instance=keep_origin_instance

    def get_output(self, instance):
        var_description = instance['var_description']
        objective_description = instance['objective_description']
        constraint_description = instance['constraint_description']
        if isinstance(var_description, str):
            var_description = eval(var_description)
        if isinstance(constraint_description, str):
            constraint_description = eval(constraint_description)
        lang = instance['language']

        # get variables and definitions
        var_def = list(var_description.keys())
        var_des = list(var_description.values())
        add_nonnegative_from_var_def = instance.get("add_nonnegative_from_var_def", True)
        # deduplicate variables
        if self.deduplicate:
            constraints = constraint_description.keys()
        else:
            constraints_dupl = constraint_description.values()
            constraints = [const2 for const1 in constraints_dupl for const2 in const1]
        # move non-negative constraints to the end.
        if any(re.search(self.non_negative_pattern, const) for const in constraints):
            if add_nonnegative_from_var_def:
                constraints2 = [const for const in constraints if not re.search(self.non_negative_pattern, const)] + \
                    ["{} >= 0".format(var) for var in var_description]
            else:
                constraints2 = [const for const in constraints if not re.search(self.non_negative_pattern, const)] + \
                    [const for const in constraints if re.search(self.non_negative_pattern, const)]
            assert set(constraints) == set(constraints2), ValueError(
                "Constraints do not match: {} vs {}".format(constraints, constraints2))
            constraints = constraints2
        # assert all(isinstance(const, str) and len(const) > 0 for const in constraints)
        # generate outputs
        lang = instance['language']
        if lang == 'cn':
            var_des = '，'.join(var_des)
            var_def = '，'.join(var_def)
            const_des = '，'.join(constraints)
            output = "变量：{}。分别定义为：{}。\n目标：{}。\n约束条件：{}。".format(
                var_des, var_def, objective_description[1], const_des)
        elif lang == 'en':
            var_des = ', '.join(var_des)
            var_def = ', '.join(var_def)
            const_des = ', '.join(constraints)
            output = "The variables are: {}. Define them as: {}.\nThe objective is: {}.\nThe constraints are: {}.".format(
                var_des, var_def, objective_description[1], const_des)
        else:
            raise NotImplementedError("Preprocessing for language {} is not implemented.".format(lang))

        return output

    def process_instance(self, instance):
        output = self.get_output(instance)
        if self.keep_origin_instance:
            instance['output'] = output
        else:
            instance = {
                "instruction": instance["instruction"],
                "input": instance["input"],
                "output": output,
            }
        return instance

    def process_chat(self, chat):
        # check if the required info is in chat; if not, cannot do augmentation
        if not _check_chat_metadata(chat, ["objective_description", "var_description", "constraint_description"]):
            return chat
        
        # reformat SimpleMathFormulationTask.A
        for index, utterance in enumerate(chat):
            if utterance["metadata"]["task"].startswith("SimpleMathFormulationTask.A"):
                instance = {
                    "objective_description": chat.metadata["objective_description"],
                    "var_description": chat.metadata["var_description"],
                    "constraint_description": chat.metadata["constraint_description"],
                    "language": chat.metadata["language"], 
                }
                output = self.get_output(instance)
                new_utterance = {
                    "role": utterance["role"],
                    "content": output,
                    "metadata": utterance["metadata"],
                }
                chat[index] = new_utterance

        # check for SolverResultExplanationTask
        # if solveResult is provided and not appended, insert them
        tasks = [utterance["metadata"]["task"] for utterance in chat]   
        if (
            not any(task.startswith("SolverResultExplanationTask") for task in tasks)
            and chat.metadata.get("solveResult", "") 
            and chat.metadata.get("solveResult_explanation", "")
        ):
            utterances_to_insert = [
                {
                    "role": "Human" if random.random() < 0.5 else "Solver",
                    "content": chat.metadata["solveResult"],
                    "metadata": {
                        "language": "en",
                        "task": "SolverResultExplanationTask.Q"
                    }
                },
                {
                    "role": "Assistant",
                    "content": chat.metadata["solveResult_explanation"],
                    "metadata": {
                        "language": "en",
                        "task": "SolverResultExplanationTask.A"
                    }
                },
            ]
            chat.append(utterances_to_insert[0])
            chat.append(utterances_to_insert[1])

        return chat


class IncorrectMathCheckingTask(SimpleMathFormulationTask):
    """ Convert info in instance to a ground-truth checking formulation for single-round QA, e.g.,
    <Human>:
    #问题#：一家药厂计划雇用生产工人和质检员。生产工人每周赚取 800 美元，质检员每周赚取 1000 美元。药厂要求至少 50 名员工，
    其中至少 5 名必须是质检员。为了保证生产质量，质检员的人数至少应为生产工人人数的十分之一。药厂希望将每周用人成本
    控制在60,000美元以下。请建立LP模型来最小化用人成本。
    #解答#：变量：生产工人雇用人数，质检员雇用人数。分别定义为：x，y。\n目标：minimize 800 * x + 1000 * y。\n约束条件：
    x + y >= 50，y >= 1/10 * x，800 * x + 1000 * y <= 60000，y >= 5，x >= 0，y >= 0。
    请问以上问题的解答是否正确？正确请回答'yes'，错误请回答'no'。
    <Assistant>:
    yes

    If a chat is provided, we assume:
        every utterrance, which has metadata on var_description, objective_description and constraint_description, 
        is trying to answer an instruction in previous round;
    Then we append another round of chat, e.g., 
    <Human>:
    一家药厂计划雇用生产工人和质检员。生产工人每周赚取 800 美元，质检员每周赚取 1000 美元。药厂要求至少 50 名员工，
    其中至少 5 名必须是质检员。为了保证生产质量，质检员的人数至少应为生产工人人数的十分之一。药厂希望将每周用人成本
    控制在60,000美元以下。请建立LP模型来最小化用人成本。
    <Assistant>:
    变量：生产工人雇用人数，质检员雇用人数。分别定义为：x，y。\n目标：minimize 800 * x + 1000 * y。\n约束条件：
    x + y >= 50，y >= 1/10 * x，800 * x + 1000 * y <= 60000，y >= 5，x >= 0，y >= 0。
    <Human>:
    请问以上问题的解答是否正确？正确请回答'yes'，错误请回答'no'。
    <Assistant>:
    yes

    Args:
        instance: dict with keys:
            var_description: a dict
            objective_description: a list of two string
            constraint_description: a dict
            label: yes, or no
            add_nonnegative_from_var_def: bool
            language: str
    """
    template_cn = "#问题#：{INSTRUCTION}\n#解答#：{INPUT}\n请问以上问题的解答是否正确？正确请回答'yes'，错误请回答'no'。"
    template_en = "#Question#: {INSTRUCTION}\n#Answer#: {INPUT}\nIs the answer to the above question correct? Please answer 'yes' for correct and 'no' for incorrect."
    utterances_to_insert_cn = [
        {
            "role": "Human",
            "content": "请问以上问题的解答是否正确？正确请回答'yes'，错误请回答'no'。",
            "metadata": {
                "language": "cn",
                "task": "IncorrectMathCheckingTask.Q"
            }
        },
        {
            "role": "Assistant",
            "content": "",
            "metadata": {
                "language": "cn",
                "task": "IncorrectMathCheckingTask.A"
            }
        },
    ]
    utterances_to_insert_en = [
        {
            "role": "Human",
            "content": "Is the answer to the above question correct? Please answer 'yes' for correct and 'no' for incorrect.",
            "metadata": {
                "language": "en",
                "task": "IncorrectMathCheckingTask.Q"
            }
        },
        {
            "role": "Assistant",
            "content": "",
            "metadata": {
                "language": "en",
                "task": "IncorrectMathCheckingTask.A"
            }
        },
    ]

    def process_instance(self, instance):
        output = self.get_output(instance)

        lang = instance["language"]
        if lang == 'cn':
            instruction = self.template_cn.format(INSTRUCTION=instance["instruction"], INPUT=output)
        elif lang == 'en':
            instruction = self.template_en.format(INSTRUCTION=instance["instruction"], INPUT=output)
        else:
            raise NotImplementedError("Preprocessing for language {} is not implemented.".format(lang))
        instance = {
            "instruction": instruction,
            "input": instance["input"],
            "output": instance["label"],
        }
        return instance

    def process_chat(self, chat):
        # check if the required info is in chat; if not, cannot do augmentation
        if not _check_chat_metadata(chat, ["objective_description", "var_description", "constraint_description"]):
            return chat
        
        all_utterances_to_insert = []
        for index, utterance in enumerate(chat):
            if utterance["metadata"]["task"].startswith("SimpleMathFormulationTask.A"):
                utterances_to_insert = copy.deepcopy(
                    self.utterances_to_insert_cn if chat.metadata["language"] == 'cn' else self.utterances_to_insert_en)
                utterances_to_insert[1]["content"] = chat.metadata["label"]
                all_utterances_to_insert((index, utterances_to_insert))
                # when chat.metadata["label"] is False, the metadata contains wrong info;
                # thus we also need to change the original output
                if not chat.metadata["label"]:  
                    instance = {  
                        "objective_description": chat.metadata["objective_description"],
                        "var_description": chat.metadata["var_description"],
                        "constraint_description": chat.metadata["constraint_description"],
                        "language": chat.metadata["language"], 
                    }
                    output = self.get_output(instance)
                    new_utterance = {
                        "role": utterance["role"],
                        "content": output,
                        "metadata": utterance["metadata"],
                    }
                    chat[index] = new_utterance
        if all_utterances_to_insert:
            for index, utterances_to_insert in all_utterances_to_insert[-1::-1]:
                chat.insert(index+1, utterances_to_insert[0])
                chat.insert(index+2, utterances_to_insert[1])

        return chat


class LPProblemDescriptionGenerationTask(SimpleMathFormulationTask):
    """ Convert info in instance to a math formulation. Based on a math formulation, generate 
    the problem description, e.g.,
    <Human>:
    基于下述数学表达式，生成一段线性规划问题的案例。
    目标：minimize 800 * x + 1000 * y。
    约束条件：x + y >= 50，y >= 1/10 * x，800 * x + 1000 * y <= 60000，y >= 5，x >= 0，y >= 0。
    <Assistant>:
    一家药厂计划雇用生产工人和质检员。生产工人每周赚取 800 美元，质检员每周赚取 1000 美元。药厂要求至少 50 名员工，
    其中至少 5 名必须是质检员。为了保证生产质量，质检员的人数至少应为生产工人人数的十分之一。药厂希望将每周用人成本
    控制在60,000美元以下。请建立LP模型来最小化用人成本。

    If a chat is provided, we assume the SimpleMathFormulationTask is to be replaced with this
    LPProblemDescriptionGenerationTask.

    Args:
        instance: dict with keys:
            var_description: a dict
            objective_description: a list of two string
            constraint_description: a dict
            language: str
    """
    pre_cn_instructions = [
        "请根据以下线性规划问题的数学表达式，生成问题的自然语言描述。",
        "下面我给出数学表达式，包括目标和约束，请根据表达式生成线性规划问题的案例。",
        "根据下面的表达式生成线性规划问题的描述：",
        "根据下面的目标和约束，给一个线性规划问题的例子。",
        "根据下面的数学表达式，生成一段线性规划问题的描述：",
        "根据以下目标和约束，生成一个线性规划问题的实例",
        "给定下列目标函数和约束条件，请生成一段线性规划问题的自然语言描述：",
        "根据以下线性规划问题的数学表达式，生成问题示例：",
        "已知目标函数和约束条件，请据此生成一段问题的案例。"
    ]
    pre_en_instructions = [
        "Please generate a description of LP problem based on the objective and constraints below.",
        "The objective and constraints of a LP problem are provided below. Please generate the corresponding problem description.",
        "Generate a natural language description of a linear programming problem based on the following mathematical expression:",
        "Create an example of a linear programming problem based on the following objective and constraints:",
        "Provide a description of a linear programming problem based on the following mathematical notation.",
        "Given the objective function and constraints below, generate a natural language description of a linear programming problem:",
        "Develop an example of a linear programming problem based on the following objective and constraints:",
        "Based on the objective function and constraints presented below, create a natural language example of a linear programming problem:",
        "Write a description of a linear programming problem based on the following mathematical expression."
    ]
    post_cn_instructions = [
        "请根据上述线性规划问题的数学表达式，生成问题的自然语言描述。",
        "前面我给出了数学表达式，包括目标和约束，请根据表达式生成线性规划问题的案例。",
        "根据前面的表达式生成线性规划问题的描述：",
        "根据前述的目标和约束，给一个线性规划问题的例子。",
        "根据前面的数学表达式，生成一段线性规划问题的描述：",
        "根据上述目标和约束，生成一个线性规划问题的实例",
        "按照给定的目标函数和约束条件，生成一段线性规划问题的自然语言描述：",
        "根据以上线性规划问题的数学表达式，生成问题示例：",
        "已知目标函数和约束条件，请据此生成一段问题的案例。"
    ]
    post_en_instructions = [
        "Please generate a description of LP problem based on the objective and constraints above.",
        "The objective and constraints of a LP problem are provided before. Please generate the corresponding problem description.",
        "Generate a natural language description of a linear programming problem based on the aforementioned mathematical expression:",
        "Create an example of a linear programming problem based on the aforementioned objective and constraints:",
        "Provide a description of a linear programming problem based on the mathematical notation just mentioned.",
        "Given the objective function and constraints above, generate a natural language description of a linear programming problem:",
        "Develop an example of a linear programming problem based on the above objective and constraints:",
        "Based on the objective function and constraints presented, create a natural language example of a linear programming problem:",
        "Write a description of a linear programming problem based on the aforementioned mathematical expression."
    ]

    def process_instance(self, instance):
        output = self.get_output(instance)
        math_formulations = "\n".join(output.split("\n")[1:])
        if random.random() < 0.5:
            instruction = random.choice(self.pre_cn_instructions) if instance['language'] == 'cn' \
                else random.choice(self.pre_en_instructions)
            instruction += "\n" + math_formulations
        else:
            instruction = random.choice(self.post_cn_instructions) if instance['language'] == 'cn' \
                else random.choice(self.post_en_instructions)
            instruction = math_formulations + "\n" + instruction

        instance = {
            "instruction": instruction,
            "input": instance["input"],
            "output": instance["instruction"],
        }
        return instance

    def process_chat(self, chat):
        # check if the required info is in chat; if not, cannot do augmentation
        if not _check_chat_metadata(chat, ["objective_description", "var_description", "constraint_description"]):
            return chat

        for index, utterance in enumerate(chat):
            if utterance["metadata"]["task"].startswith("SimpleMathFormulationTask.Q"):
                instance = {  
                    "objective_description": chat.metadata["objective_description"],
                    "var_description": chat.metadata["var_description"],
                    "constraint_description": chat.metadata["constraint_description"],
                    "language": chat.metadata["language"], 
                }
                instance = self.process_instance(instance)
                utterances_to_replace = [
                    {
                        "role": "Human",
                        "content": instance["instruction"],
                        "metadata": {
                            "language": "en",
                            "task": "LPProblemDescriptionGenerationTask.Q"
                        }
                    },
                    {
                        "role": "Assistant",
                        "content": instance["output"],
                        "metadata": {
                            "language": "en",
                            "task": "LPProblemDescriptionGenerationTask.A"
                        }
                    },
                ]
                chat[index] = utterances_to_replace[0]
                chat[index+1] = utterances_to_replace[1]


class SimpleFilterTask(TextGenerationTaskBase):
    def process_instance(self, instance):
        if not isinstance(instance,list):
            instances = [instance]
        else:
            instances = instance
        return [{
            "instruction": instance["instruction"],
            "input": instance["input"],
            "output": instance["output"],
            "inferenced_logprobs": instance["inferenced_logprobs"]
            } for instance in instances]

    def process_chat(self, chat):
        raise NotImplementedError()


class ValueLabelPredictionTask(TextGenerationTaskBase):
    """Prepare value label for chat.

    Note this class only works with ChatRecords.

    Args in config:
        end_of_step_id: a token id indicating the end of step (position to predict value label), 
            could be None or negative to indicate that the value label should be predicted at every token.
        false_eos_ids: List[List[int]] indicate the cases that should not be considered as the 
            end of step. 
    """
    def __init__(self, config, omit_labels=False):
        self.eos_token_id = getattr(config, "eos_token_id", -1)
        self.end_of_step_id = getattr(config, "end_of_step_id", -1)
        self.false_eos_ids = getattr(config, "false_eos_ids", [])
        assert isinstance(self.end_of_step_id, int), TypeError(
            "Expect end_of_step_id to be integer; got {}".format(self.end_of_step_id))
        assert (
            isinstance(self.false_eos_ids, list) 
            and all(isinstance(fids, list) for fids in self.false_eos_ids)
        ), TypeError("Expect false_eos_ids to be List[list[int]]; got {}".format(self.false_eos_ids))
        self.omit_labels = omit_labels

    def process_instance(self, instance):
        raise NotImplementedError("ValueLabelPredictionTask only works with chat.")

    def _locate_end_of_step_id(self, _input_ids, index, role_tag_len):
        if isinstance(_input_ids, torch.Tensor):  # cannot compare tensor to list, so we convert _input_ids to list
            _input_ids = _input_ids.to("cpu").tolist()
        if role_tag_len and (index < role_tag_len[0] or index >= len(_input_ids) - role_tag_len[1]):
            return False
        elif self.eos_token_id >= 0 and _input_ids[index] == self.eos_token_id:
            return True
        elif index == len(_input_ids) - 1:  # last token must be step end
            return True
        elif self.end_of_step_id >= 0:
            if _input_ids[index] == self.end_of_step_id:
                for false_eos_ids in self.false_eos_ids:
                    pattern = _input_ids[max(index - len(false_eos_ids) + 1, 0): (index+1)]
                    if pattern == false_eos_ids:
                        return False
                else:
                    return True
            else:
                return False
        else:
            return True

    def get_custom_tokenize(self, chat, default_value_label=None):

        def custom_tokenize(roles_to_predict):
            input_ids = chat.input_ids
            joint_value_labels = []
            for utterrane_index, (utterrance, _input_ids, role_tag_len, is_stochastic) in enumerate(
                zip(chat.__iter__(), chat._input_ids, chat._role_tag_lens, chat._stochastic)):
                trajectory_index = chat._trajectory_indices[utterrane_index] if is_stochastic else None
                _input_ids = _input_ids[trajectory_index] if is_stochastic else _input_ids
                role_tag_len = role_tag_len[trajectory_index] if is_stochastic else role_tag_len

                value_labels = utterrance["metadata"].get(
                    "value_labels", default_value_label if utterrance["role"] == "Assistant" else None)
                if value_labels is not None:
                    if (
                        isinstance(value_labels, str) 
                        and value_labels.startswith("[") 
                        and value_labels.endswith("]")
                    ):
                        value_labels = eval(value_labels)
                    value_label = value_labels[trajectory_index] if is_stochastic else value_labels  # 0 or 1
                    # do not predict wrong answer
                    if value_label == 0:
                        chat._predictable[utterrane_index] = False

                    # get extended_value_labels by locating end_of_step_id in _input_ids
                    extended_value_labels = [
                        value_label if self._locate_end_of_step_id(_input_ids, index, role_tag_len) else -100 \
                            for index in range(len(_input_ids))]
                else:
                    extended_value_labels = [-100] * len(_input_ids)
                
                joint_value_labels += extended_value_labels
                
            # check max_seq_length
            if chat.max_seq_length and chat.max_seq_length < len(joint_value_labels):
                joint_value_labels = joint_value_labels[:chat.max_seq_length]  # TODO introduce truncation_side
            
            if getattr(self, "omit_labels", False):  # self refers to ValueLabelPredictionTask
                tokenized_chat = {
                    "input_ids": input_ids,
                    "attention_mask": [1] * len(input_ids),
                    "value_labels": joint_value_labels
                }
            else:
                labels = chat.labels(roles_to_predict, chat._trajectory_indices)
                tokenized_chat = {
                    "input_ids": input_ids,
                    "attention_mask": [1] * len(input_ids),
                    "labels": labels,
                    "value_labels": joint_value_labels
                }
            return tokenized_chat

        return custom_tokenize

    def process_chat(self, chat, default_value_label=None):

        custom_tokenize = self.get_custom_tokenize(chat, default_value_label)

        chat.tokenize = custom_tokenize

        return chat


class ValueLabelContrastiveLearningTask(ValueLabelPredictionTask):
    """Prepare value label for contrastive learning on chat.

    Note this class only works with ChatRecords.

    Args in config:
        end_of_step_id: a token id indicating the end of step (position to predict value label), 
            could be None or negative to indicate that the value label should be predicted at every token.
        false_eos_ids: List[List[int]] indicate the cases that should not be considered as the 
            end of step.
        num_negative_value_samples: num negative samples to acquire from the sample list. The label is given
            by value_labels of metadata. When the number of negative samples is smaller than this value, 
            all the negative samples are returned.
        num_positive_value_samples: similar to num_negative_value_samples. Suggest to set this to 1.
    """
    def __init__(self, config, omit_labels=False):
        super().__init__(config, omit_labels=omit_labels)
        self.num_positive_value_samples = getattr(config, "num_positive_value_samples", 1)
        self.num_negative_value_samples = getattr(config, "num_negative_value_samples", 1)

    def get_custom_tokenize(self, chat, default_value_label=None):

        custom_tokenize_by_one = super().get_custom_tokenize(chat, default_value_label)

        def custom_tokenize(roles_to_predict):
            # get all sets of possible trajectory_indices
            num_samples = 1
            trajectory_indices = [-1] * len(chat._input_ids)
            for index, (utterrance, _input_ids, is_stochastic) in enumerate(
                zip(chat.__iter__(), chat._input_ids, chat._stochastic)):
                if is_stochastic:
                    value_labels = utterrance["metadata"].get(
                        "value_labels", default_value_label if utterrance["role"] == "Assistant" else None)
                    if value_labels is not None:
                        if (
                            isinstance(value_labels, str) 
                            and value_labels.startswith("[") 
                            and value_labels.endswith("]")
                        ):
                            value_labels = eval(value_labels)  # get value_labels in list
                        # sample positive samples and negative samples without replacement
                        pos_value_indices = [index for index, value in enumerate(value_labels) if value == 1]
                        neg_value_indices = [index for index, value in enumerate(value_labels) if value == 0]
                        if self.num_positive_value_samples < len(pos_value_indices):
                            pos_value_indices = random.sample(pos_value_indices, self.num_positive_value_samples)
                        if self.num_negative_value_samples < len(neg_value_indices):
                            neg_value_indices = random.sample(neg_value_indices, self.num_negative_value_samples)
                        trajectory_indices[index] = pos_value_indices + neg_value_indices

                        _current_num_samples = len(trajectory_indices[index])
                        if num_samples == 1:
                            num_samples = _current_num_samples
                        elif num_samples != _current_num_samples:
                            raise RuntimeError("Number of samples are inconsistent across utterrances of one chat. This may be caused by different numbers of value labels provided.")
                    else:
                        trajectory_indices[index] = random.choice(range(len(_input_ids)))
                else:
                    trajectory_indices[index] = 0

            chat._use_existing_trajectory_indices = True
            
            tokenized_chats = []
            group_index = random.randint(1, 65536)  # assign a random group id
            for sample_index in range(num_samples):
                sample_t_ids = [t_id if isinstance(t_id, int) else t_id[sample_index] for t_id in trajectory_indices]
                chat._trajectory_indices = sample_t_ids
                chat.tokenize_by_one = custom_tokenize_by_one

                tokenized_chat = chat.tokenize_by_one(roles_to_predict)
                tokenized_chat["group_index"] = group_index
                tokenized_chats.append(tokenized_chat)
                
            return tokenized_chats

        return custom_tokenize


class TokenizedPromptProcessorWithDA(TokenizedPromptProcessor):
    """ This class extends TokenizedPromptProcessor with data augmentations.
    It does:
    1. Applying a few data augmentations to the data dict instance.
    2. processing data dict instance into prompt;
    3. tokenizing the prompt.
    As usually the data augmentation is applied epoch-wise, this class is better used as a postprocessor.
    
    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        config: (:class:`DataArguments`): Any DataArguments with the following attributes: 
            train_on_inputs: whether to calculate training loss on instruction/inputs, or sololy on outputs.
            max_seq_length:
            prompt_templates: the template to use in prompts. Could be:
                - str in [`simple_qa`, `simplest_qa`, `alpaca`]: to use preset templates
                - a list of two strings, corresponding to prompt with input placeholder, and prompt without input placeholder.
            augmentations: a list, with each element to be:
                - a string indicating the augmentation method (without arguments).
                - a dict {"name": x, ...} indicating the augmentation method and other arguments.
                - an instance of subclass of TextAugmentationBase
                - a callable augmentation method.
            do_tokenization: bool
            data_generation_task: str or list of str
                - SimpleMathFormulation
                - LPProblemDescriptionGeneration
                - IncorrectMathChecking
    """
    supported_tasks = [
        'IncorrectMathChecking', 'SimpleMathFormulation', 'LPProblemDescriptionGeneration',
        'SimpleFilter'
    ]
    supported_augmentations = {
        'symmetric_shuffle': LPSymmetricShuffleAug,
        'incorrect_answer': LPIncorrectAnswerAug,
        'contrastive_learning': LPContrastiveLearningAug
        }

    def __init__(
        self,
        tokenizer: PreTrainedTokenizer, 
        config,
        is_eval=False,
        ):
        super().__init__(tokenizer, config, is_eval=is_eval)
        # check task
        generation_task = getattr(config, "data_generation_task", "SimpleMathFormulation")

        if isinstance(generation_task, str):
            self.task = [self._get_generation_task(generation_task, config)]
        elif isinstance(generation_task, list):
            self.task = [self._get_generation_task(task, config) for task in generation_task]
        else:
            raise NotImplementedError("generation_task {} cannot be parsed.".format(generation_task))
        # check augmentation
        augmentations = getattr(config, "data_augmentations", [])
        if augmentations is None:
            augmentations = []
        if isinstance(augmentations, (str, dict, TextAugmentationBase)) or callable(augmentations):
            augmentations = [augmentations]
        assert isinstance(augmentations, list), TypeError(
            "Expect config.augmentations to be a list or None; got {}".format(type(augmentations)))
        self.augmentations = []
        for index, aug in enumerate(augmentations):
            if isinstance(aug, str):
                assert aug in self.supported_augmentations, NotImplementedError(
                    "Index {} augmentation {} not implemented.".format(index, aug))
                self.augmentations.append(self.supported_augmentations[aug]())
            elif isinstance(aug, dict):
                method = aug.get("name", None)
                if method:
                    assert method in self.supported_augmentations, NotImplementedError(
                        "Index {} augmentation {} not implemented.".format(index, method))
                    self.augmentations.append(self.supported_augmentations[method](
                        **{name: value for name, value in aug.items() if name != "name"}))
                else:
                    raise KeyError("Index {} augmentation does not have name key.".format(index))
            elif callable(aug) or isinstance(aug, TextAugmentationBase):
                self.augmentations.append(aug)
            else:
                raise ValueError("Cannot parse index {} augmentation; got {}.".format(index, aug))

        self.do_tokenization = getattr(config, "do_tokenization", True)
        self.do_contrastive_learning = any([True for aug in self.augmentations if isinstance(aug, LPContrastiveLearningAug)])
        if 'IncorrectMathChecking' in generation_task:
            self.eval_augmentations = self.augmentations
            self.eval_task = self.task
        else:
            self.eval_augmentations = []
            self.eval_task = [self._get_generation_task('SimpleMathFormulation', config)]
      
    def _get_generation_task(self, generation_task, config):
        if generation_task == 'IncorrectMathChecking':
            deduplicate = getattr(config, "data_const_deduplicate", False)
            task = IncorrectMathCheckingTask(deduplicate)
        elif generation_task == 'SimpleMathFormulation':
            deduplicate = getattr(config, "data_const_deduplicate", False)
            keep_origin_instance = getattr(config, "keep_origin_instance", False)
            task = SimpleMathFormulationTask(deduplicate, keep_origin_instance)
        elif generation_task == 'LPProblemDescriptionGeneration':
            deduplicate = getattr(config, "data_const_deduplicate", False)
            task = LPProblemDescriptionGenerationTask(deduplicate)
        elif generation_task == 'SimpleFilter':
            task = SimpleFilterTask()
        elif generation_task in {'SimpleQA', 'SimpleRL'}:
            task = TextGenerationTaskBase()
        elif generation_task in {"ValueLabelPrediction", "ValueLabelPredictionWithoutGeneration"}:
            task = ValueLabelPredictionTask(config, omit_labels=generation_task=="ValueLabelPredictionWithoutGeneration")
        elif generation_task in {"ValueLabelContrastiveLearning", "ValueLabelContrastiveLearningWithoutGeneration"}:
            task = ValueLabelContrastiveLearningTask(config, omit_labels=generation_task=="ValueLabelContrastiveLearningWithoutGeneration")
        else:
            raise NotImplementedError("Task {} cannot be parsed. Currently support: {}".format(
                generation_task, self.supported_tasks))
        return task

    @staticmethod
    def _is_str_dict(value):
        return isinstance(value, str) and value.startswith("{") and value.endswith("}")

    def _eval_str_dict(self, instances):
        """load_dataset does not support nested dict. When instance is nested dict, it has to be 
        saved as string. So here we recover them.
        """
        if isinstance(instances, list):
            instances = [self._eval_str_dict(instance) for instance in instances]
        elif isinstance(instances, dict):
            keys_to_ignore = ["content", "output", "instruction", "solveResult"]
            instances = {
                key: value if key in keys_to_ignore else self._eval_str_dict(value) for key, value in instances.items()}
        elif self._is_str_dict(instances):
            instances = eval(instances)
        
        return instances

    def __call__(self, instance):
        # when used as a postprocessor, instance is a dict with batched values.
        # we convert it to bacthed dict, do processing, then convert it back.
        is_batched = isinstance(instance['instruction'], list) or (self.do_contrastive_learning and not self.is_eval)
        if is_batched:
            num_samples = len(instance['instruction'])
            instances = [{key: value[index] for key, value in instance.items()} for index in range(num_samples)]
        else:
            instances = [instance]

        # restore dict values.
        instances = self._eval_str_dict(instances)
        # do augmentation, post-processing and tokenization
        tokenized_prompts = []
        for instance in instances:
            if self.is_eval == False:
                for augmentation in self.augmentations:
                    instance = augmentation(instance)
                for task in self.task:
                    instance = task.process_instance(instance)
            else:
                for augmentation in self.eval_augmentations:
                    if augmentation is not None:
                        instance = augmentation(instance)
                for task in self.eval_task:
                    instance = task.process_instance(instance)
            if not isinstance(instance, list):
                instances = [instance]
            else:
                instances = instance
            for instance in instances:
                if self.do_tokenization:
                    tokenized_prompt = super().__call__(instance)
                    # if self.is_eval and self.do_contrastive_learning:
                    #     tokenized_prompt = super().__call__(instance[0::2])
                    # else:
                    #     tokenized_prompt = super().__call__(instance)
                else:
                    tokenized_prompt = instance
                if self.do_contrastive_learning and instance.get("inferenced_logprobs", None) != None:
                    tokenized_prompt["inferenced_logprobs"] = instance["inferenced_logprobs"]
                tokenized_prompts.append(tokenized_prompt)
        
        if is_batched:
            tokenized_prompts = {key: [tp[key] for tp in tokenized_prompts] for key in tokenized_prompts[0]}
        else:
            tokenized_prompts = tokenized_prompts[0]

        return tokenized_prompts


class PromptProcessorWithDA(TokenizedPromptProcessorWithDA):
    """Similar to TokenizedPromptProcessorWithDA, this class extends TokenizedPromptProcessor 
    with data augmentations. However, the processed instances without tokenization are returned.
    """
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer, 
        config
        ):
        super().__init__(tokenizer, config)
        self.do_tokenization = False


class TokenizedChatProcessorWithDA(TokenizedPromptProcessorWithDA):
    """ This class extends TokenizedPromptProcessorWithDA with ChatRecords.
    It does:
    1. Applying a few data augmentations to the data dict instance.
    2. Process data dict instance into ChatRecords.
    3. get the input_ids and labels from chat.
    As usually the data augmentation is applied epoch-wise, this class is better used as a postprocessor.
    
    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        config: (:class:`DataArguments`): Any DataArguments with the following attributes: 
            train_on_inputs: whether to calculate training loss on instruction/inputs, or sololy on outputs.
            max_seq_length:
            prompt_templates: the template to use in prompts. Could be:
                - str in [`simple_qa`, `simplest_qa`, `alpaca`]: to use preset templates
                - a list of two strings, corresponding to prompt with input placeholder, and prompt without input placeholder.
            augmentations: a list, with each element to be:
                - a string indicating the augmentation method (without arguments).
                - a dict {"name": x, ...} indicating the augmentation method and other arguments.
                - an instance of subclass of TextAugmentationBase
                - a callable augmentation method.
            do_tokenization: bool
            data_generation_task: str or list of str
                - SimpleMathFormulation
                - LPProblemDescriptionGeneration
                - IncorrectMathChecking
    """
    _get_chat = TokenizedChatProcessor._get_chat

    def __init__(
        self,
        tokenizer: PreTrainedTokenizer, 
        config,
        is_eval=False,
        image_processor: BaseImageProcessor = None,
        ):
        super().__init__(tokenizer, config, is_eval=is_eval)
        self.prepend_token_to = None
        self.append_token_to = {'Assistant': tokenizer.eos_token}
        if getattr(config, "roles_to_predict", None):
            self.roles_to_predict = config.roles_to_predict
        elif not config.train_on_inputs:  # keep it for back compatibility
            if config.prompt_templates == 'qwen_chat':
                self.roles_to_predict = ['assistant']
            else:
                self.roles_to_predict = ['Assistant']
        else:
            self.roles_to_predict = None
        self.prompt_templates = config.prompt_templates
        # image related
        self.image_processor = image_processor
        # others
        self.role_tags = getattr(config, "role_tags", {})
        self.role_map = getattr(config, "role_map", {})

    def _get_num_samples(self, instance):
        if self.do_contrastive_learning:
            return len(instance["instruction"])
        elif "instruction" in instance:
            return len(instance["instruction"]) if isinstance(instance["instruction"], list) else 1
        elif "chat" in instance:
            return len(instance["chat"]) if isinstance(instance["chat"][0], list) else 1
        else:
            raise KeyError("Expect instance to have key `instruction` or `chat`; got {}".format(
                list(instance.keys())))

    def __call__(self, instance):
        # when used as a postprocessor, instance is a dict with batched values.
        # we convert it to bacthed dict, do processing, then convert it back.
        num_samples = self._get_num_samples(instance)
        is_batched = (
            num_samples > 1
            or (True if "instruction" in instance and isinstance(instance["instruction"], list) else False)
            or (True if "chat" in instance and isinstance(instance["chat"][0], list) else False)
        )
        if is_batched:
            instances = [{key: value[index] for key, value in instance.items()} for index in range(num_samples)]
        else:
            instances = [instance]

        # restore dict values.
        instances = self._eval_str_dict(instances)      
        # do augmentation, post-processing and tokenization
        tokenized_chats = []
        for instance in instances:
            if (instance.get("instruction", None) is not None) and (instance.get("chat", None) is None):
                instance = convert_prompt_instance_to_chat(instance, image_processor=self.image_processor)
            # convert instance to chat
            # delay the tokenization process by omitting self.tokenizer
            chat = self._get_chat(instance)

            # get the preferenced answer instance
            if self.do_contrastive_learning:
                for task in self.task:
                    chat = task.process_chat(chat)
                tokenized_chat = chat.tokenize(self.roles_to_predict)
                tokenized_chats.append(tokenized_chat)
            for augmentation in self.augmentations:
                chat = augmentation(chat)
            
            for task in self.task:
                chat = task.process_chat(chat)
            if self.do_tokenization:
                tokenized_chat = chat.tokenize(self.roles_to_predict)
            else:
                tokenized_chat = {"messages": [utterrance for utterrance in chat]}  # convert back to list as ChatRecords is not supported in accelerate
                if chat.system_prompt and chat[0]['role'] != 'system':
                    tokenized_chat["messages"] = [{'role': 'system', 'content': chat.system_prompt}] + tokenized_chat["messages"]
            if isinstance(tokenized_chat, dict):
                tokenized_chats.append(tokenized_chat)
            elif isinstance(tokenized_chat, list) and all(isinstance(tc, dict) for tc in tokenized_chat):
                tokenized_chats.extend(tokenized_chat)
            else:
                raise TypeError("Expect tokenized_chat to be dict or List[dict]; got {}".format(
                    type(tokenized_chat[0]) if isinstance(tokenized_chat, list) else type(tokenized_chat)))
        
        if is_batched:  # TODO get batched processed_chat
            tokenized_chats = {key: [tp[key] for tp in tokenized_chats] for key in tokenized_chats[0]}
        else:
            tokenized_chats = tokenized_chats[0]

        return tokenized_chats


class MyLlamaTokenizer(LlamaTokenizer):
    """ This class fixes several bus of LlamaTokenizer, such that it
    - adds no space to the beginning of a sentence in decoding;
    - does not truncate partial Chinese words;
    """
    # no_prefix_space_tokens seems to be removed in latest transformers
    # @LlamaTokenizer.no_prefix_space_tokens.getter
    # def no_prefix_space_tokens(self):
    #     if self._no_prefix_space_tokens is None:
    #         vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
    #         self._no_prefix_space_tokens = {tok for i, tok in enumerate(vocab) if not tok.startswith("▁")}
    #     return self._no_prefix_space_tokens
    
    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        current_sub_tokens = []
        out_string = ""
        prev_is_special = True
        for token in tokens:
            # make sure that special tokens are not decoded using sentencepiece model
            if token in self.all_special_tokens:
                if not prev_is_special:
                    out_string += " "
                out_string += self.sp_model.decode(current_sub_tokens) + token
                prev_is_special = True
                current_sub_tokens = []
            else:
                current_sub_tokens.append(token)
                prev_is_special = False
        out_string += self.sp_model.decode(current_sub_tokens)
        # out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
        return out_string
    
    def adjust_truncation_for_Chinese_words(self, ids, overflowing_tokens):
        # prevent truncating partial Chinese words
        if self.truncation_side=='left':
            truncated_text = self.decode(ids[:6])
            if truncated_text.startswith("�"):
                num_additional_tokens_to_remove = len(truncated_text) - len(truncated_text.lstrip("�"))
                overflowing_tokens = overflowing_tokens + ids[:num_additional_tokens_to_remove]
                ids = ids[num_additional_tokens_to_remove:]
        elif self.truncation_side=='right':
            truncated_text = self.decode(ids[-6:])
            if truncated_text.endswith("�"):
                num_additional_tokens_to_remove = len(truncated_text) - len(truncated_text.rstrip("�"))
                overflowing_tokens = ids[-num_additional_tokens_to_remove:] + overflowing_tokens
                ids = ids[:-num_additional_tokens_to_remove]
        
        return ids, overflowing_tokens

    def truncate_sequences(
        self,
        ids: List[int],
        pair_ids: Optional[List[int]] = None,
        num_tokens_to_remove: int = 0,
        truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
        stride: int = 0,
    ) -> Tuple[List[int], List[int], List[int]]:
        ids, pair_ids, overflowing_tokens = super().truncate_sequences(
            ids=ids, pair_ids=pair_ids, 
            num_tokens_to_remove=num_tokens_to_remove, 
            truncation_strategy=truncation_strategy, 
            stride=stride)
        if ids and overflowing_tokens:
            ids, overflowing_tokens = self.adjust_truncation_for_Chinese_words(ids, overflowing_tokens)
        elif pair_ids and overflowing_tokens:
            pair_ids, overflowing_tokens = self.adjust_truncation_for_Chinese_words(pair_ids, overflowing_tokens)

        return (ids, pair_ids, overflowing_tokens)
    

class TextGenerator:
    """ Text preprocessing and postprocessing class that
    - fill in the template based on input instance
    - tokenize the prompt
    - call the model.generate
    - post-process the responses

    Args:
        model
        config
        tokenizer
        prompt_templates:
        post_process_fn: a callable function or a list of callable functions. 
            responses = post_process_fn(generation_output, responses)
        num_parallel_adapters: provide this if multiple adapters are used in parallel to provide 
            outputs for each input.
        enforce_cn_chars: for some tokenizer, provide this to enforce that partial Chinese are not tokenized.
            This is deprecated.
        max_mu_seq_len: The maximum allowed moving average mean of sequence length, used to dynamically 
            control batch size. Set it to 0 to disable it.
    """
    def __init__(
        self, 
        model: torch.nn.Module, 
        config: GenerationConfig, 
        tokenizer: PreTrainedTokenizer, 
        prompt_templates: Union[List[str], str] = None,
        system_prompt: str = None,
        post_process_fn = None,
        num_parallel_adapters = 1,
        enforce_cn_chars: bool = True,  # This is deprecated.
        max_mu_seq_len: int = 0,  # 0 means not effective
        ) -> None:
        self.model = model
        self.model.generation_config = config
        self.config = config
        self.tokenizer = tokenizer
        self.streamer = TextIteratorStreamer(
            self.tokenizer, skip_prompt=True) if getattr(self.config, "streaming", False) else None
        self.prompt_w_input, self.prompt_wo_input = get_prompt_templates(prompt_templates)
        self.prompt_templates = prompt_templates
        self.system_prompt = system_prompt

        if post_process_fn:
            if not isinstance(post_process_fn, list):
                post_process_fn = [post_process_fn]
            assert all(callable(fn) for fn in post_process_fn), ValueError(
                "Expect post_process_fn to be callable; got {}.".format(post_process_fn))
        self.post_process_fn = post_process_fn  # None or list

        self.num_parallel_adapters = num_parallel_adapters
        self.molora_strategy = None
        if hasattr(self.model, 'peft_config'):
            self.molora_strategy = getattr(
                list(self.model.peft_config.values())[0], 'molora_strategy', 'average')
        if self.num_parallel_adapters > 1 and self.molora_strategy == 'parallel' and self.streamer:
            raise NotImplementedError(
                "Streaming cannot be used together with num_parallel_adapters > 1 and parallel molora_strategy.")

        if enforce_cn_chars:
            self.prefix_allowed_tokens_fn = PrefixCNCharacterTokenGen(
                self.tokenizer, num_scores=self.model.get_input_embeddings().num_embeddings)
        else:
            self.prefix_allowed_tokens_fn = None

        self.logits_processor = LogitsProcessorList()
        if "ChatGLM" in self.model._get_name():
            self.logits_processor.append(InvalidScoreLogitsProcessor())
            if self.streamer:
                raise NotImplementedError("Streaming cannot be used together with ChatGLM.")
        
        # mean sequence length, used to dynamically control batch size
        self._mu_seq_len = None
        self.max_mu_seq_len = max_mu_seq_len
        # if prompts is a List[List[str]], we need to restore the structure for responses
        self._sub_list_size = None

    def fill_prompt_template(self, instruction, inputs, **kwargs):
        if isinstance(self.prompt_w_input, str) or isinstance(self.prompt_wo_input, str):
            system = {"system": kwargs.get("system", None) or self.system_prompt}
            if inputs:
                prompt_w_input = fill_system_prompt(self.prompt_w_input, system)
                return prompt_w_input.format(INSTRUCTION=instruction, INPUT=inputs, OUTPUT="")
            else:
                prompt_wo_input = fill_system_prompt(self.prompt_wo_input, system)
                return prompt_wo_input.format(INSTRUCTION=instruction, OUTPUT="")
        else:  # apply build-in templates in tokenizer
            if self.prompt_templates == "granite_guardian_qa":
                prompt = get_granite_guardian_qa_prompt(self.tokenizer, instruction, inputs, **kwargs)
            else:
                raise NotImplementedError("Prompt template {} has not been implemented.".format(self.prompt_templates))
            return prompt

    def _get_input_ids(self, prompts):
        tokenized_prompts = self.tokenizer(  # dict with keys ['input_ids', 'attention_mask', 'labels']
            prompts,
            padding=True,
            return_tensors="pt",
        )
        input_ids = tokenized_prompts["input_ids"].to(self.model.device)

        if self.num_parallel_adapters > 1 and self.molora_strategy == 'parallel':
            input_ids = input_ids.repeat(self.num_parallel_adapters, 1)

        return input_ids

    def _unflatten(self, somelist):
        assert len(somelist) == sum(self._sub_list_size), RuntimeError(
            f"Number of responses does not match that of flatten list of prompts; got {len(somelist)} and {sum(self._sub_list_size)}")

        newlist = []
        s_index = e_index = 0
        for sub_list_size in self._sub_list_size:
            e_index += sub_list_size
            newlist.append(somelist[s_index:e_index])
            s_index = e_index

        return newlist

    def _get_responses(self, generation_output, input_ids):
        output_ids = generation_output.sequences
        outputs = [self.tokenizer.decode(output, skip_special_tokens=True) for output in output_ids]
        prompts = [self.tokenizer.decode(_input_ids, skip_special_tokens=True) for _input_ids in input_ids]
        responses = [
            output.split(prompt)[1].strip() if output.startswith(prompt) else \
                output.strip() for output, prompt in zip(outputs, prompts)]

        if "ChatGLM" in self.model._get_name():
            responses = [chatglm_process_response(response) for response in responses]

        if self.num_parallel_adapters > 1 and self.molora_strategy == 'parallel':
            chunk_size = len(responses) // self.num_parallel_adapters
            responses = [[responses[index_x + index_y * chunk_size] for index_y in range(self.num_parallel_adapters)] for index_x in range(chunk_size)]
        
        if self.post_process_fn:  # None or list
            for post_process_fn in self.post_process_fn:
                responses = post_process_fn(generation_output, responses)

        return responses

    def response(self, instruction, inputs=None, **kwargs):
        prompt = self.fill_prompt_template(instruction, inputs, **kwargs)
        input_ids = self._get_input_ids(prompt)
        with torch.no_grad():
            generation_output = self.model.generate(
                input_ids=input_ids, 
                generation_config=self.config, 
                return_dict_in_generate=True,
                logits_processor=self.logits_processor,
                prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
                )
        response = self._get_responses(generation_output, input_ids)
        if isinstance(response, list) and len(response) == 1:
            response = response[0]

        return response
    
    def __call__(self, instruction, inputs=None, **kwargs):
        return self.response(instruction, inputs, **kwargs)

    def stream_response(self, instruction, inputs=None, **kwargs):
        prompt = self.fill_prompt_template(instruction, inputs, **kwargs)
        input_ids = self._get_input_ids(prompt)
        self.model.eval()
        generation_kwargs = dict(
            input_ids=input_ids, 
            streamer=self.streamer, 
            generation_config=self.config, 
            return_dict_in_generate=True,
            logits_processor=self.logits_processor,
            prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
            )
        with torch.no_grad():
            thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
            thread.start()

        return self.streamer

    def batch_response(self, instances):
        # instances List[Dict["instruction", "input", "output"]] or List[str]
        if isinstance(instances[0], dict):
            prompts = [self.fill_prompt_template(
                instance["instruction"], instance["input"], 
                **{key: value for key, value in instance.items() if key not in {"instruction", "input"}}) for instance in instances]
        elif isinstance(instances[0], (list, tuple)):
            prompts = [self.fill_prompt_template(
                instance[0], instance[1], system=instance[2] if len(instance) > 2 else None) for instance in instances]
        elif isinstance(instances[0], str):
            prompts = [self.fill_prompt_template(instance, None, None) for instance in instances]

        # in case prompts is a List[List[str]], flatten it
        if isinstance(prompts, list) and isinstance(prompts[0], list):  
            self._sub_list_size = [len(prompt) for prompt in prompts]
            prompts = [_prompt for prompt in prompts for _prompt in prompt]  # flatten the list
        input_ids = self._get_input_ids(prompts)

        # divide current batch into sub-batches to avoid exceptionally long sequences
        indices = None
        batch_size = input_ids.shape[0]
        sub_batches = [input_ids]
        sub_prompts = [prompts]
        if batch_size > 1:
            # get the mean stats of input_ids
            seq_len = (input_ids != self.tokenizer.pad_token_id).sum(dim=1)
            mu_seq_len = seq_len.float().mean().item()
            if self._mu_seq_len is None:
                self._mu_seq_len = mu_seq_len
            else:
                self._mu_seq_len = 0.99 * self._mu_seq_len + 0.01 * mu_seq_len
            if self.max_mu_seq_len > 0:
                self._mu_seq_len = min(self._mu_seq_len, self.max_mu_seq_len)
            # keep using this batch as it is would possibly double the memory; 
            # so we resort to the safest stategy, we divide them into sub-batches
            # such that the total seq length of sub-batches sums up to self._mu_seq_len * batch_size
            if any(seq_len > 2 * self._mu_seq_len):
                seq_len_sorted, indices = torch.sort(seq_len, dim=-1, descending=False)
                threshold = self._mu_seq_len * batch_size
                sub_batches = []
                sub_prompts = []
                sub_batch_indices = []
                seq_len_sum = 0
                use_flash_attn = getattr(self.model.config, "use_flash_attn", False)
                for index, seq_len in zip(indices, seq_len_sorted):
                    new_seq_len_sum = (seq_len_sum + seq_len) if use_flash_attn else (len(sub_batch_indices) + 1) * seq_len
                    if new_seq_len_sum <= threshold:
                        seq_len_sum = new_seq_len_sum
                        sub_batch_indices.append(index.item())
                    else:
                        if sub_batch_indices:
                            sub_batch = torch.cat([input_ids[_index].unsqueeze(0) for _index in sub_batch_indices], dim=0)
                            max_seq_len = (sub_batch != self.tokenizer.pad_token_id).sum(dim=1).max()
                            sub_batch = sub_batch[:, -max_seq_len:]
                            sub_batches.append(sub_batch)
                            sub_prompts.append([prompts[_index] for _index in sub_batch_indices])
                        sub_batch_indices = [index.item()]
                        seq_len_sum = seq_len
                if sub_batch_indices:
                    sub_batch = torch.cat([input_ids[_index].unsqueeze(0) for _index in sub_batch_indices], dim=0)
                    max_seq_len = (sub_batch != self.tokenizer.pad_token_id).sum(dim=1).max()
                    sub_batch = sub_batch[:, -max_seq_len:]
                    sub_batches.append(sub_batch)
                    sub_prompts.append([prompts[_index] for _index in sub_batch_indices])

        responses = []
        for batch, batch_prompts in zip(sub_batches, sub_prompts):
            with torch.no_grad():
                generation_output = self.model.generate(
                    input_ids=batch, 
                    generation_config=self.config, 
                    return_dict_in_generate=True,
                    logits_processor=self.logits_processor,
                    prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
                    )
            # every element of response is a list
            responses.extend(self._get_responses(generation_output, batch))
        if indices is not None:
            _, inv_indices = torch.sort(indices, dim=-1, descending=False)
            responses = [responses[index] for index in inv_indices]

        if self._sub_list_size:
            responses = self._unflatten(responses)
            self._sub_list_size = None

        return responses


class ProbsGenerator(TextGenerator):
    def __init__(self, 
        model: torch.nn.Module, 
        config: GenerationConfig, 
        tokenizer: PreTrainedTokenizer, 
        prompt_templates: Union[List[str], str] = None,
        system_prompt: str = None,
        post_process_fn = None,
        num_parallel_adapters = 1,
        enforce_cn_chars: bool = True,):
        super(ProbsGenerator, self).__init__(
            model, 
            config, 
            tokenizer, 
            prompt_templates=prompt_templates, 
            system_prompt=system_prompt, 
            post_process_fn=post_process_fn, 
            num_parallel_adapters=num_parallel_adapters, 
            enforce_cn_chars=enforce_cn_chars)
        self.max_seq_length=8192

    def fill_prompt_template(self, instruction, inputs, output, system=None):
        system = {"system": system} if system else {}
        if inputs:
            prompt_w_input = fill_system_prompt(self.prompt_w_input, system)
            return prompt_w_input.format(INSTRUCTION=instruction, INPUT=inputs, OUTPUT=output)
        else:
            prompt_wo_input = fill_system_prompt(self.prompt_wo_input, system)
            return prompt_wo_input.format(INSTRUCTION=instruction, OUTPUT=output)

    def _tokenize(self, prompts):
        prompts = [prompt if prompt.endswith(self.tokenizer.eos_token) else prompt + self.tokenizer.eos_token for prompt in prompts]
        tokenized_prompts = self.tokenizer(  # dict with keys ['input_ids', 'attention_mask', 'labels']
            prompts,
            padding=True,
            return_tensors="pt",
        )
        for k,v in tokenized_prompts.items():
            tokenized_prompts[k] = v.to(self.model.device)

        tokenized_prompts["labels"] = torch.clone(tokenized_prompts["input_ids"])
    
        return tokenized_prompts
      
    def _get_answer_labels(self, data_instances, labels):
        if isinstance(data_instances[0], dict):
            partial_prompts = [self.fill_prompt_template({**data_instance, "output": ""}) for data_instance in data_instances]
        elif isinstance(data_instances[0], (list, tuple)):
            partial_prompts = [self.fill_prompt_template(data_instance[0], data_instance[1], "") for data_instance in data_instances]
        elif isinstance(data_instances[0], str):
            partial_prompts = [self.fill_prompt_template(data_instance, None) for data_instance in data_instances]
        partial_tokenized_prompts = self._tokenize(partial_prompts)
        masked_labels = []
        for input_ids, label in zip(partial_tokenized_prompts["input_ids"], labels):
            net_answer = label[label!=0][len(input_ids[input_ids!=0])-1:]
            masked_labels.append((len(label)-len(net_answer))*[-100]+net_answer.tolist())
        # labels = torch.cat((torch.tensor([-100] * prompt_len, device=labels.device),labels[prompt_len:]))
        # labels = [-100] * prompt_len+labels[prompt_len:]
        return torch.tensor(masked_labels, device=labels.device)

    def _get_batch_logps(self, logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False) -> torch.FloatTensor:
        """Compute the log probabilities of the given labels under the given logits.

        Args:s
            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
            labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)
            average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.

        Returns:
            A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
        """
        assert logits.shape[:-1] == labels.shape

        labels = labels[:, 1:].clone()
        logits = logits[:, :-1, :]
        loss_mask = (labels != -100)

        # dummy token; we'll ignore the losses on these tokens later
        labels[labels == -100] = 0

        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
        
        return (per_token_logps * loss_mask)

    def _get_answer_probs(self, generation_output, labels):
        output_logits = generation_output.logits
        probs = self._get_batch_logps(output_logits, torch.tensor(labels, device=output_logits.device))

        return probs

    def batch_response(self, instances, **kwargs):
        # instances List[Dict["instruction", "input", "output"]] or List[str]
        if isinstance(instances[0], dict):
            prompts = [self.fill_prompt_template(
                instance["instruction"], instance["input"], instance['output'], instance.get("system", None)) for instance in instances]
        elif isinstance(instances[0], (list, tuple)):
            prompts = [self.fill_prompt_template(
                instance[0], instance[1], instance[2], instance[3] if len(instance) > 3 else None) for instance in instances]
        elif isinstance(instances[0], str):
            prompts = [self.fill_prompt_template(instance, None, None, None) for instance in instances]
        tokenized_prompts = self._tokenize(prompts)
        with torch.no_grad():
            generation_output = self.model(**tokenized_prompts)
        # labels = [self._get_answer_labels(instance, labels) for instance, labels in zip(instances, tokenized_prompts["labels"])]
        labels = self._get_answer_labels(instances, tokenized_prompts["labels"])
        tokenized_prompts["labels"] = labels
        response = self._get_answer_probs(generation_output, labels)  # keep it as a list

        return response.tolist()


def batched_index_select(input_tensor: torch.Tensor, dim: int, index: List[List[int]]):
    """ Expand torch.index_select to support index per sample in the batch.

    Args:
        input_tensor: torch.Tensor
    """
    for ii in range(1, len(input_tensor.shape)):
        if ii != dim:
            index = index.unsqueeze(ii)
    expanse = list(input_tensor.shape)
    expanse[0] = -1
    expanse[dim] = -1
    index = index.expand(expanse)
    return torch.gather(input_tensor, dim, index)


class ChatGenerator(TokenizedChatProcessor, ValueLabelPredictionTask):
    """ Chat preprocessing and postprocessing class that
    - use ChatRecord to manage and tokenize the dialogue history
    - call the model.generate with kv cache
    - post-process the responses

    Args:
        model
        config
        tokenizer
        prompt_templates:
        post_process_fn: a callable function or a list of callable functions. 
            responses = post_process_fn(generation_output, responses)
            post_process_fn should not change responses, otherwise _past_key_values will be wrong.
        num_parallel_adapters: provide this if multiple adapters are used in parallel to provide 
            outputs for each input.
        enforce_cn_chars: for some tokenizer, provide this to enforce that partial Chinese are not tokenized.
            This is deprecated.
        max_mu_seq_len: The maximum allowed moving average mean of sequence length, used to dynamically 
            control batch size. Set it to 0 to disable it.
        use_chat_cache: use _past_key_values of previous utterrances in generation.
    """
    def __init__(
        self, 
        model: Union[torch.nn.Module, AllSparkEngine], 
        config: GenerationConfig, 
        tokenizer: PreTrainedTokenizer, 
        prompt_templates: Union[List[str], str] = None,
        system_prompt: str = None,
        post_process_fn = None,
        num_parallel_adapters = 1,
        enforce_cn_chars: bool = True,  # This is deprecated.
        max_mu_seq_len: int = 0,  # 0 means not effective
        use_chat_cache: bool = False,
        data_generation_task: str = None,
        **kwargs
        ) -> None:
        # whether to use cache or the full chat history
        self.use_chat_cache = use_chat_cache

        if isinstance(model, list):
            model, aux_model = model[0], model[1:]
            if len(aux_model) == 1:
                aux_model = aux_model[0]
                self.aux_model, self.aux_engine, self.aux_model_name, _ = self._check_model_type(aux_model, config, **kwargs)
            else:
                self.aux_model, self.aux_engine, self.aux_model_name = [], [], []
                for _aux_model in aux_model:
                    _aux_model, _aux_engine, _aux_model_name, _ = self._check_model_type(_aux_model, config, **kwargs)
                    self.aux_model.append(_aux_model)
                    self.aux_engine.append(_aux_engine)
                    self.aux_model_name.append(_aux_model_name)
        else:
            self.aux_model = self.aux_engine = self.aux_model_name = None
        self.model, self.engine, self.model_name, max_position_embeddings = self._check_model_type(model, config, **kwargs)

        config = copy.deepcopy(config)
        if not hasattr(config, "prompt_templates"):
            config.prompt_templates = prompt_templates
        if not hasattr(config, "system_prompt"):
            config.system_prompt = system_prompt
        if not getattr(config, "roles_to_predict", None):
            if config.prompt_templates == 'qwen_chat':
                config.roles_to_predict = ['assistant']
            else:
                config.roles_to_predict = ['Assistant']

        if not hasattr(config, "max_seq_length"):
            config.max_seq_length = max_position_embeddings
        super().__init__(tokenizer, config, is_eval=True)
        self.config = config
        # chat cache
        self.chat = None
        self._past_key_values = None

        self.streamer = TextIteratorStreamer(
            self.tokenizer, skip_prompt=True) if getattr(self.config, "streaming", False) else None

        if post_process_fn:
            if not isinstance(post_process_fn, list):
                post_process_fn = [post_process_fn]
            assert all(callable(fn) for fn in post_process_fn), ValueError(
                "Expect post_process_fn to be callable; got {}.".format(post_process_fn))
        self.post_process_fn = post_process_fn  # None or list

        self.num_parallel_adapters = num_parallel_adapters
        self.molora_strategy = None
        if hasattr(self.model, 'peft_config'):
            self.molora_strategy = getattr(
                list(self.model.peft_config.values())[0], 'molora_strategy', 'average')
        if self.num_parallel_adapters > 1 and self.molora_strategy == 'parallel' and self.streamer:
            raise NotImplementedError(
                "Streaming cannot be used together with num_parallel_adapters > 1 and parallel molora_strategy.")

        self.logits_processor = LogitsProcessorList()
        if "ChatGLM" in self.model_name:
            self.logits_processor.append(InvalidScoreLogitsProcessor())
            if self.streamer:
                raise NotImplementedError("Streaming cannot be used together with ChatGLM.")
        
        # mean sequence length, used to dynamically control batch size
        self._mu_seq_len = None
        self.max_mu_seq_len = max_mu_seq_len

        # special generation task related
        self._is_StepBeamSearch_task = self._is_subtask("StepBeamSearch", data_generation_task)
        self._is_ValueLabelPrediction_task = self._is_subtask("ValueLabelPrediction", data_generation_task)
        self._is_ValueLabelPredictionWithOutput_task = self._is_subtask(
            "ValueLabelPredictionWithOutput", data_generation_task)
        if (
            self._is_StepBeamSearch_task 
            or self._is_ValueLabelPrediction_task 
            or self._is_ValueLabelPredictionWithOutput_task
        ):
            self.eos_token_id = getattr(config, "eos_token_id", -1)
            self.end_of_step_id = getattr(config, "end_of_step_id", -1)
            self.false_eos_ids = getattr(config, "false_eos_ids", [])
        # vllm
        if not hasattr(self, "_vllm_config"):
            self._vllm_config = None
        # code interpreter
        pass
        # others
        self.role_tags = getattr(config, "role_tags", {})
        self.role_map = getattr(config, "role_map", {})

    def _check_model_type(self, model, config, **kwargs):
        if isinstance(model, torch.nn.Module):
            model.generation_config = config
            model_name = model._get_name()
            engine = None
            max_position_embeddings = model.config.max_position_embeddings
        elif isinstance(model, AllSparkEngine) and AllSparkEngine is not type(None):
            engine = model
            model = None
            model_name = kwargs.get("model_name")
            max_position_embeddings = kwargs.get("max_position_embeddings", 4096)

            assert not self.use_chat_cache, NotImplementedError("AllSparkEngine does not support use_chat_cache.")
        elif isinstance(model, VLLM) and VLLM is not type(None):
            engine = model
            model = None
            model_name = kwargs.get("model_name")
            max_position_embeddings = kwargs.get("max_position_embeddings", 4096)
            use_beam_search = config.num_beams > 1
            self._vllm_config = VLLMSamplingParams(
                n=config.num_return_sequences,
                best_of=config.num_beams if use_beam_search else None,
                temperature=config.temperature if config.do_sample else 0.0,
                length_penalty=config.length_penalty,
                top_p=config.top_p,
                top_k=config.top_k,
                seed=config.seed,
                use_beam_search=use_beam_search, 
                max_tokens=config.max_new_tokens,
                spaces_between_special_tokens=True,
            )

            assert not self.use_chat_cache, NotImplementedError("VLLM engine does not support use_chat_cache.")
        else:
            raise NotImplementedError("Cannot hand model of type {}.".format(type(model)))
        
        return model, engine, model_name, max_position_embeddings

    @staticmethod
    def _is_subtask(task1, task2):
        task1 = [task1] if isinstance(task1, str) else task1
        task2 = [task2] if isinstance(task2, str) else task2
        return all(task in task2 for task in task1)

    def _convert_to_data_instance_format(self, message, role="Human", parse_str_message=True, allow_stochastic_chat=False):
        if isinstance(message, str):
            if (
                message.startswith("{") 
                and message.endswith("}")
                and parse_str_message
            ):  # dict message passed in as str
                try:
                    message_dict = eval(message)
                except:
                    message_dict = {}
                if "role" in message_dict or "chat" in message_dict:
                    return self._convert_to_data_instance_format(message_dict, role, parse_str_message, allow_stochastic_chat)
            message = {
                "role": role, 'content': message, 'metadata': {}
            }
            instance = {"chat": [message]}
        elif isinstance(message, dict):
            if 'chat' in message:
                instance = message
            elif 'role' in message:
                instance = {"chat": [message]}
            elif len(message) == 0:  # return empty chat
                instance = {"chat": []}
            else:
                raise KeyError("Expect message to have key chat or role; got {}.".format(message))
        elif isinstance(message, list):
            if len(message) == 0:  # return empty chat
                instance = {"chat": []}
            elif isinstance(message[0], dict):
                instance = {"chat": message}
            elif (
                isinstance(message[0], str)
                and allow_stochastic_chat
                and (self.config.num_return_sequences > 1 or self.config.return_all_search_sequences)
            ):
                message = {
                    "role": role, 'content': message, 'metadata': {}
                }
                instance = {"chat": [message]}
            else:
                instance = [self._convert_to_data_instance_format(_message, role, parse_str_message, allow_stochastic_chat) for _message in message]
        else:
            raise NotImplementedError("Cannot parse message of type {}.".format(type(message)))

        return instance

    def add_to_chat(self, message, role="Human", parse_str_message=True, allow_stochastic_chat=False):
        """
        Args:
            message: str, dict, or list.
                For a single message, could be
                    str: e.g., "what is the date today?"
                    dict: {"role": 'Human', 'content': "", "metadata": {...}}
                    List[dict]: list of dict, to pass history
                For multiple message, could be
                    List[str]: each element is a single message
                    List[List[dict]]: each element is a single message
            role: in case of single message or List[str], could additionally provide role.
        """
        instance = self._convert_to_data_instance_format(message, role, parse_str_message, allow_stochastic_chat)
        if self.chat:
            if isinstance(self.chat, ChatRecords):
                if isinstance(instance, dict):
                    self.chat.extend(instance['chat'])
                elif isinstance(instance, list) and len(instance) == 1:
                    self.chat.extend(instance[0]['chat'])
                else:
                    raise RuntimeError("Cannot append multiple instance to one chat")
            else:
                if isinstance(instance, dict):
                    instance = [instance]
                for _instance, _chat in zip(instance, self.chat):
                    _chat.extend(_instance['chat'])
        else:
            if isinstance(instance, dict):
                self.chat = self._get_chat(instance)
            else:
                self.chat = [self._get_chat(_instance) for _instance in instance]

    def _pad(self, input_ids, target_length, padding_side="left"):
        if input_ids.shape[1] < target_length:
            pads = torch.tensor(
                [[self.tokenizer.pad_token_id] * (target_length - input_ids.shape[1])], 
                dtype=torch.int32, device=input_ids.device
                )
            if input_ids.shape[0] > 1:  # corresponds to cases where num_parallel_adapters > 1
                pads = pads.repeat(input_ids.shape[0], 1)
            if padding_side=="left":
                input_ids = torch.cat([pads, input_ids], dim=1)
            elif padding_side=="right":
                input_ids = torch.cat([input_ids, pads], dim=1)
            else:
                raise NotImplementedError("Padding side {} not supported.".format(padding_side))

        return input_ids

    def _pad_and_cat(self, input_ids, split_past=True, padding_side="left"):
        if len(input_ids) == 1:
            return input_ids[0]
        else:
            if self.use_chat_cache and (self._past_key_values is not None) and split_past:
                # need to separately pad historical and current utterance
                past_ids, current_ids = [], []
                for _input_ids in input_ids:
                    last_round_index = torch.nonzero(_input_ids[0] == self.tokenizer.bos_token_id)[-1, 0].item()
                    past_ids.append(_input_ids[:, :last_round_index])
                    current_ids.append(_input_ids[:, last_round_index:])
                past_ids = self._pad_and_cat(past_ids, False, padding_side=padding_side)
                current_ids = self._pad_and_cat(current_ids, False, padding_side=padding_side)
                input_ids = torch.cat([past_ids, current_ids], dim=1)
            else:
                target_length = max([_input_ids.shape[1] for _input_ids in input_ids])
                input_ids = [self._pad(_input_ids, target_length, padding_side=padding_side) for _input_ids in input_ids]
                input_ids = torch.cat(input_ids, dim=0)

            return input_ids

    def _get_input_ids(self, chat=None, device='cpu', padding_side="left"):
        if chat is None:
            chat = self.chat
        if isinstance(chat, ChatRecords):
            input_ids = chat.input_ids
            input_ids = torch.tensor([input_ids], dtype=torch.int32, device=device)
            if self.num_parallel_adapters > 1 and self.molora_strategy == 'parallel':
                input_ids = input_ids.repeat(self.num_parallel_adapters, 1)
        elif isinstance(chat, list) and isinstance(chat[0], ChatRecords):  # lazy check
            input_ids = [self._get_input_ids(_chat, device) for _chat in chat]
            input_ids = self._pad_and_cat(input_ids, padding_side=padding_side)
        else:
            raise TypeError("Expect chat to be ChatRecords or List[ChatRecords]; got {}".format(type(chat)))

        return input_ids

    @staticmethod
    def _flatten_list(sequences):
        if isinstance(sequences, list) and isinstance(sequences[0], list):  # lazy check nested list
            return [_seq for seq in sequences for _seq in seq]
        else:
            return sequences

    def _get_responses(self, generation_output, input_ids):
        if self._is_ValueLabelPredictionWithOutput_task:
            return None
        output_ids = generation_output if isinstance(generation_output, torch.Tensor) \
            else generation_output.sequences
        if self._is_StepBeamSearch_task and self.config.return_all_search_sequences:
            outputs = [self.tokenizer.decode(output, skip_special_tokens=True) for _output_ids in output_ids for output in _output_ids]
        else:
            outputs = [self.tokenizer.decode(output, skip_special_tokens=True) for output in output_ids]
       
        if self.prompt_templates == 'qwen_chat':
            if isinstance(self.chat, ChatRecords):
                # ['role'] == 'assistant' means qwen continues its utterrance
                prompt_loc = 1 if self.chat[-1]['role'] == 'assistant' else -2
                prompts = [self.chat._text[-1].split(self.tokenizer.bos_token)[prompt_loc]]
            elif isinstance(self.chat, list):  # list
                prompt_locs = [1 if chat[-1]['role'] == 'assistant' else -2 for chat in self.chat]
                prompts = [chat._text[-1].split(self.tokenizer.bos_token)[prompt_loc] for chat, prompt_loc in zip(self.chat, prompt_locs)]
        else:
            # we use chat._text[-1] as the separator
            if isinstance(self.chat, ChatRecords):
                prompts = [self.chat._text[-1].split(self.tokenizer.bos_token)[-1]]
            elif isinstance(self.chat, list):  # list
                prompts = [chat._text[-1].split(self.tokenizer.bos_token)[-1] for chat in self.chat]
        # since outputs is output_ids decoded with skip_special_tokens=True, we also decode prompts with skip_special_tokens=True
        prompts =[self.tokenizer.decode(self.tokenizer.encode(prompt), skip_special_tokens=True) for prompt in prompts]
        if self.num_parallel_adapters > 1 and self.molora_strategy == 'parallel':  # repeat prompts for parallel_adapters
            prompts = prompts * self.num_parallel_adapters
        if self.config.num_return_sequences > 1:
            prompts = self._flatten_list([[prompt] * self.config.num_return_sequences for prompt in prompts]) 
        if self._is_StepBeamSearch_task and self.config.return_all_search_sequences:
            prompts = self._flatten_list([[prompt] * self.config.num_samples_per_search_step for prompt in prompts])
        # convert nested list to list
        responses = [
            output.split(prompt)[-1].strip() if prompt in output else \
                output.strip() for output, prompt in zip(outputs, prompts)]

        if self.prompt_templates == 'qwen_chat':
            responses = [
                response.split('assistant\n')[1] for response in responses
            ]

        if "ChatGLM" in self.model_name:
            responses = [chatglm_process_response(response) for response in responses]

        if self.config.num_return_sequences > 1:
            responses = [responses[index:index+self.config.num_return_sequences] for index in range(0, len(responses), self.config.num_return_sequences)]
            
        if self._is_StepBeamSearch_task and self.config.return_all_search_sequences:
            responses = [responses[index:index+self.config.num_samples_per_search_step] for index in range(0, len(responses), self.config.num_samples_per_search_step)]

        if self.num_parallel_adapters > 1 and self.molora_strategy == 'parallel':
            chunk_size = len(responses) // self.num_parallel_adapters
            responses = [self._flatten_list([responses[index_x + index_y * chunk_size] for index_y in range(self.num_parallel_adapters)]) for index_x in range(chunk_size)]
        
        if self.post_process_fn:  # None or list; post_process_fn should not change responses
            for post_process_fn in self.post_process_fn:
                responses = post_process_fn(generation_output, responses)
        role = 'assistant' if self.prompt_templates == 'qwen_chat' else "Assistant"
        self.add_to_chat(responses, role=role, parse_str_message=False, allow_stochastic_chat=True)

        if isinstance(responses, list) and len(responses) == 1:
            responses = responses[0]

        return responses

    @staticmethod
    def _interval_sum(value, value_mask):
        interval_indices = torch.nonzero(value_mask, as_tuple=True)[0]
        interval_start = 0
        values = []
        for interval_end in interval_indices:
            values.append(value[interval_start:interval_end].mean())
            interval_start = interval_end
        # eos token also indicates end of a step 
        # do not consider the eos token 
        if interval_start < len(value) - 1:
            values.append(value[interval_start:-1].mean())

        return torch.tensor(values, device=value.device)

    def _get_values(self, generation_output, input_ids):
        if self._is_ValueLabelPrediction_task:
            # output_ids <-- pad_token_ids + actual input_ids + actual output_ids + pad_token_ids
            output_ids = generation_output if isinstance(generation_output, torch.Tensor) \
                else generation_output.sequences
            output_ids = output_ids[:, input_ids.shape[1]:]  # actual output_ids + pad_token_ids
            value_masks = [
                torch.tensor([  # here role_tag_len is not provided; if role_tag contains end_of_step_id, should provide it in false_eos_ids
                    True if self._locate_end_of_step_id(_output_ids, index, []) else False for index in range(len(_output_ids))
                    ], device=output_ids.device, dtype=torch.bool) for _output_ids in output_ids
            ]
            # get values for step tokens; shift value_mask by 1
            if self.config.value_by_transition_scores:
                values = [self._interval_sum(value, value_mask).to("cpu").tolist() for value, value_mask in zip(generation_output.values, value_masks)]
            else:
                values = [value[value_mask[1:]].to("cpu").tolist() for value, value_mask in zip(generation_output.values, value_masks)]
        elif self._is_ValueLabelPredictionWithOutput_task:
            # get value_masks; value_masks do not consider pad_tokens 
            if isinstance(self.chat, list):
                value_masks = [self.process_chat(chat, default_value_label=1).tokenize("Assistant")["value_labels"] for chat in self.chat]
            else:
                value_masks = [self.process_chat(self.chat, default_value_label=1).tokenize("Assistant")["value_labels"]]
            value_masks = [torch.tensor(value_mask, dtype=torch.int32, device=input_ids.device) == 1 for value_mask in value_masks]
            # get values without pad_tokens
            values = [value[_ids != self.tokenizer.pad_token_id] for value, _ids in zip(generation_output, input_ids)]
            # get values for step tokens; shift value_mask by 1
            if self.config.value_by_transition_scores:
                values = [self._interval_sum(value, value_mask).to("cpu").tolist() for value, value_mask in zip(values, value_masks)]
            else:
                values = [value[:-1][value_mask[1:]].to("cpu").tolist() for value, value_mask in zip(values, value_masks)]
        elif self._is_StepBeamSearch_task and self.config.return_all_search_sequences:
            values = generation_output.verifier_scores
            values = values.tolist()
        else:
            values = None

        if values is not None:
            if isinstance(self.chat, list):
                for chat, value in zip(self.chat, values):
                    if isinstance(values[0], torch.Tensor):
                        value = value.tolist()
                    chat[-1]['metadata'].update({"values": value})
            else:
                self.chat[-1]['metadata'].update({"values": values})

        return values

    def _get_num_chat_rounds(self):
        return [len(self.chat) // 2] if isinstance(self._chat, ChatRecords) \
            else [len(chat) // 2 for chat in self.chat]

    def _add_past_key_values(self, generation_output):
        if self.use_chat_cache:
            self._past_key_values = generation_output.past_key_values
            num_chat_rounds = self._get_num_chat_rounds()
            # rearrange _past_key_values if generation_output.sequences has pads in the middle
            if any(_num > 1 for _num in num_chat_rounds) and len(num_chat_rounds) > 1:
                use_flash_attn = getattr(self.model, "use_flash_attn", False)
                # _past_key_values is 1 token shorter than generation_output.sequences
                do_rearrange = False
                current_indices = list(range(generation_output.sequences.shape[1] - 1))
                for _num, output_ids in zip(num_chat_rounds, generation_output.sequences):
                    if _num > 1:
                        pad_indices = [index for index, _id in enumerate(output_ids[:-1]) if _id == self.tokenizer.pad_token_id]
                        non_pad_indices = [index for index, _id in enumerate(output_ids[:-1]) if _id != self.tokenizer.pad_token_id]
                        new_indices = pad_indices + non_pad_indices
                        if new_indices != current_indices:
                            do_rearrange = True
                if do_rearrange:
                    pass

    def response(self, message, **kwargs):
        """
        Args:
            message: str, dict, or list.
                For a single message, could be
                    str: e.g., "what is the date today?"
                    dict: {"role": 'Human', 'content': "", "metadata": {...}}
                    List[dict]: list of dict, to pass history
                For multiple message, could be
                    List[str]: each element is a single message
                    List[List[dict]]: each element is a single message
            kwargs: 
                Other keywords will be passed to model.generate and thus model.forward.
        """
        if self.engine:  # if engine exists, use engine_response
            return self.engine_response(message)
        if self.prompt_templates in ["qwen_chat", "llama_new_chat"]:
            chat_role = 'user'
        else:
            chat_role = "Human"
        self.add_to_chat(message, chat_role)
        input_ids = self._get_input_ids(device=self.model.device)

        with torch.no_grad():
            if self._is_StepBeamSearch_task:
                # prepare fn and kwargs for generation and value
                from utils.rl_utils import MyAutoModelForCausalLMWithValueHead
                from utils.generation_utils import StepBeamSearch
                # prepare value model
                if self.aux_model:
                    assert isinstance(self.aux_model, MyAutoModelForCausalLMWithValueHead), TypeError(
                        "Expect model to be MyAutoModelForCausalLMWithValueHead to use step_beam_search, got {}".format(type(self.aux_model)))
                    value_fn, value_kwargs = self.aux_model.value, {"input_ids": None, "attention_mask": None}
                else:
                    assert isinstance(self.model, MyAutoModelForCausalLMWithValueHead), TypeError(
                        "Expect model to be MyAutoModelForCausalLMWithValueHead to use step_beam_search, got {}".format(type(self.model)))
                    value_fn, value_kwargs = self.model.value_head, {"last_hidden_state": None}
                # prepare generation model
                kwargs.update({"return_policy": True})
                generation_model = self.model.pretrained_model if isinstance(self.model, MyAutoModelForCausalLMWithValueHead) else self.model
                generate_policy, generate_kwargs = generation_model.generate(
                    input_ids=input_ids, 
                    generation_config=self.config, 
                    return_dict_in_generate=True,
                    logits_processor=self.logits_processor,
                    past_key_values=self._past_key_values,
                    output_past_key_values=self.use_chat_cache,
                    **kwargs,
                    )

                # init search policy
                generation_output = StepBeamSearch(
                    generate_policy=generate_policy,
                    generate_kwargs=generate_kwargs,
                    value_function=value_fn,
                    value_kwargs=value_kwargs,
                    generation_config=self.config,
                    model_config=self.model.config,
                    device=input_ids.device).step_beam_search(input_ids)
            elif self._is_ValueLabelPrediction_task:
                from utils.rl_utils import MyAutoModelForCausalLMWithValueHead

                if self.aux_model:
                    assert isinstance(self.aux_model, MyAutoModelForCausalLMWithValueHead), TypeError(
                        "Expect model to be MyAutoModelForCausalLMWithValueHead to use step_beam_search, got {}".format(type(self.aux_model)))

                    generation_output = self.model.generate(
                        input_ids=input_ids, 
                        generation_config=self.config, 
                        return_dict_in_generate=True,
                        logits_processor=self.logits_processor,
                        past_key_values=self._past_key_values,
                        output_past_key_values=self.use_chat_cache,
                        **kwargs,
                        )
                    output_ids = generation_output if isinstance(generation_output, torch.Tensor) \
                        else generation_output.sequences
                    attention_mask = (output_ids != self.tokenizer.pad_token_id).to(torch.int64)
                    value_output = self.model.value(input_ids=output_ids, attention_mask=attention_mask, **kwargs)
                    generation_output.values = value_output.values
                else:
                    assert isinstance(self.model, MyAutoModelForCausalLMWithValueHead), TypeError(
                        "Expect model to be MyAutoModelForCausalLMWithValueHead to use step_beam_search, got {}".format(type(self.model)))

                    generation_output = self.model.generate_and_value(
                        input_ids=input_ids, 
                        generation_config=self.config, 
                        return_dict_in_generate=True,
                        logits_processor=self.logits_processor,
                        past_key_values=self._past_key_values,
                        output_past_key_values=self.use_chat_cache,
                        **kwargs,
                        )
            elif self._is_ValueLabelPredictionWithOutput_task:
                from utils.rl_utils import MyAutoModelForCausalLMWithValueHead

                assert isinstance(self.model, MyAutoModelForCausalLMWithValueHead), TypeError(
                    "Expect model to be MyAutoModelForCausalLMWithValueHead to use step_beam_search, got {}".format(type(self.model)))

                attention_mask = (input_ids != self.tokenizer.pad_token_id).to(torch.int64)
                generation_output = self.model.value(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
            else:
                generation_output = self.model.generate(
                    input_ids=input_ids, 
                    generation_config=self.config, 
                    return_dict_in_generate=True,
                    logits_processor=self.logits_processor,
                    past_key_values=self._past_key_values,
                    output_past_key_values=self.use_chat_cache,
                    **kwargs,
                    )
        responses = self._get_responses(generation_output, input_ids)
        self._add_past_key_values(generation_output)
        values = self._get_values(generation_output, input_ids)

        if responses is not None and values is not None:
            return responses, values
        elif responses is not None:
            return responses
        elif values is not None:
            return values
        else:
            raise RuntimeError("Both responses and values are None.")

    def engine_response(self, message):
        self.add_to_chat(message, "Human")

        if isinstance(self.engine, AllSparkEngine) and AllSparkEngine is not type(None):
            input_ids = self._get_input_ids(device='cpu', padding_side="right").to(torch.int64)
            attention_mask = torch.ones_like(input_ids, dtype=torch.int64)
            attention_mask[input_ids==self.tokenizer.pad_token_id] = 0

            generate_config = self.config.__dict__.copy()
            generate_config['max_length'] = generate_config['max_new_tokens'] + input_ids.shape[1]

            generation_output = self.engine.run_text_generation(
                self.model_name,
                {
                    "input_ids": torch.utils.dlpack.to_dlpack(input_ids),
                    "attention_mask": torch.utils.dlpack.to_dlpack(attention_mask),
                },
                generate_config=generate_config,
            )

            output_ids = torch.utils.dlpack.from_dlpack(generation_output["generated_ids"])
            response = self._get_responses(output_ids, input_ids)
        elif isinstance(self.engine, VLLM) and VLLM is not type(None):
            input_ids = self._get_input_ids(device='cpu').to(torch.int64).to("cpu").tolist()
            input_ids = [[_id for _id in _input_ids if _id != self.tokenizer.pad_token_id] for _input_ids in input_ids]
            outputs = self.engine.generate(
                prompts=message, 
                sampling_params=self._vllm_config, 
                prompt_token_ids=input_ids,
                )
            response = [output.outputs[0].text.strip() for output in outputs]
            self.add_to_chat(response, role="Assistant", parse_str_message=False, allow_stochastic_chat=True)
        # self._add_past_key_values(generation_output)
        if isinstance(response, list) and len(response) == 1:
            response = response[0]

        return response

    def __call__(self, message, **kwargs):
        return self.response(message, **kwargs)

    def stream_response(self, message):
        self.add_to_chat(message, "Human")
        input_ids = self._get_input_ids(device=self.model.device)

        self.model.eval()
        generation_kwargs = dict(
            input_ids=input_ids, 
            streamer=self.streamer, 
            generation_config=self.config, 
            return_dict_in_generate=True,
            logits_processor=self.logits_processor,
            past_key_values=self._past_key_values,
            )
        with torch.no_grad():
            thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
            thread.start()

        return self.streamer

    def batch_response(self, message):
        return self.response(message)

    def clear_cache(self):
        self.chat = None
        self._past_key_values = None


def get_value_from_nested_dict(nested_dict: dict, key_list: Union[Union[str, int], List[Union[str, int]]]):
    if isinstance(key_list, (str, int)):
        if key_list in nested_dict:
            return nested_dict[key_list]
        else:
            raise KeyError('Current-level dict has keys: {}; got {}.'.format(list(nested_dict.keys()), key_list))
    elif isinstance(key_list, list):
        if len(key_list) == 1:
            return get_value_from_nested_dict(nested_dict, key_list[0])
        else:
            return get_value_from_nested_dict(nested_dict[key_list[0]], key_list[1:])
    else:
        raise NotImplementedError("Cannot parse key_list {}".format(key_list))


def _get_nested_dict_from_one_kv_pair(keys: Union[Union[str, int], List[Union[str, int]]], value):
    """ Create a nested dict from a list of keys and a value.

    Args:
        keys: string/int or a list of strings/ints, e.g., [kl1, kl2, kl3, ...].
        value: anything.
    """
    if isinstance(keys, (str, int)):
        return {keys: value}
    else:  # list
        if len(keys) == 1:
            return {keys[0]: value}
        else:
            return {keys[0]: _get_nested_dict_from_one_kv_pair(keys[1:], value)}


def update_nested_dict(nested_dict: dict, keys: Union[str, List[str]], value):
    """ Update an existing nested_dict using (keys, value) pair.
    
    Args:
        nested_dict: a dict with multi-level keys, e.g., {kl1: {kl2: {kl3: value, ...}, ...}, ...}.
        keys: a list of strings, e.g., [kl1, kl2, kl3, ...].
        value: anything.
    """
    if (
        isinstance(keys, list) 
        and len(keys) > 1 
        and keys[0] in nested_dict
    ):
        update_nested_dict(nested_dict[keys[0]], keys[1:], value)
    else:
        nested_dict.update(_get_nested_dict_from_one_kv_pair(keys, value))

    return nested_dict


def get_cn_token_combos(
        tokenizer: PreTrainedTokenizer, 
        cn_chars_path: str = None,
        consider_special_underscore: bool = False,
        ):
    """ Returns two dict for the mapping between Chinese characters and tokenized ids.
    cn_char2ids: a flatten dict with paris {cn_char: ids}
    ids2cn_char: a nested dict with pairs {id0: {id1: {id2: cn_char, ...}, ...}, ...}

    Args:
        tokenizer: a tokenzier that supports tokenizing any Chinese character to 
        cn_chars_path:
        consider_special_underscore: some tokenizers may tokenize "▁"+CN_char as one token, e.g., 
            in chatGLM, both '▁阿' and '阿' are in the vocab. Setting this to True to consider "▁"+CN_char.
    """
    if not cn_chars_path:  # if not provided, read from file
        cn_chars_path = "./data/CN_characters.txt"
    if os.path.isfile(cn_chars_path):
        with open(cn_chars_path,"r") as reader:
            cn_chars = reader.read()
    else:
        raise FileNotFoundError("{} is not a valid file path.".format(cn_chars_path))
    if hasattr(tokenizer, '_tokenize'):
        cn_tokens = tokenizer._tokenize(','.join(cn_char for cn_char in cn_chars))  # tokens split by ','
    else:
        cn_tokens = tokenizer.convert_ids_to_tokens(  # tokens split by ','
            tokenizer.encode(','.join(cn_char for cn_char in cn_chars), add_special_tokens=False))
    if cn_tokens[0] == '▁':
        cn_tokens[0] = ','

    vocab2ids = tokenizer.get_vocab()

    index = 0
    cn_char2ids = {}
    ids2cn_char = {}
    current_tokens = []
    for cn_token in cn_tokens:
        if cn_token == ',':
            if current_tokens:
                current_token_ids = [vocab2ids[token] for token in current_tokens]
                cn_char2ids.update({cn_chars[index]: current_token_ids})
                update_nested_dict(ids2cn_char, current_token_ids, cn_chars[index])
                if consider_special_underscore:
                    underscore_cn_char = '▁' + cn_chars[index]
                    if underscore_cn_char in vocab2ids:
                        cn_char2ids.update({underscore_cn_char: [vocab2ids[underscore_cn_char]]})
                        update_nested_dict(ids2cn_char, vocab2ids[underscore_cn_char], underscore_cn_char)
                index += 1
                current_tokens = []
        else:
            current_tokens.append(cn_token)
    else:
        if current_tokens:
            current_token_ids = [vocab2ids[token] for token in current_tokens]
            cn_char2ids.update({cn_chars[index]: current_token_ids})
            update_nested_dict(ids2cn_char, current_token_ids, cn_chars[index])

    return cn_char2ids, ids2cn_char


class PrefixCNCharacterTokenGen:
    """ For a tokenizer that encodes Chinese characters into several tokens, 
    if only part of the tokens are predicted, the character cannot be decoded.
    To prevent this issue, this class formulates a logits_processor that only 
    allows a certain set of tokens to be the next token. The set is conditioned
    on previous input_ids.

    The method used here is related to constrained beam search. See ref:
    Cao, N. De, et al. (2021). Autoregressive Entity Retrieval. ICLR.

    Args:
        tokenizer: a tokenizer that encodes Chinese characters into one or three tokens.
        num_scores: the number of scores to predict in the model. It may be different 
            from the length of tokenizer, e.g., in the case of ChatGLM.

    Calls:
        batch_id: not used.
        inputs_ids: tensor of shape (n,), the generated token_ids so far.

    Returns:
        A list of available tokens to generate.
    """
    def __init__(self, tokenizer: PreTrainedTokenizer, num_scores: int = None) -> None:
        # get the mapping between cn tokens and ids
        # in ChatGLMTokenizer, '▁阿' and '阿' are two different tokens
        consider_special_underscore = True if 'ChatGLMTokenizer' in tokenizer.__class__.__name__ else False
        self.cn_char2ids, self.ids2cn_char = get_cn_token_combos(
            tokenizer, consider_special_underscore=consider_special_underscore)
        # cn_chars are represented using either 1 or 3 tokens
        self.level1tokens = list(set(  # find the first token of 3-token cn_chars
            key for key, value in self.ids2cn_char.items() if isinstance(value, dict)))
        self.level2tokens = list(set(  # find the second token of 3-token cn_chars
            key2 for key1 in self.level1tokens for key2 in self.ids2cn_char[key1]))
        self.all_tokens = list(range(num_scores or len(tokenizer)))

    def prefix_allowed_tokens_fn(self, batch_id: int, inputs_ids: torch.Tensor) -> List[int]:
        if inputs_ids[-1] in self.level1tokens:
            return list(self.ids2cn_char[inputs_ids[-1].item()].keys())
        elif inputs_ids[-1] in self.level2tokens and inputs_ids[-2] in self.level1tokens:
            return list(self.ids2cn_char[inputs_ids[-2].item()][inputs_ids[-1].item()].keys())
        else:
            return self.all_tokens
    
    def __call__(self, batch_id: int, inputs_ids: torch.Tensor) -> List[int]:
        return self.prefix_allowed_tokens_fn(batch_id, inputs_ids)
    

class InvalidScoreLogitsProcessor(LogitsProcessor):
    """Modified from models.modeling_chatglm.InvalidScoreLogitsProcessor:
    below we only check positive infinity, as setting scores to negative infinity is an approach to 
    prevent generation of unwanted tokens. 
    """
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if torch.isnan(scores).any() or torch.isposinf(scores).any():
            scores.zero_()
            scores[..., 5] = 5e4
        return scores
    

def chatglm_process_response(response):
    response = response.strip()
    response = response.replace("[[训练时间]]", "2023年")
    punkts = [
        [",", "，"],
        ["!", "！"],
        [":", "："],
        [";", "；"],
        ["\?", "？"],
    ]
    for item in punkts:
        response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
        response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
    return response


def preprocess_identity_tokens_in_chat_dataset(
    data_path: str, 
    human_token: Union[str, re.Pattern],
    assist_token: Union[str, re.Pattern],
    eos_token: str = "<|endoftext|>",
    mode = "readable",
    prepprocess_fn = None,
    ):
    """Replace tokens corresponding to identities in chat datasets.

    Args:
        data_path:
        human_token: str or regex expression, human_token will be replaced with <Human>.
        assist_token: str or regex expression, assist_token will be replaced with <Assistant>.
        eos_token: 
        mode: one of `readable`, `efficient`, `extreme`
            readable:
                "instruction": [Instruction]\n[A]:\n[output]...\n[H]:\n[Instruction]
                "output": [output]
                tokenized_prompt: [H]:\n{INSTRUCTION}\n[A]:\n{OUTPUT}[eos]
                In this mode, one conversation has multiple instances, each of different rounds.
                This mode must be used on datasets with contexts.
            efficient: 
                "instruction": [Instruction]\n[A]:\n[output][eos]...\n[H]:\n[Instruction]
                "output": [output]
                tokenized_prompt: [H]:\n{INSTRUCTION}\n[A]:\n{OUTPUT}[eos]
                In this mode, one conversation is processed into one instance.
            extreme:
                "instruction": [Instruction][A][output][eos]...[eos][Instruction]
                "output": [output]
                tokenized_prompt: {INSTRUCTION}[A]{OUTPUT}[eos]
                In this mode, one conversation is processed into one instance. No token for [H]. No ":" or "\n".
            If "input" is not empty, [Instruction] = [Instruction] + "\n" + {INPUT}
            If "output" is not empty, [Instruction] = [Instruction] + to_assist_token + {OUTPUT}
        prepprocess_fn: other pre-process function applied before this one.
    """
    dataset = load_datasets(data_path, columns_to_read=["instruction", "input", "output"])
    num_samples = len(dataset)
    assert num_samples > 0, RuntimeError("Empty datasets at {}".format(data_path))
    if mode == 'readable':
        to_human_token = "\n<Human>:\n"
        to_assist_token = "\n<Assistant>:\n"
    elif mode == "efficient":
        to_human_token = eos_token + "\n<Human>:\n"
        to_assist_token = "\n<Assistant>:\n"
    elif mode == "extreme":
        to_human_token = eos_token
        to_assist_token = "<Assistant>"

    def _preprocess_one_instance(instance):
        if prepprocess_fn and callable(prepprocess_fn):
            instance = prepprocess_fn(instance)
        instruction = instance['instruction']
        if isinstance(human_token, str):
            instruction = instruction.replace(human_token, to_human_token)
        else:  # re.Pattern
            instruction = re.sub(human_token, to_human_token, instruction)
        if isinstance(assist_token, str):
            instruction = instruction.replace(assist_token, to_assist_token)
        else:  # re.Pattern
            instruction = re.sub(assist_token, to_assist_token, instruction)
        if instruction.startswith(eos_token):
            instruction = instruction.lstrip(eos_token)
        instruction = instruction.strip()
        if instance.get("input", ""):
            instruction += "\n" + instance["input"].strip()
        if instance.get("output", ""):
            instruction += to_assist_token + instance["output"].strip()
        return {"instruction": instruction}

    instances = [_preprocess_one_instance(instance) for instance in dataset]

    # save processed dataset to json
    file_name = data_path.split('/')[-1]
    data_folder = data_path.split(file_name)[0]
    extension = file_name.split('.')[-1]
    base_name = file_name.split('.'+extension)[0]
    output_path = os.path.join(data_folder, base_name + '_{}.'.format(mode) + extension)
    print('Save output to {}.'.format(output_path))
    with open(output_path, 'w', encoding='utf-8') as f: 
        f.write(json.dumps(instances, indent=2, ensure_ascii=False))


def deduplicate_chat_dataset_context(data_path: str):
    dataset = load_datasets(data_path, columns_to_read=["instruction", "input", "output"])
    num_samples = len(dataset)
    assert num_samples > 0, RuntimeError("Empty datasets at {}".format(data_path))

    instances = []
    current_instance = dataset[0]
    for instance in dataset:
        if current_instance['instruction'] in instance['instruction']:
            current_instance = instance
        elif instance['instruction'] in current_instance['instruction']:
            continue
        else:
            instances.append(current_instance)
            current_instance = instance
    print('Number of samples before/after processing: {} -> {}'.format(num_samples, len(instances)))
    
    # save processed dataset to json
    file_name = data_path.split('/')[-1]
    data_folder = data_path.split(file_name)[0]
    extension = file_name.split('.')[-1]
    base_name = file_name.split('.'+extension)[0]
    output_path = os.path.join(data_folder, base_name + '_{}.'.format("dedupl") + extension)
    print('Save output to {}.'.format(output_path))
    with open(output_path, 'w', encoding='utf-8') as f: 
        f.write(json.dumps(instances, indent=2, ensure_ascii=False))

    return output_path


def test_IncorrectAnswer_augment():

    # from utils.data_utils import (
    # TokenizedPromptProcessorWithDA, prepare_datasets, LPSymmetricShuffleAug
    # )
    from transformers import AutoTokenizer
    from datasets import load_dataset, concatenate_datasets, Features, Value, Sequence
    name = None
    p_wrong_map = {'obj_dir_wrong': 0.0,
            'obj_coeff_wrong': 0.0,
            'obj_coeff_wrong_hd': 0.0,
            'obj_coeff_switch_wrong':0.0,
            'obj_extra_wrong': 0.0,
            'obj_extra_wrong_hd': 0.0,
            'obj_miss_wrong': 0.0,
            'cons_miss_wrong': 0.0,
            'cons_more_wrong': 0.0,
            'cons_more_wrong_hd': 0.0,
            'cons_dir_wrong': 0.0,
            'cons_coeff_wrong': 0.0,
            'cons_coeff_wrong_hd': 0.0,
            'con_switch_coeff_wrong':0.0,
            'cons_switch_coeff_wrong': 0.0,
            'con_miss_wrong': 0.0,
            'con_extra_wrong': 0.0,
            'con_extra_wrong_hd': 0.0,
            'con_var_wrong': 0.0,
            'con_var_wrong_hd': 0.0
            }
    if name == None:
        pwrong = 0.0
    else:
        pwrong = 1.0
        p_wrong_map[name] = 1.0
    
    class DataArguments:
        train_file = ["../copilot_utilities/datasets/dev/en_dev_99_cn.json"]
        validation_file = None
        validation_split = 0.0
        disable_caching = False
        train_on_inputs = False
        max_seq_length = 512
        prompt_templates = "simple_qa"
        data_generation_task = "IncorrectMathChecking"
        data_augmentations = {"name": "incorrect_answer", "p_wrong": pwrong, 'p_wrong_map':p_wrong_map} #.  p_wrong=0.5, p_obj_wrong=0.4, p_cons_wrong=0.7
        data_const_deduplicate = False

    columns_to_read = ["instruction", "input", "language", "var_description", "objective_description", "constraint_description", "index"]

    base_model_path = "../output/llama2-13b_chat_hf"
    tokenizer = AutoTokenizer.from_pretrained(
        base_model_path, 
        padding_side="left"  # allow batched inference
        )
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    data_args = DataArguments()
    postprocessor = PromptProcessorWithDA(tokenizer, data_args)
    train_data, validation_data = prepare_datasets(  # use dataset.cleanup_cache_files() to remove cache
        data_args,
        #postprocessor_fn=postprocessor, 
        columns_to_read=columns_to_read,
        shuffle=False
        )

    processed_instances = []
    for index, instance in enumerate(train_data):
        # print("preprocessed instance:\n", instance)
        processed_instance = postprocessor(instance)
        processed_instance['augment_fuction'] = name
        #print('processed instance:\n',processed_instance )
        processed_instances.append(processed_instance)

    print('test_over')


if __name__ == "__main__":
    import numpy as np
    from transformers import LlamaTokenizer

    padding_side = 'left'
    base_model_path = '/root/model_ckpt/llama-7b/'
    tokenizer = MyLlamaTokenizer.from_pretrained(
        base_model_path, 
        padding_side=padding_side  # allow batched inference
        )
    tokenizer.pad_token_id = 0  # set pad_token to '<unk>'

    train_file = '../datasets/alpha/alpaca_data_cleaned.json'
    # train_file = '../datasets/belle_cn/belle_data1M_cn.json'
    # train_file = '../datasets/CoT/CoT_data.json'

    train_data = load_dataset("json", data_files=train_file)
    train_data = train_data['train'].shuffle()

    data_args = DataArguments(train_file=train_file, max_seq_length=1024)
    data_args.train_on_inputs = False
    preprocessor = TokenizedPromptProcessor(tokenizer, data_args)
    input_lengths = []
    for index, instance in enumerate(train_data):
        tokenized_prompt = preprocessor(instance)
        input_lengths.append(len(tokenized_prompt['input_ids']))
        if index >= 10000:
            break
    
    input_lengths = np.array(input_lengths)
    quantiles = np.quantile(input_lengths, [0.5, 0.9, 0.95, 0.975])
    print(quantiles)
    print('test over')