# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Classes and functions related to text generation using model finetuned with LoRA.
Modified from https://github.com/tloen/alpaca-lora
"""
import os
import sys

if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

#----------------------------------------------------------------------------
# handle output redirection
redirect_path = None
STDOUT = sys.stdout
for argv in sys.argv[1:]:
    if argv.startswith('--save_path='):
        redirect_path = argv.split('=')[1]
        if os.path.isdir(redirect_path):
            redirect_path = redirect_path.rstrip('/')
        elif redirect_path.endswith(('json', 'jsonl')):
            redirect_path = os.path.dirname(redirect_path).rstrip('/')
        elif redirect_path == "None":
            break
        else:
            raise NameError("Cannot parse save_path. Expect a folder or a json path; got {}".format(redirect_path))
        break
if isinstance(redirect_path, str) and os.path.isdir(redirect_path):
    redirect_path = redirect_path + '.log'
    sys.stdout = sys.stderr = open(redirect_path, 'w')
#----------------------------------------------------------------------------

import copy
from dataclasses import dataclass, field
import gradio as gr
import json
import logging
from typing import Optional
import torch
from tqdm import tqdm
import transformers
import warnings

# Catch when user should re-install transformers library
assert "LlamaTokenizer" in transformers._import_structure["models.llama"], ImportError(
    """LLaMA is now in HuggingFace's main branch.
    Please reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git
    """)
from transformers import (
    GenerationConfig,
    HfArgumentParser,
    set_seed,
)

from examples.run_lora import (
    ModelArguments, 
    GenerationArguments,
    MODEL_CLASS, 
    CONFIG_CLASS, 
    DTYPE_CLASS, 
    get_tokenizer,
    AutoConfig,
    AutoModelForCausalLM,
)
from utils.data_utils import (
    ALL_PROMPT_TEMPLATES,
    load_datasets,
    DEFAULT_LLAMA_SYSTEM_PROMPT, 
    DEFAULT_LLAMA_NEW_SYSTEM_PROMPT,
    DEFAULT_QWEN_SYSTEM_PROMPT,
    TextGenerator,
    ProbsGenerator,
    ChatGenerator,
    load_jsons,
    get_filtered_data_by_indices,
    granite_guardian_post_process_fn,
)
from utils.generation_utils import generation_post_init, PythonCodeInterpreter
from utils.metric_utils import SimpleVotingEnsemble
from utils.peft_utils import MoPeftModel
from utils.reward_model_utils import (
    get_rm_ensemble_instance, 
    ensemble_by_yes_probs, 
    RMGeneratorRetainPorbsProcessor
)
from utils.rl_utils import MyAutoModelForCausalLMWithValueHead, load_rl_weights_into_model
from global_const import CKPT_FOLDER, LLAMA_SIZE

logger = logging.getLogger(__name__)


@dataclass
class ExtendedModelArguments(ModelArguments):
    """Arguments related to the model itself.

    See examples.run_lora.ModelArguments for more details.
    """
    # below I overwrite the metadata of model_name_or_path from ModelArguments of examples.run_lora
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "The LoRA model checkpoint for weights initialization."
            "Could be a dict {'adapter_name':'some_path'} if multiple lora "
            "checkpoints are to be loaded."
        },
    )
    merge_lora: Optional[bool] = field(
        default=False,
        metadata={
            "help": "Merge the LoRA model checkpoint into the base model before inference."
        },
    )

    @staticmethod
    def is_ckpt_folder(folder):
        assert os.path.isdir(folder), FileExistsError(
            'model_name_or_path must be a valid folder; got {}.'.format(folder))
        ckpt_model_path1 = os.path.join(folder, "pytorch_model.bin")
        ckpt_model_path2 = os.path.join(folder, "adapter_model.bin")
        ckpt_model_path3 = os.path.join(folder, "pytorch_model.safetensors")
        ckpt_model_path4 = os.path.join(folder, "adapter_model.safetensors")
        assert (
            os.path.exists(ckpt_model_path1) or os.path.exists(ckpt_model_path2)
            or os.path.exists(ckpt_model_path3) or os.path.exists(ckpt_model_path4)
        ), FileNotFoundError(
            f'Cannot find a ckpt file in {folder}.')

    def check_model_name_or_path(self, model_name_or_path):
        # in case of a dict, multiple model_name_or_path are provided for a single base model.
        if isinstance(model_name_or_path, dict):
            # ensure each path is unique
            assert len(set(model_name_or_path.values())) == len(model_name_or_path.values()), ValueError(
                "Same path is provided multiple times in model_name_or_path {}".format(model_name_or_path))
            # check each path is valid
            for adapter_name, path in model_name_or_path.items():
                if path.startswith('~'):
                    path = os.path.expanduser("~") + path[1:]
                    model_name_or_path.update({adapter_name: path})
                self.is_ckpt_folder(path)
        # in case of a list, multiple model_name_or_path are provided for multiple base models.
        elif isinstance(model_name_or_path, list):
            for _model_name_or_path in model_name_or_path:
                self.check_model_name_or_path(_model_name_or_path)
        else:
            self.is_ckpt_folder(model_name_or_path)

    def __post_init__(self):
        super().__post_init__()
        if self.model_name_or_path:
            if (
                (self.model_name_or_path.startswith("{") and self.model_name_or_path.endswith("}"))  # lazy-check dict in str
                or (self.model_name_or_path.startswith("[") and self.model_name_or_path.endswith("]"))  # lazy-check list in str
            ):
                self.model_name_or_path = eval(self.model_name_or_path)
            self.check_model_name_or_path(self.model_name_or_path)
        else:
            logger.warning('model_name_or_path is not provided, which means the base model is to be evaluated.')
        if self.merge_lora and self.load_in_8bit:
            raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode.")


@dataclass
class DataArguments:
    """ Arguments pertaining to what data we are going to input our model for eval.
    """
    input_source: str = field(
        default="preset",
        metadata={
            "help": "Where to get input sequences. The sources below are supported:"
            "preset: several pre-set inputs are tested."
            "terminal: a terminal interface is used."
            "web: a gradio interface is used."
            "[json path]: if a local json path is given, the json will be loaded and evaluated."
        },
    )
    is_input_source_data_paths: Optional[bool] = field(
        default=False,
        metadata={
            "help": "Whether the input_source contains a dict of file paths; with key being the json name"
            "of save_path, and value being the input path. In this case, need to load the paths first, then"
            "the actual data."
        },
    )
    save_path: Optional[str] = field(
        default=None,
        metadata={"help": "If provided, the outputs of all evaluations will be saved to a json file."},
    )
    overwrite_save_path: Optional[bool] = field(
        default=False,
        metadata={"help": "If save_path is an existing file and overwrite_save_path is True, will overwrite save_path."},
    )
    disable_caching: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to disable dataset caching and clear cache. Only applicable when input_source is [json path]."},
    )
    share_link: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to create a publicly shareable link for the interface. Only applicable when input_source is web."},
    )
    prompt_templates: Optional[str] = field(
        default="simple_qa",
        metadata={
            "help": "The template to use in prompts. Could be a str in {} to use preset templates,"
            "or a list of two strings, corresponding to prompt with input placeholder, and prompt without input placeholder.".format(ALL_PROMPT_TEMPLATES)
        },
    )
    system_prompt: Optional[str] = field(
        default="default",
        metadata={
            "help": "The template to use as system prompt. Could be any str, including special cases: "
            "None means using empty system prompt; default means using the default system prompt of current prompt_templates."
        },
    )
    enforce_cn_chars: Optional[bool] = field(
        default=False,
        metadata={
            "help": "When parts of CN character tokens are detected, whether to force the following tokens to be in the possible "
            "token sets. Set this to True to avoid generating partial CN characters which cannot be decoded."
        },
    )
    eval_batch_size: Optional[int] = field(
        default=8,
        metadata={
            "help": "Evaluation batch size."
        },
    )
    pseudo_master_port: Optional[str] = field(
        default="29501",
        metadata={
            "help": "To compile megatron dependencies, a master port must be provided if DP is not initialized."
        },
    )
    starting_index: Optional[int] = field(
        default=0,
        metadata={
            "help": "The starting index if input_source is file."
        },
    )
    filter_data_by_indices: Optional[str] = field(
        default=None,
        metadata={"help": "Whether to filter the test data by indices. May provide a list or a file path that saves the list."},
    )
    include_by_indices: Optional[bool] = field(
        default=True,
        metadata={"help": "When filter_data_by_indices is provided, whether to include samples from filter_data_by_indices or exclude them."},
    )
    max_mu_seq_len: Optional[int] = field(
        default=0,
        metadata={
            "help": "The maximum allowed moving average mean of sequence length, used to dynamically control batch size."
        },
    )
    data_generation_task: Optional[str] = field(
        default='SimpleMathFormulation',
        metadata={
            "help": "Optionally a str or list of generation task names."
            "For example, 'SimpleMathFormulation', 'LPProblemDescriptionGeneration', 'IncorrectMathChecking' "
        }
    )
    save_steps: Optional[int] = field(
        default=-1,
        metadata={
            "help": "Save the generation results every this number of steps; disabled when negative."
        }
    )
    # ValueLabelPrediction task related
    end_of_step_id: Optional[int] = field(
        default=-1,
        metadata={
            "help": "A token id indicating the end of step (position to predict value label)."
            "could be None or negative to indicate that the value label should be predicted at every token."
        },
    )
    false_eos_ids: Optional[str] = field(
        default="[]",
        metadata={"help": "List[List[int]] indicate the cases that should not be considered as the end of step. "},
    )
    # ChatRecords
    role_tags: Optional[str] = field(
        default=None,
        metadata={
            "help": "Some tokenizers may have special tokens for the role tags (e.g., <｜User｜>, <｜Assistant｜>), instead of"
            "standard format <bos_token> + role (e.g., <|im_start|>assistant). For these tokenizers, need to provide"
            "role_tags dict in order to get a correct label, e.g., {'user': '<｜User｜>','assistant':'<｜Assistant｜>'}"
        },
    )
    role_map: Optional[str] = field(
        default=None,
        metadata={
            "help": "Some tokenizers cannot handle the roles in chat. Provide a role map if you do not want the role to be"
            "automatically detected and mapped. "
        },
    )

    def __post_init__(self):
        # super().__post_init__()
        if self.input_source not in ['preset','rm_present', 'terminal', 'web']:  # file_path
            if self.input_source.startswith("[") and self.input_source.endswith("]"):
                self.input_source = eval(self.input_source)
            else:
                self.input_source = [self.input_source]  # keep it a list for simplicity
            self.input_source = [
                os.path.expanduser("~") + input_source[1:] if input_source.startswith('~') else input_source \
                    for input_source in self.input_source]
            assert all(os.path.isfile(input_source) for input_source in self.input_source), FileExistsError(
                f'some input_source {self.input_source} is not a valid file path.')
            assert all(input_source.endswith(("json", "jsonl")) for input_source in self.input_source), ValueError(
                f"Expect input_source to be a json/jsonl file; got {self.input_source}.")
        if self.save_path in ['None', '']:
            self.save_path = None
        if self.save_path:
            if self.save_path.startswith('~'):
                self.save_path = os.path.expanduser("~") + self.save_path[1:]
            if os.path.isdir(self.save_path):
                self.save_path = os.path.join(self.save_path, 'gen_results.json')
            assert self.save_path.endswith(('.json', '.jsonl')), NotImplementedError(
                'save_path should be a json file path; got {}.'.format(self.save_path))
            if os.path.isfile(self.save_path) and not self.overwrite_save_path:
                raise FileExistsError(f'save_path {self.save_path} exists. Consider setting overwrite_save_path to True.')
        # check prompt_templates
        if self.prompt_templates.startswith(('[', '(')) and self.prompt_templates.endswith((']', ')')):
            self.prompt_templates = eval(self.prompt_templates)
        if isinstance(self.prompt_templates, str):
            assert self.prompt_templates in ALL_PROMPT_TEMPLATES, NotImplementedError(
                "Prompt template {} has not been implemented.".format(self.prompt_templates))
        if self.system_prompt in {"", " ", "None"} or self.system_prompt is None:
            self.system_prompt = ""
        elif self.system_prompt == "default":
            if self.prompt_templates == "llama_qa":
                self.system_prompt = DEFAULT_LLAMA_SYSTEM_PROMPT
            elif self.prompt_templates == "qwen_qa":
                self.system_prompt = DEFAULT_QWEN_SYSTEM_PROMPT
            else:
                warnings.warn(
                    "Current prompt_templates {} does not have default system prompt.".format(self.prompt_templates))
        elif os.path.isfile(self.system_prompt):  # load system prompt from file
            with open(self.system_prompt, "r") as f:
                self.system_prompt = f.read()
            self.system_prompt = self.system_prompt.replace("\\n", '\n')
        if self.max_mu_seq_len <= 0:
            self.max_mu_seq_len = 0
        # self.filter_data_by_indices can also be a file, but the handling of file is done inside data loader,
        # so that its content will not be loggered in generate_lora
        if self.filter_data_by_indices == "None":
            self.filter_data_by_indices = None
        elif (
            isinstance(self.filter_data_by_indices, str) 
            and self.filter_data_by_indices.startswith("[") 
            and self.filter_data_by_indices.endswith("]")
        ):
            self.filter_data_by_indices = eval(self.filter_data_by_indices)
        if self.data_generation_task.startswith("[") and self.data_generation_task.endswith("]"):
            self.data_generation_task = eval(self.data_generation_task)
        self._is_StepBeamSearch_task = self._is_subtask("StepBeamSearch", self.data_generation_task)
        self._is_ValueLabelPrediction_task = self._is_subtask("ValueLabelPrediction", self.data_generation_task)
        self._is_ValueLabelPredictionWithOutput_task = self._is_subtask("ValueLabelPredictionWithOutput", self.data_generation_task)
        self._is_GuardianRiskDetection_task = self._is_subtask("GuardianRiskDetection", self.data_generation_task)
        if self.false_eos_ids.startswith("[") and self.false_eos_ids.endswith("]"):
            self.false_eos_ids = eval(self.false_eos_ids)
        else:
            self.false_eos_ids = []
        if self.role_tags in {"None", ""}:
            self.role_tags = None
        elif isinstance(self.role_tags, str) and self.role_tags.startswith("{"):
            self.role_tags = eval(self.role_tags)
        assert self.role_tags is None or isinstance(self.role_tags, dict), RuntimeError(
            "Fail to parse {}; expect it to be a dict in str.".format(self.role_tags))
        if self.role_map in {"None", ""}:
            self.role_map = None
        elif isinstance(self.role_map, str) and self.role_map.startswith("{"):
            self.role_map = eval(self.role_map)
        assert self.role_map is None or isinstance(self.role_map, dict), RuntimeError(
            "Fail to parse {}; expect it to be a dict in str.".format(self.role_map))
    
    @staticmethod
    def _is_subtask(task1, task2):
        task1 = [task1] if isinstance(task1, str) else task1
        task2 = [task2] if isinstance(task2, str) else task2
        return all(task in task2 for task in task1)


@dataclass
class ExtendedGenerationArguments(GenerationArguments):
    """ Arguments related to the model generation.

    GenerationConfig cantains a comprehensive list of arguments. But unlike TrainingArguments, GenerationConfig 
    is not dataclass with fields, meaning that we could not parse it like we parse other arguments. So here we 
    make this dataclass.

    For more arguments, see examples.run_lora.GenerationArguments and transformers.generation.configruation_utils.GenerationConfig.
    """
    seed: Optional[int] = field(
        default=None,
        metadata={"help": "Provide a positive seed for replicable generation. If a negative value is given, no seed will be used."},
    )
    use_cache: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use cache in generation"},
    )
    # Step beam search related.
    search_num_beams: Optional[int] = field(
        default=1,
        metadata={"help": "The number of beams in step beam search"},
    )
    max_search_steps: Optional[int] = field(
        default=1,
        metadata={"help": "Number of beam search steps in step beam search."},
    )
    num_samples_per_search_step: Optional[int] = field(
        default=10,
        metadata={"help": "Max number of samples when search at each step."},
    )
    max_new_token_per_step: Optional[int] = field(
        default=100,
        metadata={"help": "Max step sequence length when generate"},
    )
    value_by_transition_scores: Optional[bool] = field(
        default=False,
        metadata={
            "help": "Whether to use average transition scores as verifier."
        }
    )
    dedup_mode: Optional[bool] = field(
        default=True,
        metadata={
            "help": "whether to prefer different step while ranking by value function"
        }
    )
    return_all_search_sequences: Optional[bool] = field(
        default=False,
        metadata={
            "help": "whether to return all the end sequences generated while searching"
        }
    )
    # code interpreter related.
    enable_code_interpreter: Optional[bool] = field(
        default=None,
        metadata={"help": "Whether to use code interpreter"},
    )
    code_prefix: Optional[str] = field(default=None, metadata={"help": "Tag before code block"},)
    code_suffix: Optional[str] = field(default=None, metadata={"help": "Tag after code block"},)
    code_output_prefix: Optional[str] = field(default=None, metadata={"help": "Tag before code output block"},)
    code_output_suffix: Optional[str] = field(default=None, metadata={"help": "Tag after code output block"},)
    code_timeout_length: Optional[int] = field(
        default=5,
        metadata={"help": "Maximum running time for each code piece (in seconds)."},
    )

    def __post_init__(self):
        super().__post_init__()
        if self.seed < 1:
            self.seed = None
        if self.enable_code_interpreter:
            assert isinstance(self.code_prefix, str) and self.code_prefix != "None", ValueError("code_prefix must be provided.")
            assert isinstance(self.code_suffix, str) and self.code_suffix != "None", ValueError("code_suffix must be provided.")
            assert isinstance(self.code_output_prefix, str) and self.code_output_prefix != "None", ValueError("code_output_prefix must be provided.")
            assert isinstance(self.code_output_suffix, str) and self.code_output_suffix != "None", ValueError("code_output_suffix must be provided.")
            self.code_prefix = self.code_prefix.replace("\\n", "\n")
            self.code_suffix = self.code_suffix.replace("\\n", "\n")
            self.code_output_prefix = self.code_output_prefix.replace("\\n", "\n")
            self.code_output_suffix = self.code_output_suffix.replace("\\n", "\n")
            # To disable the warning:
            # huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
            os.environ["TOKENIZERS_PARALLELISM"] = "false"


def save_gen_results(gen_results, starting_index=0, save_path=None):
    if save_path and gen_results:
        for index, gen_result in enumerate(gen_results):
            gen_result['index'] = index + starting_index
        logger.info('Save output to {}.'.format(save_path))
        # when starting_index > 0, we combine results with previous results
        if os.path.isfile(save_path) and starting_index > 0:
            existing_gen_results = load_jsons(save_path)
            if len(existing_gen_results) < starting_index:
                logger.warning("Existing file at {} does not contain {} instances. Are you sure you want to start at {}?".format(
                    save_path, starting_index, starting_index))
                gen_results = existing_gen_results + gen_results
            else:
                gen_results = existing_gen_results[:starting_index] + gen_results
        
        with open(save_path, 'w', encoding='utf-8') as f: 
            f.write(json.dumps(gen_results, indent=2, ensure_ascii=False))


def load_model_base_and_adapter(
    base_model_path, 
    model_config, 
    adapter_path, 
    model_args, 
    data_args, 
    tokenizer,
    index=0,
    ):
    base_model_prefix = model_args.base_model.rsplit('-', 1)[0]
    model_class = MODEL_CLASS.get(base_model_prefix, AutoModelForCausalLM)

    # load base model
    model = model_class.from_pretrained(
        base_model_path,
        load_in_8bit=model_args.load_in_8bit,
        torch_dtype=DTYPE_CLASS[model_args.torch_dtype],
        device_map='auto',
        config=model_config,
        use_flash_attention_2=model_args.use_flash_attn,
    )

    # replace methods like generate, sample, with customized ones
    model = generation_post_init(model)

    # load lora
    if adapter_path:
        model = MoPeftModel.from_pretrained(
            model,
            adapter_path,
            adapter_name=model_args.adapter_name,
            is_trainable=False,  # inference mode
            molora_strategy=model_args.molora_strategy,
        )
        logger.info('LoRA weights loaded from {}.'.format(adapter_path))
        if model_args.merge_lora:
            model = model.merge_and_unload()
            logger.info('LoRA weights merged into base model.')
    model.config.use_cache = False  # do not return the last key/values attentions
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.bos_token_id = tokenizer.bos_token_id
    model.config.eos_token_id = tokenizer.eos_token_id
    
    # load value head
    if (
        (
            (data_args._is_StepBeamSearch_task or data_args._is_ValueLabelPrediction_task)
            and not (isinstance(model_args.model_name_or_path, list) and index == 0)
        )
        or data_args._is_ValueLabelPredictionWithOutput_task
    ):
        model = MyAutoModelForCausalLMWithValueHead(model, model_args)
        model.is_peft_model = not model_args.merge_lora
        model.device = model.pretrained_model.device
        if adapter_path:
            model, valid_ckpt_path = load_rl_weights_into_model(model, adapter_path)
            if valid_ckpt_path:
                logger.info('Value head weights loaded from {}.'.format(valid_ckpt_path))
            else:
                logger.warning('Did not detect value head checkpoint from:\n{}\nIs this intended?'.format(adapter_path))

    return model


def call_model_generator_on_file(model_generator, data_args, input_source, save_path):
    """ Call model_generator to process one file or one set of files.
    
    Args:
        model_generator:
        input_source: str or List[str]
        data_args: 
    """
    step = 0
    gen_results = []
    starting_index = data_args.starting_index

    logger.info("\n\n=================================================")
    logger.info("Processing input_source:\n{}\n".format(input_source))
    if starting_index > 0:
        logger.info("Starting from sample index {}".format(starting_index))
        
    if data_args.prompt_templates.endswith("chat"):
        try:
            # check input arguments
            columns_to_read = ["chat", "instruction", "input", "output"]
            dataset = load_datasets(input_source, columns_to_read=columns_to_read)
            filter_data_by_indices = getattr(data_args, "filter_data_by_indices", None)
            if filter_data_by_indices:
                dataset = get_filtered_data_by_indices(
                    dataset, filter_data_by_indices, include_by_indices=getattr(data_args, "include_by_indices", True))
            num_samples = len(dataset)

            for index in tqdm(range(starting_index, num_samples, data_args.eval_batch_size), desc='Processing'):
                batch = dataset[index:min(index + data_args.eval_batch_size, num_samples)]
                chats, instructions = batch.get('chat', None), batch.get('instruction', None)
                if chats is not None:  # chats is List[List[Dict]], each List[Dict] is a chat to be converted to ChatRecords
                    if data_args._is_ValueLabelPredictionWithOutput_task:
                        queries = chats  # chat cannot be stochastic
                        for chat in chats:
                            for utterance in chat:
                                if isinstance(utterance["content"], list):
                                    utterance["content"] = utterance["content"][0]
                    else:
                        queries = [[utterance for utterance in chat if (utterance["role"] != "Assistant") or \
                            utterance["metadata"]["task"].endswith(".Q")] for chat in chats]
                    # for data starting with system message, add system message to chat
                    # for the rest, add an empty message that adds nothing
                    if any("system" in query[0]["role"] for query in queries):
                        system_prompts = [([query[0]] if query[0]['role'] == "system" else [{}]) for query in queries]
                        queries = [(query[1:] if query[0]['role'] == "system" else query) for query in queries]
                        model_generator.add_to_chat(system_prompts)

                    if data_args.eval_batch_size > 1:  # assign index within batch
                        for index, query in enumerate(queries):
                            query[0]['metadata'].update({"index": index})
                    round_index = 1
                    if data_args._is_ValueLabelPredictionWithOutput_task:
                        queries_current = [query[:(2*round_index)] for query in queries if len(query) >= 2*round_index]
                    else:
                        queries_current = [query[:(2*round_index-1)] for query in queries if len(query) >= 2*round_index-1]
                    while len(queries_current) > 0:
                        _ = model_generator.batch_response(queries_current)
                        chat = [chat.clone() for chat in model_generator.chat]
                        model_generator.clear_cache()
                        # assign chat back to queries
                        if not data_args._is_ValueLabelPredictionWithOutput_task:
                            if data_args.eval_batch_size > 1:
                                for _chat in chat:
                                    index = _chat[0]['metadata']['index']
                                    queries[index] = _chat + queries[index][(2*round_index-1):]
                            else:
                                queries[0] = chat[0] + queries[0][(2*round_index-1):]
                        # update queries_current
                        round_index += 1
                        if data_args._is_ValueLabelPredictionWithOutput_task:
                            queries_current = [query[:(2*round_index)] for query in queries if len(query) >= 2*round_index]
                        else:
                            queries_current = [query[:(2*round_index-1)] for query in queries if len(query) >= 2*round_index-1]
                    # removes index within batch
                    if data_args.eval_batch_size > 1:
                        for index, query in enumerate(queries):
                            query[0]['metadata'].pop("index")
                    gen_results.extend([{"index": None, "chat": chat} for chat in queries])
                elif instructions is not None:  # instructions is List[str]
                    inputs = batch.get('input', [""]*len(instructions))                        
                    if data_args._is_ValueLabelPrediction_task or (data_args._is_StepBeamSearch_task and model_generator.config.return_all_search_sequences):
                        responses, values = model_generator.batch_response(instructions)
                        if isinstance(responses, str):
                            responses = [responses]
                        gen_results.extend([
                            {'instruction': inst, 'input': input_, 'output': output, 'values': value} \
                                for inst, input_, output, value in zip(instructions, inputs, responses, values)])
                    elif data_args._is_ValueLabelPredictionWithOutput_task:
                        responses = batch.get('output', None)  # responses is List[str] and each response cannot be stochastic
                        if isinstance(responses[0], list):
                            responses = [response[0] for response in responses]
                        queries = [[
                            {"role": "Human", "content": instruction, "metadata": {}},
                            {"role": "Assistant", "content": output, "metadata": {}},
                        ] for instruction, output in zip(instructions, responses)]
                        values = model_generator.batch_response(queries)
                        gen_results.extend([
                            {'instruction': inst, 'input': input_, 'output': output, 'values': value} \
                                for inst, input_, output, value in zip(instructions, inputs, responses, values)])
                    else:
                        outputs = model_generator.batch_response(instructions)
                        if isinstance(outputs, str):
                            outputs = [outputs]
                        gen_results.extend([
                            {'instruction': inst, 'input': input_, 'output': output} \
                                for inst, input_, output in zip(instructions, inputs, outputs)])
                    model_generator.clear_cache()

                step += 1
                if data_args.save_steps > 0 and step % data_args.save_steps == 0:
                    save_gen_results(gen_results, starting_index, save_path)

        except torch.cuda.OutOfMemoryError:  # issubclass(torch.cuda.OutOfMemoryError, BaseException) --> True
            logger.error("CUDA out of memory. {} batches processed".format(index))
    else:
        if model_generator.config.output_answer_probs:
            columns_to_read = ["instruction", "input", "output"]
        elif model_generator.prompt_templates in ["granite_guardian_qa"]:
            columns_to_read = ["instruction", "input", "output", "risk_name"]
        else:
            columns_to_read = ["instruction", "input"]
        dataset = load_datasets(data_args.input_source, columns_to_read=columns_to_read)
        num_samples = len(dataset)
        try:
            for index in tqdm(range(starting_index, num_samples, data_args.eval_batch_size), desc='Processing'):
                batch = dataset[index:min(index + data_args.eval_batch_size, num_samples)]
                instructions, outputs = batch['instruction'], batch.get("output", None)
                actual_batch_size = len(instructions)
                inputs = batch.get('input', [""]*len(instructions))
                if model_generator.config.output_answer_probs:
                    responses = model_generator.batch_response(list(zip(instructions, inputs, outputs)))
                elif data_args.prompt_templates.endswith("chat"):
                    responses = model_generator.batch_response(instructions)
                else:
                    instructions_dict = [{key: values[index] for key, values in batch.items()} for index in range(actual_batch_size)]
                    responses = model_generator.batch_response(instructions_dict)
                if model_generator.prompt_templates in ["granite_guardian_qa"]:  # output is risk_prediction
                    gen_results.extend([
                        {
                            'instruction': inst, 
                            'input': input_, 
                            'output': output, 
                            'risk_name': risk_name, 
                            'risk_prediction': response
                        } for inst, input_, output, risk_name, response in zip(
                            instructions, inputs, outputs, batch["risk_name"], responses)])
                else:
                    gen_results.extend([
                        {
                            'instruction': inst, 
                            'input': input_, 
                            'output': response
                        } for inst, input_, response in zip(instructions, inputs, responses)])
 
                step += 1
                if data_args.save_steps > 0 and step % data_args.save_steps == 0:
                    save_gen_results(gen_results, starting_index, save_path)

        except torch.cuda.OutOfMemoryError:  # issubclass(torch.cuda.OutOfMemoryError, BaseException) --> True
            logger.error("CUDA out of memory. {} batches processed".format(index))

    save_gen_results(gen_results, starting_index, save_path)

    return gen_results

def parse_args(args):
    # We keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser((ExtendedModelArguments, DataArguments, ExtendedGenerationArguments))

    if len(args) == 1 and args[0].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, generation_args = parser.parse_json_file(json_file=os.path.abspath(args))
    else:
        model_args, data_args, generation_args = parser.parse_args_into_dataclasses(args)

    return model_args, data_args, generation_args


def get_generation_config(data_args, generation_args, tokenizer):
    # set up generation_config
    generation_config = GenerationConfig(**copy.deepcopy(generation_args.__dict__))
    # Set seed before initializing model.
    if generation_config.seed:
        set_seed(generation_config.seed)
    # configure generation_config
    generation_config.pad_token_id = tokenizer.pad_token_id
    generation_config.bos_token_id = tokenizer.bos_token_id
    generation_config.eos_token_id = tokenizer.eos_token_id
    # set up role_tags and role_map
    generation_config.role_tags = data_args.role_tags
    generation_config.role_map = data_args.role_map
    if (
        data_args._is_StepBeamSearch_task
        or data_args._is_ValueLabelPrediction_task 
        or data_args._is_ValueLabelPredictionWithOutput_task
    ):
        generation_config.end_of_step_id = data_args.end_of_step_id
        generation_config.false_eos_ids = data_args.false_eos_ids
    
    if generation_config.ensemble_method == "reward_model":
        generation_config.return_dict_in_generate = True
        generation_config.output_scores = True
        
    if generation_config.output_answer_probs == True:
        generation_config.return_dict_in_generate = True
        generation_config.output_scores = True
        # post_process_fn = GeneratorRetainAnswerPorbsProcessor(tokenizer)
    if data_args.prompt_templates in ["granite_guardian_qa"]:
        generation_config.output_scores = True
    if data_args.data_generation_task == "GuardianRiskDetection":
        generation_config.output_scores = True
    logger.info(generation_config)

    return generation_config


def load_model(base_model_prefix, base_model_path, model_args, data_args, generation_args, tokenizer):
    model_config = None
    if base_model_prefix in {'llama3', 'llama3.3'}:
        if data_args.prompt_templates == "llama_qa":
            data_args.prompt_templates = "llama_new_qa"
            if data_args.system_prompt == DEFAULT_LLAMA_SYSTEM_PROMPT:
                from datetime import datetime

                data_args.system_prompt = DEFAULT_LLAMA_NEW_SYSTEM_PROMPT.format(TODAY=datetime.today().strftime("%d %b %Y"))
        elif data_args.prompt_templates == "llama_chat":
            data_args.prompt_templates = "llama_new_chat"
        config_class = CONFIG_CLASS.get(base_model_prefix, AutoConfig)
        model_config = config_class.from_pretrained(base_model_path)
        model_config.pad_token_id = tokenizer.pad_token_id
    elif base_model_prefix == 'qwen':
        from models.qwen.megatron_utils import compile_megatron_dependencies

        config_class = CONFIG_CLASS.get(base_model_prefix, AutoConfig)
        model_config = config_class.from_pretrained(base_model_path)
        model_config = copy.deepcopy(model_config)
        model_config.fp16 = model_args.torch_dtype == "float16"
        model_config.bf16 = model_args.torch_dtype == "bfloat16"
        model_config.tensor_model_parallel_size = 1
        model_config.micro_batch_size = 1
        model_config.masked_softmax_fusion = True
        model_config.gradient_accumulation_fusion = False
        compile_megatron_dependencies(
            model_config, master_port=data_args.pseudo_master_port)
    elif base_model_prefix == 'recur-qwen2.5':
        config_class = CONFIG_CLASS.get(base_model_prefix, AutoConfig)
        model_config = config_class.from_pretrained(base_model_path)
        model_config.pad_token_id = tokenizer.pad_token_id
        model_config.recur_strategy = model_args.recur_strategy
        model_config.recur_times = model_args.recur_times
        model_config.num_prelude_layers = model_args.num_prelude_layers
        model_config.num_coda_layers = model_args.num_coda_layers
        model_config.input_injection_type = model_args.input_injection_type
        model_config.state_init_strategy = model_args.state_init_strategy
        model_config.init_std = model_args.init_std
        model_config.attn_to_recur_key_values = model_args.attn_to_recur_key_values
        model_config.ln_after_recur = model_args.ln_after_recur
    else:
        config_class = CONFIG_CLASS.get(base_model_prefix, AutoConfig)
        model_config = config_class.from_pretrained(base_model_path)
        model_config.pad_token_id = tokenizer.pad_token_id
    # load the model
    if isinstance(model_args.model_name_or_path, (str, dict, type(None))):
        model = load_model_base_and_adapter(base_model_path, model_config, model_args.model_name_or_path, model_args, data_args, tokenizer)
    elif isinstance(model_args.model_name_or_path, list):
        assert (
            data_args._is_StepBeamSearch_task 
            or data_args._is_ValueLabelPrediction_task
        ), NotImplementedError("Loading multiple models is supported in only a few tasks, such as StepBeamSearch and ValueLabelPrediction.")
        model = [
            load_model_base_and_adapter(base_model_path, model_config, adapter_path, model_args, data_args, tokenizer, index) \
                for index, adapter_path in enumerate(model_args.model_name_or_path)
        ]
    else:
        raise TypeError("Expect model_name_or_path to be one of str, dict, list, or None; got {}".format(type(model_args.model_name_or_path)))

    if generation_args.enable_code_interpreter:
        from utils.solver_utils import PythonExecutor
        # TODO when merge_lora is False, model is MoPEFT class, and code_interpreter is not assigned
        model.code_interpreter = PythonCodeInterpreter(
            tokenizer, 
            code_prefix=generation_args.code_prefix,
            code_suffix=generation_args.code_suffix,
            code_output_prefix=generation_args.code_output_prefix,
            code_output_suffix=generation_args.code_output_suffix,
            executor=PythonExecutor(
                share_runtime_in_batch=False, 
                get_answer_from_stdout=True, 
                timeout_length=generation_args.code_timeout_length)
        )

    return model


def get_model_generator(model, model_args, data_args, generation_config, post_process_fn, tokenizer):
    if isinstance(model_args.model_name_or_path, dict):
        num_parallel_adapters = len(model_args.model_name_or_path)
    elif isinstance(model_args.model_name_or_path, list):
        num_parallel_adapters = [len(adapter_path) if isinstance(adapter_path, dict) else 1 for adapter_path in model_args.model_name_or_path]
        if model_args.molora_strategy == "parallel":
            assert all(_num_adapters == num_parallel_adapter[0] for _num_adapters in num_parallel_adapters[1:]), RuntimeError(
                "In case of parallel molora_strategy, all num_parallel_adapters must be equal."
            )
        num_parallel_adapters = num_parallel_adapters[0]
    else:
        num_parallel_adapters = 1
    if generation_config.output_answer_probs == True:
        model_generator = ProbsGenerator(
            model, generation_config, tokenizer, 
            prompt_templates=data_args.prompt_templates, 
            system_prompt=data_args.system_prompt, 
            enforce_cn_chars=data_args.enforce_cn_chars,
            post_process_fn=post_process_fn,
            num_parallel_adapters=num_parallel_adapters,
            )
        print("probs generater")
    elif data_args.prompt_templates.endswith("chat"):
        model_generator = ChatGenerator(
            model, generation_config, tokenizer, 
            prompt_templates=data_args.prompt_templates, 
            system_prompt=data_args.system_prompt, 
            post_process_fn=post_process_fn,
            num_parallel_adapters=num_parallel_adapters,
            max_mu_seq_len=data_args.max_mu_seq_len,
            use_chat_cache=False,
            data_generation_task=data_args.data_generation_task,
            )
    else:
        model_generator = TextGenerator(
            model, generation_config, tokenizer, 
            prompt_templates=data_args.prompt_templates, 
            system_prompt=data_args.system_prompt, 
            enforce_cn_chars=data_args.enforce_cn_chars,
            post_process_fn=post_process_fn,
            num_parallel_adapters=num_parallel_adapters,
            max_mu_seq_len=data_args.max_mu_seq_len,
            )
    # set model to eval mode
    model_generator.model.eval()

    return model_generator


def main(args):
    """ Step 1. Prepare arguments and logging """
    model_args, data_args, generation_args = parse_args(args)

    # Setup logging; the local rank is -1 by default
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    transformers.utils.logging.set_verbosity_info()
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # print 
    logger.info('Available cuda devices: {}'.format(torch.cuda.device_count()))  # several cuda may be given
    logger.info('Pytorch version: {}; transformers version {}'.format(torch.__version__, transformers.__version__))
    logger.info("\nData related arguments:\n %s", data_args)
    logger.info("\nModel arguments:\n %s", model_args)
    logger.info("\nGeneration related arguments:\n %s", generation_args)

    """ Step 2. Configure tokenizer and model """
    base_model_path = os.path.join(CKPT_FOLDER, model_args.base_model)
    base_model_prefix = model_args.base_model.rsplit('-', 1)[0]
    # set up tokenizer and generation_config
    tokenizer = get_tokenizer(base_model_prefix, base_model_path)
    generation_config = get_generation_config(data_args, generation_args, tokenizer)
    post_process_fn = None
    if generation_config.ensemble_method == "reward_model":
        post_process_fn = RMGeneratorRetainPorbsProcessor(tokenizer)
    if base_model_prefix == "granite-guardian-3.1":
        from functools import partial
        post_process_fn = partial(
            granite_guardian_post_process_fn, tokenizer, safe_token="No", unsafe_token="Yes", nlogprobs=20)
    elif data_args._is_GuardianRiskDetection_task:
        from functools import partial
        post_process_fn = partial(
            granite_guardian_post_process_fn, tokenizer, safe_token="No", unsafe_token="Yes", nlogprobs=20)
    # configure model_config and load model
    model = load_model(base_model_prefix, base_model_path, model_args, data_args, generation_args, tokenizer)

    """ Step 3. Set up interface. """
    model_generator = get_model_generator(model, model_args, data_args, generation_config, post_process_fn, tokenizer)
    
    # do generation for different input_source
    starting_index = 0
    if isinstance(data_args.input_source, str) and data_args.input_source == 'preset':
        instructions = [
            "妈妈的妈妈叫什么？",
            "Tell me about the president of Mexico in 2019.",
            "Tell me about the king of France in 2019.",
            "List all Canadian provinces in alphabetical order.",
            "Write a Python program that prints the first 10 Fibonacci numbers.",
            "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.",
            "Tell me five words that rhyme with 'shock'.",
            "Translate the sentence 'I have no mouth but I must scream' into Spanish.",
            "Count up from 1 to 500.",
        ]
        gen_results = []
        for instruction in instructions:
            response = model_generator(instruction)
            if data_args.save_path:
                gen_results.append({"instruction": instruction, "response": response})
            logger.info("Human:\n{}\n".format(instruction))
            logger.info("Assistant:\n{}\n".format(response))

        """ Step 4. Additional save gen_results. """
        save_gen_results(gen_results, starting_index, data_args.save_path)

    elif isinstance(data_args.input_source, str) and data_args.input_source == 'rm_preset':
        instance = {
            "instruction": "卖椰子的人必须使用人力车或牛车来运送椰子。人力车每辆可运载 50 个椰子，单程费用为 10 元。牛车每辆可运载 30 个椰子，单程费用为 8 元。卖家最多有 200 元可用于运输椰子。人力车的数量不得超过牛车的数量。如何设计运输，可以最大化运输的椰子数量?",
            "input": "",
            "output": [
            "变量：人力车的数量，牛车的数量。分别定义为：x，y。\n目标：maximize 50 * x + 30 * y。\n约束条件：x <= y，10 * x + 8 * y <= 200，x >= 0，y >= 0。",
            "变量：人力车的数量，牛车的数量。分别定义为：x，y。\n目标：maximize 50 * x + 30 * y。\n约束条件：x <= y，10 * x + 8 * y <= 200，x >= 0，y >= 0。",
            "变量：人力车的数量，牛车的数量。分别定义为：x，y。\n目标：maximize 50 * x + 30 * y。\n约束条件：10 * x + 8 * y <= 200，x <= y，x >= 0，y >= 0。",
            "变量：人力车的数量，牛车的数量。分别定义为：x，y。\n目标：maximize 50 * x + 30 * y。\n约束条件：x <= y，10 * x + 8 * y <= 200，x >= 0，y >= 0。",
            "变量：人力车的数量，牛车的数量。分别定义为：x，y。\n目标：maximize 50 * x + 30 * y。\n约束条件：10 * x + 8 * y <= 200，x <= y，x >= 0，y >= 0。",
            "变量：人力车的数量，牛车的数量。分别定义为：x，y。\n目标：maximize 50 * x + 30 * y。\n约束条件：x <= y，10 * x + 8 * y <= 200，x >= 0，y >= 0。",
            "变量：人力车的数量，牛车的数量。分别定义为：x，y。\n目标：maximize 50 * x + 30 * y。\n约束条件：10 * x + 8 * y <= 200，x <= y，x >= 0，y >= 0。"
            ],
            "index": 0,
            "lang": "cn"
            }
        answer_batch = instance["output"]
        rm_instances = get_rm_ensemble_instance(instance)
        instructions = [rm_instance['instruction'] for rm_instance in rm_instances]
        inputs = [rm_instance['input'] for rm_instance in rm_instances]
        gen_results = []
        responses = []
        for instruction in instructions:
            response = model_generator(instruction)
            responses.append(response)
        ensemble_response = ensemble_by_yes_probs(answer_batch, responses)
        instruction = instance['instruction']
        # if data_args.save_path:
        #     gen_results.append({"instruction": instruction, "response": ensemble_response})
        logger.info("Human:\n{}\n".format(instruction))
        logger.info("Assistant:\n{}\n".format(ensemble_response))

        """ Step 4. Additional save gen_results. """
        save_gen_results(gen_results, starting_index, data_args.save_path)

    elif isinstance(data_args.input_source, str) and data_args.input_source == 'terminal':
        # set up a second logger that prints to terminal
        consoleHandler = logging.StreamHandler(STDOUT)
        consoleHandler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s -   %(message)s"))
        console_logger = logging.getLogger('console')
        console_logger.addHandler(consoleHandler)
        gen_results = []
        while True:
            console_logger.info("Please input instruction:")
            instruction = input()
            if instruction in {'break', 'quit', 'exit', 'quit()', 'exit()'}:
                break
            if instruction in {'clear'} and hasattr(model_generator, 'clear_cache'):
                model_generator.clear_cache()
                continue
            if generation_config.streaming:
                response_iterator = model_generator.stream_response(instruction)
                console_logger.info("Response:\n")
                response = []
                for new_text in response_iterator:
                    console_logger.info(new_text)
                    response.append(new_text)
                if data_args.save_path:
                    gen_results.append({"instruction": instruction, "response": "".join(response)})
            else:
                response = model_generator(instruction)
                if data_args.save_path:
                    gen_results.append({"instruction": instruction, "response": response})
                console_logger.info("Response:\n{}".format(response))

        """ Step 4. Additional save gen_results. """
        save_gen_results(gen_results, starting_index, data_args.save_path)

    elif isinstance(data_args.input_source, str) and data_args.input_source == 'web':
        gen_results = []  # TODO support saving gen_results
        demo = gr.Interface(
            fn=model_generator,
            inputs=[
                gr.components.Textbox(
                    lines=2,
                    label="Instruction",
                    placeholder="Tell me about alpacas.",
                ),
                gr.components.Textbox(lines=2, label="Input", placeholder="none"),
            ],
            outputs=[
                gr.inputs.Textbox(
                    lines=5,
                    label="Output",
                )
            ],
            title="LoRA",
            description="This is a sandbox for LLM models trained with LoRA.",
        )
        server_app, local_url, share_url = demo.launch(
            server_name="127.0.0.1", share=data_args.share_link, inbrowser=True)
        logger.info("local_url: {}".format(local_url))
        logger.info("share_url: {}".format(share_url))

        """ Step 4. Additional save gen_results. """
        save_gen_results(gen_results, starting_index, data_args.save_path)

    elif isinstance(data_args.input_source, list):  # [json paths]
        input_source = data_args.input_source

        if data_args.is_input_source_data_paths:
            save_folder = data_args.save_path
            if save_folder.endswith(".json"):
                save_folder = os.path.dirname(save_folder)

            input_sources = load_jsons(input_source)
            if isinstance(input_sources, list) and isinstance(input_sources[0], dict):
                input_sources = {name: input_path for sub_source in input_sources for name, input_path in sub_source.items()}
            assert isinstance(input_sources, dict), TypeError(
                "Expect input_sources to be dict; got {}".format(type(input_sources)))
            assert all(os.path.isfile(input_path) for input_path in input_sources.values()), FileNotFoundError(
                "Detect invalid file_path:\n{}".format("\n".join([input_path for input_path in input_sources.values() if not os.path.isfile(input_path)])))
            logger.info("Processing {} input_sources...".format(len(input_sources)))

            for json_name, input_path in input_sources.items():
                save_path = os.path.join(save_folder, json_name)
                call_model_generator_on_file(model_generator, data_args, input_path, save_path)
                torch.cuda.empty_cache()
        else:
            call_model_generator_on_file(model_generator, data_args, input_source, data_args.save_path)
    else:
        raise NotImplementedError('Cannot parse input_source {}'.format(data_args.input_source))


if __name__ == "__main__":
    main(sys.argv[1:])