import os
from .themis_utils import *
from criteria_for_qs import question_source_dict
import re
from .prompts_v2 import *
from .nips2024_prompts import *
from tqdm import tqdm
try:
    from tigerscore import TIGERScorer
except:
    pass
from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig
import ipdb
import re
import json
import time
from vllm import SamplingParams
from transformers import AutoTokenizer, AutoConfig
import openai
import tiktoken

import evaluate_configuration
from model_template import UltraCM, AutoJ


def parse_math_result(string):
    '''最终答案肯定在后面，我们从后面parse数字即可'''
    try:
        if type(string) == dict:
            if 'Rationale' in string:
                string = string['Rationale']
            elif 'rationale' in string:
                string = string['rationale']
            elif 'Answer' in string:
                string = ' ' + string['Answer'] + ' '
            else:
                ipdb.set_trace()
        elif type(string) in [int, float]:
            return string
        # string = string.replace('\n', ' ').replace('=', '= ').replace('.', ' ').replace(',', ' ')
        string = string.replace('\n', ' ').replace('=', '= ').replace(',', ' ')
        res = None
        for v in reversed(string.split()):
            try:
                if eval(v):
                    if type(eval(v)) == float or type(eval(v)) == int:
                        res = eval(v)
                        break
            except:
                pass
        if res:
            decision = res
        else:
            decision = ''
    except:
        return ''
    return decision


PROMPT_v2 = '''# Goal
You are a helpful assistant aiming to provide valuable critiques and analysis for the **evaluated response** of the **conversation history** between an user and an assistant.
Besides, we also provide some **initial criteria** for you to assit your evaluation.

Three input information are listed as below
---
## Conversation history
{query}
---

---
## Evaluated Response
{evaluated_response}
---

---
## Our Provided Criteria
{my_criteria}
---

# Task
To generate the valuable critiques, you should follow these 4 steps to generate valuable and accurate critiques:
**Step 1:** generate your high-quality reference answer for better critiques
**Step 2:** analyze the purpose of the user role in the previous conversation history
**Step 3:** generate the two-tier detailed criteria list (if our provided criteria list is not empty, please expand ours)
**Step 4:** generate your detailed feedbacks, followed by a summarizaing containing the final judgemen score (ranging from 1 to 10, where higher score denotes better quaulity of evaluated response.)

# NOTICE
1. the last response in the conversation history contains the citation symbols for better critiques, like [S1] and [S2]. Do NOT critique on these citation symbols

# Output
Please output your critiques
'''



PROMPT = '''# Goal
You are a helpful assistant aiming to provide valuable critiques and analysis for the **evaluated response** of the **conversation history** between an user and an assistant.
Besides, we also provide some **initial criteria** for you to assit your evaluation.

Three input information are listed as below
---
## Conversation history
{query}
---

---
## Evaluated Response
{evaluated_response}
---

---
## Our Provided Criteria
{my_criteria}
---

# Task
To generate the valuable critiques, you should follow these 4 steps to generate valuable and accurate critiques:
**Step 1:** analyze the purpose of the user role in the previous conversation history
**Step 2:** generate the two-tier detailed criteria list (if our provided criteria list is not empty, please expand ours)
**Step 3:** generate your high-quality reference answer for better critiques
**Step 4:** generate your detailed feedbacks, followed by a summarizaing containing the final judgemen score (ranging from 1 to 10, where higher score denotes better quaulity of evaluated response.)

# NOTICE
1. the last response in the conversation history contains the citation symbols for better critiques, like [S1] and [S2]. Do NOT critique on these citation symbols

# Output
Please output your critiques
'''


def segment_response(response, min_length=3, criteria_type=''):
    """segment the input response and add the citation box brackets
    Refer to: https://arxiv.org/pdf/2305.14627.pdf
    
    we segment the response with the sentence boundary

    如果是代码题，那么下面的匹配规则可能有问题
    """
    
    coding_criteria_types = set([
        'code_simplification',
        'code_generation',
        'explaining_code',
        'code_correction_rewriting',
        'code_to_code_translation',
        'airoboro2.2_coding'
    ])
    if criteria_type in coding_criteria_types:
        # use the \n as the separator
        segments = re.split(r'(\n+)', response)
        index = 1
        strings = []
        for segment in segments:
            if segment.strip():
                # pure text
                if segment[-1] in '.?!;:,+-*~!。，：！？':
                    segment_ = segment[:-1] + f" [S{index}]{segment[-1]}"
                else:
                    segment_ = segment + f" [S{index}]"
                strings.append(segment_)
                index += 1
            else:
                strings.append(segment)
        new_response = ''.join(strings)
    else:
        response += ' '    # 兼容如下的正则表达式中的空白符号匹配
        if not response.strip():
            response = 'Empty Response'
        # ignore the enumeration like "1. ...; 2. ...;"
        segments = re.split(r'([\.?!;]\s)', response)
        # segments = re.split(r'(?<=[^A-Z].[.?]) +(?=[A-Z])', response)
        segments = [segment for segment in segments if segment.strip()]
        # add the brackets sequentially
        index = 1
        if len(segments) == 1:
            # no punctation, just add the brackets in the end
            new_response = f'{segments[0]} [S{index}]'
            index += 1
        else:
            strings = []
            for segment in segments:
                if segment[0] in '.?!;':
                    if len(strings) == 0:
                        strings.append(segment)
                    elif len(strings[-1].strip()) >= min_length:
                        # effective citation
                        segment = f" [S{index}]{segment}"
                        strings.append(segment)
                        index += 1
                    else:
                        strings[-1] += segment
                else:
                    strings.append(segment)
                    
            if strings[-1].strip()[-1] not in '.?!;':
                strings[-1] = strings[-1].strip() + f" [S{index}]"
            new_response = ''.join(strings)
        try:
            assert index > 1
        except:
            if new_response.endswith('.') or new_response.endswith('?') or new_response.endswith('!') or new_response.endswith(';'):
                punc = new_response[-1]
                new_response = new_response[:-1] + f' [S1]{punc}'
            else:
                new_response += f' [S1].'
    # print(new_response)
    # exit()
    return new_response



#####################################################################################
# for autoj-13b
from vllm import LLM, SamplingParams
import torch
from utils.constants_prompt import build_autoj_input # constants_prompt -> codes/constants_prompt.py


# 
FINETUNE_INST = "You are evaluating errors in a model-generated output for a given instruction."
FINETUNE_INPUT = """\
Instruction: 
{generation_instruction}
{input_context}

Model-generated Output: 
{hypothesis_output}

For each error you give in the response, please also elaborate the following information:
- error location (the words that are wrong in the output)
- error aspect it belongs to.
- explanation why it's an error, and the correction suggestions.
- severity of the error ("Major" or "Minor"). 
- reduction of score (between 0.5 and 5 given the severity of the error)

Your evaluation output:\
"""


# for ultracm
ultracm_instruction_template = """Given my answer to an instruction, your role is to provide specific and constructive feedback for me. You should find the best way for me to learn from your feedback and improve my performance. 

You should consider multiple aspects of my answer, including helpfulness, truthfulness, honesty, and to what extent the answer follows instructions.
---

### Instruction
{instruction}

### Answer
{completion}
---

Please act as a teacher and provide specific and constructive feedback. Besides describing the weaknesses of the answer, you should also provide specific suggestions to guide me toward understanding how to improve. Please note, however, that your suggestions should help me better complete the instructions, but you should not introduce new requirements that are not mentioned in the instructions. Your feedback should focus on enhancing my ability to think critically and respond accurately. However, never explicitly provide the reference answer, nor do polite phrases be required. Only respond with concise feedback in chat style. Finally, score the overall quality of the answer from 1 to 10, where 1 is the worst and 10 is the best.

*Format*
### Feedback
Overall Score: [1-10]
[Your feedback]

---

### Feedback
Overall Score: 
"""
def generate_feedback(generator, example, tokenizer):
    system_prompt = "User: A one-turn chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, very detailed, and polite answers to the user's questions.</s>"
    conv = [system_prompt]
    conv.append("User: " + ultracm_instruction_template.format(
                                    instruction=example["instruction"],
                                    completion=example["completion"],
                                    ) + "</s>")
    conv.append("Assistant: ")
    prompt = "\n".join(conv)

    with torch.no_grad():
        input_tokens = tokenizer.encode(prompt)
        length = len(input_tokens)
        input_tokens = torch.LongTensor(input_tokens).unsqueeze(0).cuda()
        response_tokens = generator.generate(input_tokens, do_sample=True, temperature=1.0, top_p=1, max_new_tokens=1024, repetition_penalty=1.2)
        response = tokenizer.decode(response_tokens[0][length:]).strip('\n').strip()
        # response = tokenizer.decode(response_tokens[0]).strip('\n').strip()
        response = response.replace('</s>', '').strip()
    return response


def generate_feedback_batch(generator, examples):
    prompts = []
    for example in examples:
        system_prompt = "User: A one-turn chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, very detailed, and polite answers to the user's questions.</s>"
        conv = [system_prompt]
        conv.append("User: " + ultracm_instruction_template.format(
                                        instruction=example["instruction"],
                                        completion=example["completion"],
                                        ) + "</s>")
        conv.append("Assistant: ")
        prompt = "\n".join(conv)
        prompts.append(prompt)

    sampling_params = SamplingParams(temperature=1.0, top_p=1.0, max_tokens=1024, n=1)
    outputs = generator.generate(prompts, sampling_params)
    responses = [output.outputs[0].text for output in outputs]
    return responses
#####################################################################################




def template_maker(msg, mode='autoj'):
    # zero-shot baseline template 
    if mode == 'autoj':
        prompt = build_autoj_input(
            prompt=msg['question'], 
            resp1=msg['generation'], resp2=None, 
            protocol="single")
        return [
            {'role': 'system', 'content': "A chat between a curious user and an artificial intelligence expert. The expert gives helpful, specific, and concise answers to the user's questions."},
            {'role': 'user', 'content': prompt}
        ]
    elif mode == 'promethues':
        user_content = '''###Task Description:
An instruction (might include an Input inside it), a response to evaluate, a reference answer are given.
1. Write a criteria
2. Write a reference answer 
3. Write a detailed feedback that assess the quality of the response strictly based on the given score rubric, not evaluating in general.
2. After writing a feedback, write a score that is an integer between 1 and 5. You should refer to the score rubric.
3. The output format should look as follows: \"Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)\"
4. Please do not generate any other opening, closing, and explanations.

###The instruction to evaluate:
{orig_instruction}

###Response to evaluate:
{orig_response}

###Feedback:
'''
        user_content = user_content.format(
            orig_instruction=msg['question'],
            orig_response=msg['generation']
        )
        return [
            {'role': 'user', 'content': user_content}
        ]
    elif mode == 'ultracm':
        user_content = ultracm_instruction_template.format(
            instruction=msg['question'],
            completion=msg['generation'],
        )
        return [
            {'role': 'system', 'content': "A chat between a curious user and an artificial intelligence expert. The expert gives helpful, specific, and concise answers to the user's questions."},
            {'role': 'user', 'content': user_content}
        ]
    elif mode == 'tigerscore':
        FINETUNE_INST = "You are evaluating errors in a model-generated output for a given instruction."
        FINETUNE_INPUT = """\
Instruction: {generation_instruction}
{input_context}


Model-generated Output:
{hypothesis_output}


For each error you give in the response, please also elaborate the following information:
- error location (the words that are wrong in the output)
- error aspect it belongs to.
- explanation why it's an error, and the correction suggestions.
- severity of the error ("Major" or "Minor").
- reduction of score (between 0.5 and 5 given the severity of the error)

Your evaluation output:
"""
        user_content = FINETUNE_INST + '\n' + FINETUNE_INPUT.format(
            generation_instruction='',
            input_context=msg['question'],
            hypothesis_output=msg['generation']
        )
        user_content = user_content.strip('\n ') + '\n'
        return [
            {'role': 'system', 'content': "A chat between a curious user and an artificial intelligence expert. The expert gives helpful, specific, and concise answers to the user's questions."},
            {'role': 'user', 'content': user_content}
        ]
    else:
        raise Exception(f'[!] Unknow mode:', mode)




def prepare_question_with_template(sample, question_template, task):
    question = sample["question"]
    if task == "generation":
        blank_list = re.findall("\{question}", question_template)
        context_for_blank = [question]
    elif task == "critique":
        response = sample["response"]
        dataset = sample["question_source"]
        if dataset == "HumanEval":
            blank_list = re.findall("\{solution}", question_template)
            context_for_blank = [response]
        else:
            blank_list = re.findall("\{question}|\{solution}", question_template)
            context_for_blank = [question, response]
    elif task == "correction":
        response = sample["response"]
        critique = sample["critique"]

        # replace the critique
        if '# Feedback' in critique:
            try:
                items = re.split('# Feedbacks\n', critique)
                try:
                    assert len(items) == 2
                except:
                    items = re.split('# Feedback\n', critique)
                    assert len(items) == 2
                critique = '# Feedbacks\n' + items[1]
            except:
                print(f'[!] MEET ERROR!!!')
                # ipdb.set_trace()

        dataset = sample["question_source"]
        if dataset == "HumanEval":
            blank_list = re.findall("\{solution}|\{critique}", question_template)
            context_for_blank = [response, critique]
        else:
            blank_list = re.findall("\{question}|\{solution}|\{critique}", question_template)
            context_for_blank = [question, response, critique]
        pass
    assert len(blank_list) == len(context_for_blank)
    for b, c in zip(blank_list, context_for_blank):
        question_template = question_template.replace(b, c)
    return question_template

def set_prompt(model_path, task, max_gen_len,  dataset_with_prompt, model_from_api=False):
    if model_from_api:
        tokenizer = tiktoken.encoding_for_model(model_path)
        limit_length = evaluate_configuration.model_from_api_limit_length[model_path]
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
        limit_length = config.max_position_embeddings
    for sample in dataset_with_prompt:
        question = prepare_question_with_template(sample=sample,
                                                  question_template=sample["prompt_info_dict"]["question_template"],
                                                  task=task)
        instruction = sample["prompt_info_dict"]["instruction"]
        question_source = sample["question_source"]
        split_sep = evaluate_configuration.few_shot_split_sep_by_task[task][question_source]
        example_list = (sample["prompt_info_dict"]["examples"].split(split_sep))
        final_prompt = instruction + question
        model_template_len = 0
        for i in range(len(example_list)):
            examples = split_sep.join(example_list[:i+1])
            final_prompt = instruction + examples + question
            input_length = len(tokenizer.encode(final_prompt))
            if input_length > limit_length - max_gen_len - model_template_len:
                if split_sep == "\n---":
                    question = "\n---\n" + question
                final_prompt = instruction + split_sep.join(example_list[:i]) + question
                break
        final_prompt = final_prompt.strip()
        sample["final_prompt"] = final_prompt.strip()
    return dataset_with_prompt


def infer_openai(model, api_key, dataset_with_prompt, out_dir, task, prompt_type):
    max_gen_len = 512
    model_name = model.split("/")[-1]
    openai.api_key = api_key
    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_prompt(model_path=model, task=task, max_gen_len=max_gen_len,
                                     dataset_with_prompt=dataset_with_prompt, model_from_api=True)
    print(f"----------Start {task}----------")
    print("Number of samples: ", len(dataset_with_prompt))
    save_path = os.path.join(save_dir,
                             f"{prompt_type}_result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")
    result_list = []
    with open(save_path, "w", encoding="utf-8") as f:
        for sample in dataset_with_prompt:
            final_prompt = sample["final_prompt"]
            retry_delay_seconds = 5
            for i in range(5):
                try:
                    response = openai.ChatCompletion.create(
                        model=model,
                        messages=[{"role": "user", "content": final_prompt}],
                        max_tokens=max_gen_len,
                        temperature=0,
                        top_p=0.95,
                    )
                    if "choices" in response.keys():
                        result = {"id":sample["id"], "final_prompt":sample["final_prompt"],
                                f'{task}_result': response["choices"][0]["message"]["content"]}
                        f.write(json.dumps(result, ensure_ascii=False) + "\n")
                        break
                    else:
                        print(response)
                except (openai.APIError,openai.error.RateLimitError) as e:
                    print(e)
                    time.sleep(retry_delay_seconds)
                    retry_delay_seconds *= 2
        result_list.append(result)
    return result_list

def infer_hf(model, llm, dataset_with_prompt, out_dir, task, prompt_type):
    '''internlm2 and llama2 model'''
    max_gen_len = 1024
    tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
    sampling_params = SamplingParams(temperature=0, max_tokens=max_gen_len, n=1,
                                     stop=["\n\nQuestion:",
                                           "\n\nTable:",
                                           "\n\n---",
                                           "\n\nYou are an expert Python programmer",
                                           "\nAnalysis and verdict:"]
                                     )
    model_name = model.split("/")[-1]
    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_prompt(model_path=model, task=task, max_gen_len=max_gen_len,
                                     dataset_with_prompt=dataset_with_prompt, model_from_api=False)
    print("Number of samples: ", len(dataset_with_prompt))
    inputs = [sample["final_prompt"] for sample in dataset_with_prompt]

    if 'Llama-2' in model:
        # Analysis and verdict, Based on the problems you found
        if 'Task Description' in inputs[0] or 'Final Judgement' in inputs[0]:
            ##### parse the inputs
            inputs = []
            for sample in dataset_with_prompt:
                critique = '\n---\n# Followings are criteria, task description, high-quality reference, and the entries of our feedback for the answer in the clear Markdown format.\n---\n\n\n' + sample['critique']
                response = segment_response(sample['response'])
                dataset = sample["question_source"]
                question = sample["question"]
                question_template = sample['prompt_info_dict']['question_template']
                if dataset == "HumanEval":
                    blank_list = re.findall("\{solution}|\{critique}", question_template)
                    context_for_blank = [response, critique]
                else:
                    blank_list = re.findall("\{question}|\{solution}|\{critique}", question_template)
                    context_for_blank = [question, response, critique]
                assert len(blank_list) == len(context_for_blank)
                for b, c in zip(blank_list, context_for_blank):
                    question_template = question_template.replace(b, c)
                # inputs.append(sample['final_prompt'])
                question_template = question_template.replace('Based on the problems you found, improve your answer', '**Please re-generate your revised answer that solves all the feedback entries that you found.**')
                inputs.append(question_template)
        else:
            pass
        outputs = llm.generate(inputs, sampling_params)
        outputs = sorted(outputs, key=lambda x: int(x.request_id))
        outputs = [output.outputs[0].text for output in outputs]
    else:
        if 'Task Description' in inputs[0] or 'Final Judgement' in inputs[0]:
            ##### parse the inputs
            inputs = []
            critiques = []
            for index, sample in enumerate(dataset_with_prompt):
                critique = sample['critique']
                critiques.append(critique)
                #critique, evaluated_response = sample['critique'].split('\n----------\n')
                #try:
                #    items = critique.split('# Feedback')
                #    assert len(items) == 2
                #    critique = '\n' + items[1]
                #    items = critique.split('# Summarize')
                #    assert len(items) == 2
                #    critique = items[0]
                #    critique = critique.replace('\n---\n---\n', '')
                #except:
                #    pass
                #critique = '\n# Evaluated Response\n---\n' + evaluated_response + '\n---\nNote that each sentence in the evaluated response are marked with a citation symbol, like [S1] and [S2], which is used to locate the position of flaws.'
                critique = '\n---\n# Followings are criteria, task description, high-quality reference, feedback entries and feedback summarization for the evaluated response\n---\n' + critique
                response = segment_response(sample['response'])
                dataset = sample["question_source"]
                question = sample["question"]
                question_template = sample['prompt_info_dict']['question_template']
                if dataset == "HumanEval":
                    blank_list = re.findall("\{solution}|\{critique}", question_template)
                    context_for_blank = [response, critique]
                else:
                    blank_list = re.findall("\{question}|\{solution}|\{critique}", question_template)
                    context_for_blank = [question, response, critique]
                assert len(blank_list) == len(context_for_blank)
                for b, c in zip(blank_list, context_for_blank):
                    question_template = question_template.replace(b, c)
                # inputs.append(sample['final_prompt'])
                # question_template = question_template.replace('Based on the problems you found, improve your answer', '**Please re-generate your revised answer that solves all the feedback entries that you found.**')
                question_template = question_template.replace('Improved solution and answer:', '\n Must complete the "Improved solution" with your new generated rationale.\nImproved solution and answer:')
                question_template = question_template.replace('Improved code:', '### NOTICE!!!!!: Please directly generate your revised code in the JSON format, DONOT ADD any textual explanation in your generated code!!!!! ONLY Generate the revised code!!!!!')
                inputs.append(question_template)

            gen_config = GenerationConfig(
                temperature=0.0,
                max_new_tokens=2048)
            batch_size = 16
            index = 0
            pbar = tqdm(total=len(inputs))
            outputs = []
            error_counter = 0

            while index < len(inputs):
                if dataset_with_prompt[index]['critique']:
                    msgs = [[{'role': 'user', 'content': msg}] for msg in inputs[index:index+batch_size]]
                    responses = llm(msgs, gen_config=gen_config)
                    responses = [response.text for response in responses]
                    index += batch_size
                    outputs.extend(responses)
                    pbar.update(len(msgs))
                else:
                    outputs.append('')
                    index += 1
                    pbar.update(1)
        else:
            #outputs = llm.generate(inputs, sampling_params)
            #outputs = sorted(outputs, key=lambda x: int(x.request_id))
            #outputs = [output.outputs[0].text for output in outputs]
            outputs = []
            gen_config = GenerationConfig(temperature=0.0, max_new_tokens=2048)
            index = 0
            batch_size = 8
            pbar = tqdm(total=len(inputs))
            while index < len(inputs):
                msgs_ = [[{'role': 'user', 'content': input}] for input in inputs[index:index+batch_size]]
                responses = llm(msgs_, gen_config=gen_config)
                responses = [response.text for response in responses]
                index += batch_size
                outputs.extend(responses)
                pbar.update(len(msgs_)) 

    assert len(inputs) == len(outputs)
    result_list = [{"id":sample["id"], "final_prompt": sample['final_prompt'], f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]

    save_path = os.path.join(save_dir, f"{prompt_type}_result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")

    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list


def set_critic_model_prompt(model, dataset):
    for sample in dataset:
        question = sample["question"]
        response = sample["response"]
        if "UltraCM-13b" in model:
            final_prompt = UltraCM.get_prompt(question=question, response=response)
            sample["final_prompt"] = final_prompt.strip()
        elif "autoj-13b" in model:
            final_prompt = AutoJ.get_prompt(question=question, response=response)
            sample["final_prompt"] = final_prompt.strip()
        elif 'TIGERScore' in model:
            sample["final_prompt"] = ('', question, response)
        elif 'Themis' in model:
            sample['final_prompt'] = ('', (question, response, sample['question_source']))
        elif '/cpfs02/llm/shared' in model:
            # internlm2 with critique-tuned dataset
            if 'autoj' in model:
                mode = 'autoj'
                sample['final_prompt'] = template_maker({'question': question, 'generation': response}, mode=mode)
            elif 'tigerscore' in model:
                mode = 'tigerscore'
                sample['final_prompt'] = template_maker({'question': question, 'generation': response}, mode=mode)
            elif 'ultracm' in model or 'ultrafeedback' in model:
                mode = 'ultracm'
                sample['final_prompt'] = template_maker({'question': question, 'generation': response}, mode=mode)
            elif 'promethues' in model or 'prometheus' in model:
                mode = 'promethues'
                sample['final_prompt'] = template_maker({'question': question, 'generation': response}, mode=mode)
            else:
                sample["final_prompt"] = ('Please solve the question effectively!', question, response, sample['question_type'], sample['question_source'])
    return dataset


def infer_hf_critic_model(model, llm, dataset, out_dir, task):
    '''auto-j, ultracm, tigerscore'''
    if 'autoj-13b' in model or 'UltraCM' in model:
        _infer_hf_critic_model_v1(model, llm, dataset, out_dir, task)
    elif 'TIGERScore' in model:
        _infer_hf_critic_model_v2(model, llm, dataset, out_dir, task)
    elif 'Themis' in model:
        _infer_hf_critic_model_themis(model, llm, dataset, out_dir, task)
    elif 'promethues' in model or 'prometheus' in model:
        _infer_hf_critic_model_promethues(model, llm, dataset, out_dir, task)
    elif '/cpfs02/llm/shared' in model or 'transfer_from_tos' in model:
        if 'autoj' in model or 'ultracm' in model or 'tigerscore' in model or 'ultrafeedback' in model:
            _infer_hf_critic_model_v3(model, llm, dataset, out_dir, task)
        else:
            #_infer_hf_critic_model_v6_resumm_v2(model, llm, dataset, out_dir, task)
            #_infer_hf_critic_model_v6_resumm(model, llm, dataset, out_dir, task)
            #_infer_hf_critic_model_v6_resumm_multi_turn_v2(model, llm, dataset, out_dir, task)
            #_infer_hf_critic_model_v6_resumm_multi_turn_v3(model, llm, dataset, out_dir, task)
            #ipdb.set_trace()
            _infer_hf_critic_model_v6_resumm_multi_turn(model, llm, dataset, out_dir, task)

def _infer_hf_critic_model_v1(model, llm, dataset, out_dir, task):
    '''auto-j, ultracm'''
    max_gen_len = 1024
    # compatible for ultracm, must with sampling
    sampling_params = SamplingParams(temperature=1.0, top_p=1.0, max_tokens=max_gen_len, n=1)
    model_name = model.split("/")[-1]
    print("Number of samples: ", len(dataset))

    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_critic_model_prompt(model, dataset)

    inputs = [sample["final_prompt"] for sample in dataset_with_prompt]
    outputs = llm.generate(inputs, sampling_params)
    outputs = sorted(outputs, key=lambda x: int(x.request_id))
    outputs = [output.outputs[0].text for output in outputs]

    result_list = [{"id":sample["id"], "final_prompt":sample["final_prompt"],
                    f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]

    save_path = os.path.join(save_dir, f"result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")

    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list


def _infer_hf_critic_model_v2(model, llm, dataset, out_dir, task):
    '''tigerscore'''
    max_gen_len = 1024
    sampling_params = SamplingParams(temperature=0, max_tokens=max_gen_len, n=1)
    model_name = model.split("/")[-1]
    print("Number of samples: ", len(dataset))

    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_critic_model_prompt(model, dataset)
    inputs = [sample["final_prompt"] for sample in dataset_with_prompt]
    instructions = [a for a, b, c in inputs]
    input_contexts = [b for a, b, c in inputs]
    hypo_outputs = [c for a, b, c in inputs]
    outputs = llm.score(instructions, hypo_outputs, input_contexts)
    outputs = [json.dumps(output) for output in outputs]
    result_list = [{"id":sample["id"], "final_prompt":sample["final_prompt"],
                    f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]
    save_path = os.path.join(save_dir, f"result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")
    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list


def _infer_hf_critic_model_v3(model, llm, dataset, out_dir, task):
    '''internlm2 model with lmdeploy inference'''
    max_gen_len = 1024
    # compatible for ultracm, must with sampling
    # model_name = model.split("/")[-1]
    model_name = model.replace('/', '_')
    print("Number of samples: ", len(dataset))

    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_critic_model_prompt(model, dataset)

    inputs = [sample["final_prompt"] for sample in dataset_with_prompt]
    batch_size = 32
    index = 0
    pbar = tqdm(total=len(inputs))
    outputs = []
    gen_config = GenerationConfig(temperature=0.8, max_new_tokens=2048, top_k=50, top_p=0.95)
    while index < len(inputs):
        msgs_ = inputs[index:index+batch_size]
        responses = llm(msgs_, gen_config=gen_config)
        responses = [response.text for response in responses]
        index += batch_size
        outputs.extend(responses)
        pbar.update(len(msgs_)) 

    result_list = [{"id":sample["id"], "final_prompt":sample["final_prompt"],
                    f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]

    save_path = os.path.join(save_dir, f"result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")

    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list



def _infer_hf_critic_model_v4(model, llm, dataset, out_dir, task):
    '''internlm2 model with lmdeploy inference'''
    max_gen_len = 2048
    # compatible for ultracm, must with sampling
    model_name = model.split("/")[-1]
    print("Number of samples: ", len(dataset))

    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_critic_model_prompt(model, dataset)

    inputs = [sample["final_prompt"] for sample in dataset_with_prompt]
    batch_size = 32
    index = 0
    pbar = tqdm(total=len(inputs))
    outputs = []
    gen_config = GenerationConfig(temperature=0.0, max_new_tokens=1024)

    while index < len(inputs):
        # build the conversation history
        histories = []
        for msg in inputs[index:index+batch_size]:
            histories.append([{
                'role': 'user', 
                'content': f'# Instruction\n{msg[0]}\n\n# Input\n{msg[1]}'
            }])
            evaluated_response = segment_response(msg[2])
            histories[-1].append({
                'role': 'assistant',
                'content': evaluated_response})
        ## task 
        for his in histories:
            his.append({'role': 'user', 'content': multi_turn_prompts['task']})
        responses = llm(histories, gen_config=gen_config)
        responses = [response.text for response in responses]
        for his, response in zip(histories, responses):
            his.append({'role': 'assistant', 'content': response})
        # criteria
        for his in histories:
            his.append({'role': 'user', 'content': multi_turn_prompts['criteria'].format(my_criteria='')})
        #responses = llm(histories, gen_config=gen_config)
        #responses = [response.text for response in responses]
        #for his, response in zip(histories, responses):
        for his in histories:
            his.append({'role': 'assistant', 'content': '''## # Two-tier Structure of Criteria\nAccuracy\nDescription: The response/answer must correctly and accurately answer the user question.\n\n### Correctness of Solution\nDescription: Evaluate the accuracy of the solution provided, ensuring that it solves the problem correctly. For code generation tasks, this would involve checking if the code runs without errors and produces the expected output. For mathematical problems, the solution should be mathematically sound and provide the correct answer. For other tasks need logical reasoning, the solution conclusion should be correct and accurate without errors.\nDegree: important.\n### Precision of Details\nDescription: Assess the precision and correctness of the details within the response. This includes verifying that all steps, calculations, or pieces of code are precisely accurate and contribute to the correct final outcome. Any stated facts or definitions should also be error-free.\nDegree: important.\n\n### Consistency with Given Information\nDescription: Determine if the response is consistent with the information provided in the question. The answer should not introduce new assumptions that were not present in the original problem statement and should only use the information given to arrive at the solution\nDegree: important.\n\n### Step-by-Step\nDescription: for questions the solution should contain step-by-step reasoning to get the final response/answer.\nDegree: important'''})
        # feedback
        for his in histories:
            his.append({'role': 'user', 'content': multi_turn_prompts['feedback']})
        responses = llm(histories, gen_config=gen_config)
        responses = [response.text for response in responses]
        for his, response in zip(histories, responses):
            his.append({'role': 'assistant', 'content': response})
        # summarization
        for his in histories:
            his.append({'role': 'user', 'content': multi_turn_prompts['summarization']})
        responses = llm(histories, gen_config=gen_config)
        responses = [response.text for response in responses]
        for his, response in zip(histories, responses):
            his.append({'role': 'assistant', 'content': response})
        responses = []
        for his in histories:
            utterance = []
            for sample in his[-2:]:
                if sample['role'] == 'assistant':
                    utterance.append(sample['content'])
            r = '\n'.join(utterance)
            responses.append(r)
        outputs.extend(responses)
        pbar.update(len(responses))
        index += batch_size

    result_list = [{"id":sample["id"], "final_prompt":sample["final_prompt"],
                    f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]

    save_path = os.path.join(save_dir, f"result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")

    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list



def template_maker_nips2024(history, msg, mode='task', domain_name='translation'):
    if mode == 'task':
        prompt = task_criteria_prompts[0].format(my_criteria='')
    elif mode == 'reference':
        prompt = reference_prompts[0]
    elif mode == 'feedback':
        if domain_name in ['code_exec', 'code_not_exec', 'math_pot']:
            criteria_type = 'code_generation'
        else:
            criteria_type = ''
        response = segment_response(msg[2], criteria_type=criteria_type)
        prompt = feedback_prompts[0].format(evaluated_response=response)
    history.append({'role': 'user', 'content': prompt})


def _infer_hf_critic_model_v5(model, llm, dataset, out_dir, task):
    '''internlm2 model with lmdeploy inference'''
    max_gen_len = 2048
    # compatible for ultracm, must with sampling
    model_name = model.replace("/", "_")
    print("Number of samples: ", len(dataset))

    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_critic_model_prompt(model, dataset)

    inputs = [sample["final_prompt"] for sample in dataset_with_prompt]
    batch_size = 32
    index = 0
    pbar = tqdm(total=len(inputs))
    outputs = []
    gen_config = GenerationConfig(temperature=0.0, max_new_tokens=2048)

    prompt = '''# Goal
You are a helpful assistant aiming to provide valuable critiques and analysis for the **last response** in the previous conversation history.

We have provided our criteria list for you from different evaluation perspective as below
---
# Our Provided Criteria List
{my_criteria}
---

# Task

To generate the valuable feedback, you should follow these 4 steps to generate valuable and accurate critiques:
**Step 1:** analyze the purpose of the user role in the previous conversation history
**Step 2:** generate the two-tier detailed criteria list (if our provided criteria list is not empty, please expand ours)
**Step 3:** generate your high-quality reference answer for better critiques
**Step 4:** generate your detailed feedbacks, followed by a summarizaing containing the final judgemen score (ranging from 1 to 10, where higher score denotes better quaulity of evaluated response.)

# NOTICE
1. the last response in the conversation history contains the citation symbols for better critiques, like [S1] and [S2]. Do NOT critique on these citation symbols

# Output
Generate the your content in the clear markdown template format.
'''

    while index < len(inputs):
        # build the conversation history
        histories = []
        for msg in inputs[index:index+batch_size]:
            histories.append([{
                'role': 'user', 
                'content': f'# Instruction\n{msg[0]}\n\n# Input\n{msg[1]}'
            }])
            evaluated_response = segment_response(msg[2])
            histories[-1].append({
                'role': 'assistant',
                'content': evaluated_response})
        ## task 
        for his in histories:
            his.append({'role': 'user', 'content': prompt.format(my_criteria='')})

        responses = llm(histories, gen_config=gen_config)
        responses = [response.text for response in responses]
        outputs.extend(responses)
        pbar.update(len(responses))
        index += batch_size

    result_list = [{"id":sample["id"], "final_prompt":sample["final_prompt"],
                    f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]

    save_path = os.path.join(save_dir, f"result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")

    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list


def _infer_hf_critic_model_v6_resumm_v2(model, llm, dataset, out_dir, task):
    '''internlm2 model with lmdeploy inference'''
    # compatible for ultracm, must with sampling
    model_name = model.replace("/", "_")
    print("Number of samples: ", len(dataset))
    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_critic_model_prompt(model, dataset)
    inputs = [sample["final_prompt"] for sample in dataset_with_prompt]
    batch_size = 32
    index = 0
    pbar = tqdm(total=len(inputs))
    outputs = []
    gen_config = GenerationConfig(temperature=0.0, max_new_tokens=4096)
    while index < len(inputs):
        histories = []
        evaluated_responses = []
        for msg in inputs[index:index+batch_size]:
            ch = [{'role': 'user', 'content': f'# Instruction\n{msg[0]}\n\n# Input\n{msg[1]}'}]
            evaluated_response = segment_response(msg[2])
            prompt = template_single_turn_nips2024(ch, evaluated_response)
            histories.append(prompt)
            evaluated_responses.append(evaluated_response)
        responses = llm(histories, gen_config=gen_config)
        assert len(responses) == len(evaluated_responses)
        responses = [response.text + '\n----------\n' + evaluated_response for response, evaluated_response in zip(responses, evaluated_responses)]
        outputs.extend(responses)
        pbar.update(len(histories))
        index += batch_size
    result_list = [{"id":sample["id"], "final_prompt":sample["final_prompt"], f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]
    save_path = os.path.join(save_dir, f"result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")
    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list




def _infer_hf_critic_model_v6_resumm(model, llm, dataset, out_dir, task):
    '''internlm2 model with lmdeploy inference'''
    # compatible for ultracm, must with sampling
    model_name = model.replace("/", "_")
    print("Number of samples: ", len(dataset))
    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_critic_model_prompt(model, dataset)
    inputs = [sample["final_prompt"] for sample in dataset_with_prompt]
    batch_size = 32
    index = 0
    pbar = tqdm(total=len(inputs))
    outputs = []
    gen_config = GenerationConfig(temperature=0.0, max_new_tokens=4096)
    while index < len(inputs):
        histories = []
        evaluated_responses = []
        for msg in inputs[index:index+batch_size]:
            ch = [{'role': 'user', 'content': f'# Instruction\n{msg[0]}\n\n# Input\n{msg[1]}'}]
            evaluated_response = segment_response(msg[2])
            prompt = template_single_turn_nips2024(ch, evaluated_response)
            histories.append(prompt)
            evaluated_responses.append(evaluated_response)
        responses = llm(histories, gen_config=gen_config)
        assert len(responses) == len(evaluated_responses)
        responses = [response.text + '\n----------\n' + evaluated_response for response, evaluated_response in zip(responses, evaluated_responses)]
        outputs.extend(responses)
        pbar.update(len(histories))
        index += batch_size
    result_list = [{"id":sample["id"], "final_prompt":sample["final_prompt"], f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]
    save_path = os.path.join(save_dir, f"result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")
    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list


######## single-turn prompt ########
def template_single_turn_nips2024(queries, evaluated_response, my_criteria='', with_reference=True, with_task=True, with_criteria=True):
    query = '\n'.join([f'**{item["role"]}**: {item["content"]}' for item in queries])
    #pp = PROMPT.format(
    pp = PROMPT_v2.format(
        query=query, 
        evaluated_response=evaluated_response, 
        my_criteria=my_criteria
    )
    return [dict(role='user', content=pp)]

def _infer_hf_critic_model_v6_resumm_multi_turn(model, llm, dataset, out_dir, task):
    '''internlm2 model with lmdeploy inference'''
    if 'no_ref' in model:
        #modemode = 'no_ref'
        modemode = None
    # TODO 
    elif 'no_criteria' in model or '':
        modemode = 'no_criteria'
    elif 'no_task' in model:
        modemode = 'no_task'
    elif 'no_all' in model:
        modemode = 'no_all'
    else:
        modemode = None
    # compatible for ultracm, must with sampling
    model_name = model.replace("/", "_")
    print("Number of samples: ", len(dataset))
    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_critic_model_prompt(model, dataset)
    inputs = [sample["final_prompt"] for sample in dataset_with_prompt]
    batch_size = 16
    index = 0
    pbar = tqdm(total=len(inputs))
    outputs = []
    gen_config = GenerationConfig(temperature=0.0, max_new_tokens=4096)

    #while index < len(inputs):
    #    histories = []
    #    evaluated_responses = []
    #    if inputs[index][3] not in ['Symbolic Reasoning', 'Algorithmic task']:
    #        outputs.append('')
    #        pbar.update(1)
    #        index += 1
    #        continue
    #    else:
    #        break

    while index < len(inputs):
        histories = []
        evaluated_responses = []
        #if inputs[index][3] not in ['Symbolic Reasoning', 'Algorithmic task']:
        #    outputs.append('')
        #    pbar.update(1)
        #    index += 1
        #    continue
        question_sources = []
        for msg in inputs[index:index+batch_size]:
            #ch = [{'role': 'user', 'content': f'# Instruction\n{msg[0]}\n\n# Input\n{msg[1]}'}]
            assert len(msg) == 5
            question_sources.append(msg[4])
            #if msg[3] not in ['Symbolic Reasoning', 'Algorithmic task']:
            #    break

            if msg[3] == 'Code Generation':
                criteria_type = 'code_generation'
            else:
                criteria_type = ''
            evaluated_response = segment_response(msg[2], criteria_type=criteria_type)
            #prompt = template_single_turn_nips2024(ch, evaluated_response)
            #histories.append(prompt)
            evaluated_responses.append(evaluated_response)
            query = f'# Instruction\n{msg[0]}\n\n# Input\n{msg[1]}'
            user_1 = multi_turn_prompts_['initial_input'][0].format(response=evaluated_response, conversation_history=query)
            user_1 = {'role': 'user', 'content': user_1}
            histories.append([user_1])
        # generate task
        if modemode not in ['no_task', 'no_all']:
            for h in histories:
                user_2 = multi_turn_prompts_['task'][0]
                user_2 = {'role': 'user', 'content': user_2}
                h.append(user_2)
            res = llm(histories, gen_config=gen_config)
            assert len(res) == len(histories)
            for ii, (r, h) in enumerate(zip(res, histories)):
                h.append({'role': 'assistant', 'content': r.text})
                #h.append({'role': 'assistant', 'content': ''})
        #### generate critiera
        if modemode not in ['no_criteria', 'no_all']:
            for ii, h in enumerate(histories):
                #if my_criteria:
                #		user_3 = multi_turn_prompts_['criteria'][0].format(my_criteria=my_criteria)
                #else:
                user_3 = multi_turn_prompts_['criteria'][1]
                user_3 = {'role': 'user', 'content': user_3}
                h.append(user_3)
            res = llm(histories, gen_config=gen_config)
            assert len(res) == len(histories)
            #for ii, (h, qs) in enumerate(zip(histories, question_sources)):
            #    h.append({'role': 'assistant', 'content': question_source_dict[qs]})
            for ii, (h, r) in enumerate(zip(histories, res)):
                h.append({'role': 'assistant', 'content': r.text})
        #### generate reference
        if modemode not in ['no_ref', 'no_all']:
            for ii, h in enumerate(histories):
                user_4 = multi_turn_prompts_['reference'][0]
                user_4 = {'role': 'user', 'content': user_4}
                h.append(user_4)
            res = llm(histories, gen_config=gen_config)
            assert len(res) == len(histories)
            for ii, (r, h) in enumerate(zip(res, histories)):
                h.append({'role': 'assistant', 'content': r.text})
        # generate feedback
        for ii, h in enumerate(histories):
            #user_5 = multi_turn_prompts_['feedback'][0]
            user_5 = multi_turn_prompts_['feedback'][1]
            user_5 = {'role': 'user', 'content': user_5}
            h.append(user_5)
        res = llm(histories, gen_config=gen_config)
        assert len(res) == len(histories)
        for ii, (r, h) in enumerate(zip(res, histories)):
            h.append({'role': 'assistant', 'content': r.text})
        # generate summarization
        for ii, h in enumerate(histories):
            user_6 = multi_turn_prompts_['summarization'][0]
            user_6 = {'role': 'user', 'content': user_6}
            h.append(user_6)
        res = llm(histories, gen_config=gen_config)
        assert len(res) == len(histories)
        for ii, (r, h) in enumerate(zip(res, histories)):
            h.append({'role': 'assistant', 'content': r.text})
        responses = ['\n----------\n'.join([utterance['content'] for utterance in h if utterance['role'] == 'assistant']) for h in histories]
        assert len(responses) == len(evaluated_responses)
        outputs.extend(responses)
        pbar.update(len(histories))
        index += len(histories)
    result_list = [{"id":sample["id"], "final_prompt":sample["final_prompt"], f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]
    save_path = os.path.join(save_dir, f"result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")
    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list

def _infer_hf_critic_model_promethues(model, llm, dataset, out_dir, task):
    max_gen_len = 2048
    model_name = model.split("/")[-1]
    print("Number of samples: ", len(dataset))
    gen_config = SamplingParams(max_tokens=2048, temperature=0.0, n=1)

    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    gen_config = GenerationConfig(
        temperature=1.0,
        top_k=50,
        top_p=1.0,
        max_new_tokens=2048)
    dataset_with_prompt = set_critic_model_prompt(model, dataset)
    inputs = [[sample["final_prompt"][0]] for sample in dataset_with_prompt]

    batch_size = 32
    index = 0
    pbar = tqdm(total=len(inputs))
    outputs = []
    while index < len(inputs):
        msgs_ = inputs[index:index+batch_size]
        responses = llm(msgs_, gen_config=gen_config)
        responses = [response.text for response in responses]
        index += batch_size
        outputs.extend(responses)
        pbar.update(len(msgs_))
    result_list = [{"id":sample["id"], "final_prompt":sample["final_prompt"],
                    f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]
    save_path = os.path.join(save_dir, f"result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")
    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list




def _infer_hf_critic_model_themis(model, llm, dataset, out_dir, task):
    '''tigerscore'''
    max_gen_len = 1024
    model_name = model.split("/")[-1]
    print("Number of samples: ", len(dataset))
    gen_config = SamplingParams(max_tokens=2048, temperature=0.0, n=1)

    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_critic_model_prompt(model, dataset)
    inputs = [sample["final_prompt"][1] for sample in dataset_with_prompt]
    outputs = generate_themis(llm, gen_config, inputs)
    result_list = [{"id":sample["id"], "final_prompt":sample["final_prompt"],
                    f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]
    save_path = os.path.join(save_dir, f"result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")
    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list


def _infer_hf_critic_model_v6_resumm_multi_turn_v2(model, llm, dataset, out_dir, task):
    '''internlm2 model with lmdeploy inference'''
    # compatible for ultracm, must with sampling
    model_name = model.replace("/", "_")
    print("Number of samples: ", len(dataset))
    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_critic_model_prompt(model, dataset)
    inputs = [sample["final_prompt"] for sample in dataset_with_prompt]
    batch_size = 16
    index = 0
    pbar = tqdm(total=len(inputs))
    outputs = []
    gen_config = GenerationConfig(temperature=0.0, max_new_tokens=4096)

    while index < len(inputs):
        histories = []
        evaluated_responses = []
        question_sources = []
        for msg in inputs[index:index+batch_size]:
            assert len(msg) == 5
            question_sources.append(msg[4])
            #if msg[3] not in ['Symbolic Reasoning', 'Algorithmic task']:
            #    break

            if msg[3] == 'Code Generation':
                criteria_type = 'code_generation'
            else:
                criteria_type = ''
            evaluated_response = segment_response(msg[2], criteria_type=criteria_type)
            evaluated_responses.append(evaluated_response)
            query = f'# Instruction\n{msg[0]}\n\n# Input\n{msg[1]}'
            user_1 = multi_turn_prompts_['initial_input'][0].format(response=evaluated_response, conversation_history=query)
            user_1 = {'role': 'user', 'content': user_1}
            histories.append([user_1])
        #### generate reference
        for ii, h in enumerate(histories):
            user_4 = multi_turn_prompts_['reference'][1]
            user_4 = {'role': 'user', 'content': user_4}
            h.append(user_4)
        res = llm(histories, gen_config=gen_config)
        assert len(res) == len(histories)
        for ii, (r, h) in enumerate(zip(res, histories)):
            h.append({'role': 'assistant', 'content': r.text})
        # generate task
        for h in histories:
            user_2 = multi_turn_prompts_['task'][0]
            user_2 = {'role': 'user', 'content': user_2}
            h.append(user_2)
        res = llm(histories, gen_config=gen_config)
        assert len(res) == len(histories)
        for ii, (r, h) in enumerate(zip(res, histories)):
            h.append({'role': 'assistant', 'content': r.text})
        #### generate critiera
        for ii, h in enumerate(histories):
            user_3 = multi_turn_prompts_['criteria'][1]
            user_3 = {'role': 'user', 'content': user_3}
            h.append(user_3)
        res = llm(histories, gen_config=gen_config)
        assert len(res) == len(histories)
        for ii, (h, r) in enumerate(zip(histories, res)):
            h.append({'role': 'assistant', 'content': r.text})
        # generate feedback
        for ii, h in enumerate(histories):
            user_5 = multi_turn_prompts_['feedback'][1]
            user_5 = {'role': 'user', 'content': user_5}
            h.append(user_5)
        res = llm(histories, gen_config=gen_config)
        assert len(res) == len(histories)
        for ii, (r, h) in enumerate(zip(res, histories)):
            h.append({'role': 'assistant', 'content': r.text})
        # generate summarization
        for ii, h in enumerate(histories):
            user_6 = multi_turn_prompts_['summarization'][0]
            user_6 = {'role': 'user', 'content': user_6}
            h.append(user_6)
        res = llm(histories, gen_config=gen_config)
        assert len(res) == len(histories)
        for ii, (r, h) in enumerate(zip(res, histories)):
            h.append({'role': 'assistant', 'content': r.text})
        responses = ['\n----------\n'.join([utterance['content'] for utterance in h if utterance['role'] == 'assistant']) for h in histories]
        assert len(responses) == len(evaluated_responses)
        outputs.extend(responses)
        pbar.update(len(histories))
        index += len(histories)
    result_list = [{"id":sample["id"], "final_prompt":sample["final_prompt"], f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]
    save_path = os.path.join(save_dir, f"result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")
    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list



def _infer_hf_critic_model_v6_resumm_multi_turn_v3(model, llm, dataset, out_dir, task):
    '''internlm2 model with lmdeploy inference'''
    # compatible for ultracm, must with sampling
    model_name = model.replace("/", "_")
    print("Number of samples: ", len(dataset))
    save_dir = os.path.join(out_dir, model_name, task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    dataset_with_prompt = set_critic_model_prompt(model, dataset)
    inputs = [sample["final_prompt"] for sample in dataset_with_prompt]
    batch_size = 16
    index = 0
    pbar = tqdm(total=len(inputs))
    outputs = []
    gen_config = GenerationConfig(temperature=0.0, max_new_tokens=4096)

    while index < len(inputs):
        histories = []
        evaluated_responses = []
        question_sources = []
        for msg in inputs[index:index+batch_size]:
            assert len(msg) == 5
            question_sources.append(msg[4])
            #if msg[3] not in ['Symbolic Reasoning', 'Algorithmic task']:
            #    break

            if msg[3] == 'Code Generation':
                criteria_type = 'code_generation'
            else:
                criteria_type = ''
            evaluated_response = segment_response(msg[2], criteria_type=criteria_type)
            evaluated_responses.append(evaluated_response)
            query = f'# Instruction\n{msg[0]}\n\n# Input\n{msg[1]}'
            #user_1 = multi_turn_prompts_['initial_input'][0].format(response=evaluated_response, conversation_history=query)
            user_1 = multi_turn_prompts_v3['initial_input'][0].format(conversation_history=query)
            user_1 = {'role': 'user', 'content': user_1}
            histories.append([user_1])
        #### generate reference
        for ii, h in enumerate(histories):
            user_4 = multi_turn_prompts_['reference'][1]
            user_4 = {'role': 'user', 'content': user_4}
            h.append(user_4)
        res = llm(histories, gen_config=gen_config)
        assert len(res) == len(histories)
        for ii, (r, h) in enumerate(zip(res, histories)):
            h.append({'role': 'assistant', 'content': r.text})
        # generate task
        for h in histories:
            user_2 = multi_turn_prompts_['task'][0]
            user_2 = {'role': 'user', 'content': user_2}
            h.append(user_2)
        res = llm(histories, gen_config=gen_config)
        assert len(res) == len(histories)
        for ii, (r, h) in enumerate(zip(res, histories)):
            h.append({'role': 'assistant', 'content': r.text})
        #### generate critiera
        for ii, h in enumerate(histories):
            user_3 = multi_turn_prompts_['criteria'][1]
            user_3 = {'role': 'user', 'content': user_3}
            h.append(user_3)
        res = llm(histories, gen_config=gen_config)
        assert len(res) == len(histories)
        for ii, (h, r) in enumerate(zip(histories, res)):
            h.append({'role': 'assistant', 'content': r.text})
        # generate feedback
        for ii, (h, er) in enumerate(zip(histories, evaluated_responses)):
            user_5 = multi_turn_prompts_['feedback'][1].format(response=er)
            user_5 = {'role': 'user', 'content': user_5}
            h.append(user_5)
        res = llm(histories, gen_config=gen_config)
        assert len(res) == len(histories)
        for ii, (r, h) in enumerate(zip(res, histories)):
            h.append({'role': 'assistant', 'content': r.text})
        # generate summarization
        for ii, h in enumerate(histories):
            user_6 = multi_turn_prompts_['summarization'][0]
            user_6 = {'role': 'user', 'content': user_6}
            h.append(user_6)
        res = llm(histories, gen_config=gen_config)
        assert len(res) == len(histories)
        for ii, (r, h) in enumerate(zip(res, histories)):
            h.append({'role': 'assistant', 'content': r.text})
        responses = ['\n----------\n'.join([utterance['content'] for utterance in h if utterance['role'] == 'assistant']) for h in histories]
        assert len(responses) == len(evaluated_responses)
        outputs.extend(responses)
        pbar.update(len(histories))
        index += len(histories)
    result_list = [{"id":sample["id"], "final_prompt":sample["final_prompt"], f'{task}_result': output} for sample, output in zip(dataset_with_prompt, outputs)]
    save_path = os.path.join(save_dir, f"result_{time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))}.jsonl")
    with open(save_path, "w", encoding="utf-8") as f:
        for sample in result_list:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    return result_list



