import json
from .themis_utils import *
from .llm import *
from vllm import EngineArgs, LLMEngine, SamplingParams
from .prompts_v2 import *
import re
from .nips2024_prompts import *
from tqdm import tqdm
from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig, ChatTemplateConfig
import torch
from vllm import LLM, SamplingParams
import torch
import pprint
import tiktoken
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, AutoConfig, AutoModel, LlamaTokenizer, LlamaForCausalLM
from accelerate import init_empty_weights,infer_auto_device_map,load_checkpoint_in_model,dispatch_model, load_checkpoint_and_dispatch
from lmdeploy.serve.openai.api_client import APIClient
from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig
import ipdb
import sys
try:
    from tigerscore import TIGERScorer
except:
    pass
try:
    from util_func import *
except:
    pass
from multiprocessing import Pool
import random
import time

# for autoj-13b
# from vllm import LLM, SamplingParams
import torch
from utils.constants_prompt import build_autoj_input # constants_prompt -> codes/constants_prompt.py

PROMPT_NO_ALL = '''# Goal
You are a helpful assistant aiming to provide valuable critiques and analysis for the **evaluated response** of the **conversation history** between an user and an assistant.
Besides, we also provide some **initial criteria** for you to assit your evaluation.

Three input information are listed as below
---
## Conversation history
{query}
---

---
## Evaluated Response
{evaluated_response}
---
### NOTICE: the last response in the conversation history contains the citation symbols for better critiques, like [S1] and [S2]. Do NOT critique on these citation symbols

---
## Our Provided Criteria
{my_criteria}
---

# Task
To generate the valuable critiques, you should follow these 4 steps to generate valuable and accurate critiques:
**Step 1:** analyze the purpose of the user role in the previous conversation history
**Step 2:** generate the two-tier detailed criteria list (if our provided criteria list is not empty, please expand ours)
**Step 3:** generate your high-quality reference answer for better critiques
**Step 4:** generate your detailed feedbacks, including the description of response performance on each **first-tier primary criteria**. In the end, you should provide your final judgement score **x**, ranging from 1 to 10, by following the score rubrics below:

| Score Ranges | Definiation |
| ---------- | ----------- |
| 1 <= x < 3 | The quality is very low, containing numerous severe flaws; there are also other flaws, with Important error criteria. |
| 3 <= x < 5 | The quality is low, making it difficult to fulfill user query; There are many flaws, and a small number of severe flaws may be included. |
| 5 <= x < 7 | The quality is moderate, somewhat addressing the user query; There are a few errors, and a small number of severe errors may be included. |
| 7 <= x < 9 | The quality is approximately the same as the reference response (with the reference response scoring around 8). The response effectively answers user query. |
| 9 <= x < 10 | The quality is better than the reference, perfectly answering the user query in the conversation history. |

## NOTICE!!!
1. Quality scores (1-10) can be expressed as floating-point numbers.
2. Within specific score ranges, the more flaws there are, the lower quality score, and vice versa.
3. You should compare the evaluated response the reference before giving your quality score. Please follow the important guideline as follows: if evaluated response is worse than the reference, its score should be lower.

# Output
Please output your critiques
'''


PROMPT_NO_CRITERIA = '''# Goal
You are a helpful assistant aiming to provide valuable critiques and analysis for the **evaluated response** of the **conversation history** between an user and an assistant.
Besides, we also provide some **initial criteria** for you to assit your evaluation.

Three input information are listed as below
---
## Conversation history
{query}
---

---
## Evaluated Response
{evaluated_response}
---
### NOTICE: the last response in the conversation history contains the citation symbols for better critiques, like [S1] and [S2]. Do NOT critique on these citation symbols

---
## Our Provided Criteria
{my_criteria}
---

# Task
To generate the valuable critiques, you should follow these 4 steps to generate valuable and accurate critiques:
**Step 1:** analyze the purpose of the user role in the previous conversation history
**Step 2:** generate your high-quality reference answer for better critiques
**Step 3:** generate your detailed feedbacks, followed by a summarizaing containing the final judgemen score (ranging from 1 to 10, where higher score denotes better quaulity of evaluated response.)

## NOTICE!!!
1. the last response in the conversation history contains the citation symbols for better critiques, like [S1] and [S2]. Do NOT critique on these citation symbols

# Output
Please output your critiques
'''



#PROMPT_NO_CRITERIA = '''# Goal
#You are a helpful assistant aiming to provide valuable critiques and analysis for the **evaluated response** of the **conversation history** between an user and an assistant.
#Besides, we also provide some **initial criteria** for you to assit your evaluation.
#
#Three input information are listed as below
#---
### Conversation history
#{query}
#---
#
#---
### Evaluated Response
#{evaluated_response}
#---
#### NOTICE: the last response in the conversation history contains the citation symbols for better critiques, like [S1] and [S2]. Do NOT critique on these citation symbols
#
#---
### Our Provided Criteria
#{my_criteria}
#---
#
## Task
#To generate the valuable critiques, you should follow these 4 steps to generate valuable and accurate critiques:
#**Step 1:** analyze the purpose of the user role in the previous conversation history
#**Step 2:** generate the two-tier detailed criteria list (if our provided criteria list is not empty, please expand ours)
#**Step 3:** generate your high-quality reference answer for better critiques
#**Step 4:** generate your detailed feedbacks, including the description of response performance on each **first-tier primary criteria**. In the end, you should provide your final judgement score **x**, ranging from 1 to 10, by following the score rubrics below:
#
#| Score Ranges | Definiation |
#| ---------- | ----------- |
#| 1 <= x < 3 | The quality is very low, containing numerous severe flaws; there are also other flaws, with Important error criteria. |
#| 3 <= x < 5 | The quality is low, making it difficult to fulfill user query; There are many flaws, and a small number of severe flaws may be included. |
#| 5 <= x < 7 | The quality is moderate, somewhat addressing the user query; There are a few errors, and a small number of severe errors may be included. |
#| 7 <= x < 9 | The quality is approximately the same as the reference response (with the reference response scoring around 8). The response effectively answers user query. |
#| 9 <= x < 10 | The quality is better than the reference, perfectly answering the user query in the conversation history. |
#
### NOTICE!!!
#1. Quality scores (1-10) can be expressed as floating-point numbers.
#2. Within specific score ranges, the more flaws there are, the lower quality score, and vice versa.
#3. You should compare the evaluated response the reference before giving your quality score. Please follow the important guideline as follows: if evaluated response is worse than the reference, its score should be lower.
#
## Output
#Please output your critiques
#'''



PROMPT_NO_TASK = '''# Goal
You are a helpful assistant aiming to provide valuable critiques and analysis for the **evaluated response** of the **conversation history** between an user and an assistant.
Besides, we also provide some **initial criteria** for you to assit your evaluation.

Three input information are listed as below
---
## Conversation history
{query}
---

---
## Evaluated Response
{evaluated_response}
---
### NOTICE: the last response in the conversation history contains the citation symbols for better critiques, like [S1] and [S2]. Do NOT critique on these citation symbols

---
## Our Provided Criteria
{my_criteria}
---

# Task
To generate the valuable critiques, you should follow these 4 steps to generate valuable and accurate critiques:
**Step 1:** analyze the purpose of the user role in the previous conversation history
**Step 2:** generate the two-tier detailed criteria list (if our provided criteria list is not empty, please expand ours)
**Step 3:** generate your high-quality reference answer for better critiques
**Step 4:** generate your detailed feedbacks, including the description of response performance on each **first-tier primary criteria**. In the end, you should provide your final judgement score **x**, ranging from 1 to 10, by following the score rubrics below:

| Score Ranges | Definiation |
| ---------- | ----------- |
| 1 <= x < 3 | The quality is very low, containing numerous severe flaws; there are also other flaws, with Important error criteria. |
| 3 <= x < 5 | The quality is low, making it difficult to fulfill user query; There are many flaws, and a small number of severe flaws may be included. |
| 5 <= x < 7 | The quality is moderate, somewhat addressing the user query; There are a few errors, and a small number of severe errors may be included. |
| 7 <= x < 9 | The quality is approximately the same as the reference response (with the reference response scoring around 8). The response effectively answers user query. |
| 9 <= x < 10 | The quality is better than the reference, perfectly answering the user query in the conversation history. |

## NOTICE!!!
1. Quality scores (1-10) can be expressed as floating-point numbers.
2. Within specific score ranges, the more flaws there are, the lower quality score, and vice versa.
3. You should compare the evaluated response the reference before giving your quality score. Please follow the important guideline as follows: if evaluated response is worse than the reference, its score should be lower.

# Output
Please output your critiques
'''



PROMPT_NO_REF = '''# Goal
You are a helpful assistant aiming to provide valuable critiques and analysis for the **evaluated response** of the **conversation history** between an user and an assistant.
Besides, we also provide some **initial criteria** for you to assit your evaluation.

Three input information are listed as below
---
## Conversation history
{query}
---

---
## Evaluated Response
{evaluated_response}
---
### NOTICE: the last response in the conversation history contains the citation symbols for better critiques, like [S1] and [S2]. Do NOT critique on these citation symbols

---
## Our Provided Criteria
{my_criteria}
---

# Task
To generate the valuable critiques, you should follow these 4 steps to generate valuable and accurate critiques:
**Step 1:** analyze the purpose of the user role in the previous conversation history
**Step 2:** generate the two-tier detailed criteria list (if our provided criteria list is not empty, please expand ours)
**Step 3:** generate your high-quality reference answer for better critiques
**Step 4:** generate your detailed feedbacks, including the description of response performance on each **first-tier primary criteria**. In the end, you should provide your final judgement score **x**, ranging from 1 to 10, by following the score rubrics below:

| Score Ranges | Definiation |
| ---------- | ----------- |
| 1 <= x < 3 | The quality is very low, containing numerous severe flaws; there are also other flaws, with Important error criteria. |
| 3 <= x < 5 | The quality is low, making it difficult to fulfill user query; There are many flaws, and a small number of severe flaws may be included. |
| 5 <= x < 7 | The quality is moderate, somewhat addressing the user query; There are a few errors, and a small number of severe errors may be included. |
| 7 <= x < 9 | The quality is approximately the same as the reference response (with the reference response scoring around 8). The response effectively answers user query. |
| 9 <= x < 10 | The quality is better than the reference, perfectly answering the user query in the conversation history. |

## NOTICE!!!
1. Quality scores (1-10) can be expressed as floating-point numbers.
2. Within specific score ranges, the more flaws there are, the lower quality score, and vice versa.
3. You should compare the evaluated response the reference before giving your quality score. Please follow the important guideline as follows: if evaluated response is worse than the reference, its score should be lower.

# Output
Please output your critiques
'''


PROMPT = '''# Goal
You are a helpful assistant aiming to provide valuable critiques and analysis for the **evaluated response** of the **conversation history** between an user and an assistant.
Besides, we also provide some **initial criteria** for you to assit your evaluation.

Three input information are listed as below
---
## Conversation history
{query}
---

---
## Evaluated Response
{evaluated_response}
---

---
## Our Provided Criteria
{my_criteria}
---

# Task
To generate the valuable critiques, you should follow these 4 steps to generate valuable and accurate critiques:
**Step 1:** analyze the purpose of the user role in the previous conversation history
**Step 2:** generate the two-tier detailed criteria list (if our provided criteria list is not empty, please expand ours)
**Step 3:** generate your high-quality reference answer for better critiques
**Step 4:** generate your detailed feedbacks, followed by a summarizaing containing the final judgemen score (ranging from 1 to 10, where higher score denotes better quaulity of evaluated response.)

# NOTICE
1. the last response in the conversation history contains the citation symbols for better critiques, like [S1] and [S2]. Do NOT critique on these citation symbols

# Output
Please output your critiques
'''




# 
FINETUNE_INST = "You are evaluating errors in a model-generated output for a given instruction."
FINETUNE_INPUT = """\
Instruction: 
{generation_instruction}
{input_context}

Model-generated Output: 
{hypothesis_output}

For each error you give in the response, please also elaborate the following information:
- error location (the words that are wrong in the output)
- error aspect it belongs to.
- explanation why it's an error, and the correction suggestions.
- severity of the error ("Major" or "Minor"). 
- reduction of score (between 0.5 and 5 given the severity of the error)

Your evaluation output:\
"""


# for ultracm
ultracm_instruction_template = """Given my answer to an instruction, your role is to provide specific and constructive feedback for me. You should find the best way for me to learn from your feedback and improve my performance. 

You should consider multiple aspects of my answer, including helpfulness, truthfulness, honesty, and to what extent the answer follows instructions.
---

### Instruction
{instruction}

### Answer
{completion}
---

Please act as a teacher and provide specific and constructive feedback. Besides describing the weaknesses of the answer, you should also provide specific suggestions to guide me toward understanding how to improve. Please note, however, that your suggestions should help me better complete the instructions, but you should not introduce new requirements that are not mentioned in the instructions. Your feedback should focus on enhancing my ability to think critically and respond accurately. However, never explicitly provide the reference answer, nor do polite phrases be required. Only respond with concise feedback in chat style. Finally, score the overall quality of the answer from 1 to 10, where 1 is the worst and 10 is the best.

*Format*
### Feedback
Overall Score: [1-10]
[Your feedback]

---

### Feedback
Overall Score: 
"""
def generate_feedback(generator, example, tokenizer):
    system_prompt = "User: A one-turn chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, very detailed, and polite answers to the user's questions.</s>"
    conv = [system_prompt]
    conv.append("User: " + ultracm_instruction_template.format(
                                    instruction=example["instruction"],
                                    completion=example["completion"],
                                    ) + "</s>")
    conv.append("Assistant: ")
    prompt = "\n".join(conv)

    with torch.no_grad():
        input_tokens = tokenizer.encode(prompt)
        length = len(input_tokens)
        input_tokens = torch.LongTensor(input_tokens).unsqueeze(0).cuda()
        response_tokens = generator.generate(input_tokens, do_sample=True, temperature=0.8, top_p=0.8, top_k=40, max_new_tokens=1024)
        response = tokenizer.decode(response_tokens[0][length:]).strip('\n').strip()
        # response = tokenizer.decode(response_tokens[0]).strip('\n').strip()
        response = response.replace('</s>', '').strip()
    return response


def generate_feedback_batch_default(generator, examples):
    prompts = []
    for example in examples:
        prompts.append(example['question'].strip())
    sampling_params = SamplingParams(temperature=1.0, top_p=1.0, max_tokens=1024, n=1)
    prompts = prompts[:10]
    outputs = generator.generate(prompts, sampling_params)
    responses = [output.outputs[0].text for output in outputs]
    return responses




def generate_feedback_batch(generator, examples):
    prompts = []
    for example in examples:
        system_prompt = "User: A one-turn chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, very detailed, and polite answers to the user's questions.</s>"
        conv = [system_prompt]
        conv.append("User: " + ultracm_instruction_template.format(
                                        instruction=example["instruction"],
                                        completion=example["completion"],
                                        ) + "</s>")
        conv.append("Assistant: ")
        prompt = "\n".join(conv)
        prompts.append(prompt)

    sampling_params = SamplingParams(temperature=0.8, top_k=40, top_p=0.8, max_tokens=1024, n=1)
    outputs = generator.generate(prompts, sampling_params)
    responses = [output.outputs[0].text for output in outputs]
    return responses



def segment_response(response, min_length=3, criteria_type=''):
    """segment the input response and add the citation box brackets
    Refer to: https://arxiv.org/pdf/2305.14627.pdf
    
    we segment the response with the sentence boundary

    如果是代码题，那么下面的匹配规则可能有问题
    """
    
    coding_criteria_types = set([
        'code_simplification',
        'code_generation',
        'explaining_code',
        'code_correction_rewriting',
        'code_to_code_translation',
        'airoboro2.2_coding'
    ])
    if criteria_type in coding_criteria_types:
        # use the \n as the separator
        segments = re.split(r'(\n+)', response)
        index = 1
        strings = []
        for segment in segments:
            if segment.strip():
                # pure text
                if segment[-1] in '.?!;:,+-*~!。，：！？':
                    segment_ = segment[:-1] + f" [S{index}]{segment[-1]}"
                else:
                    segment_ = segment + f" [S{index}]"
                strings.append(segment_)
                index += 1
            else:
                strings.append(segment)
        new_response = ''.join(strings)
    else:
        response += ' '    # 兼容如下的正则表达式中的空白符号匹配
        if not response.strip():
            response = 'Empty Response'
        # ignore the enumeration like "1. ...; 2. ...;"
        segments = re.split(r'([\.?!;]\s)', response)
        # segments = re.split(r'(?<=[^A-Z].[.?]) +(?=[A-Z])', response)
        segments = [segment for segment in segments if segment.strip()]
        # add the brackets sequentially
        index = 1
        if len(segments) == 1:
            # no punctation, just add the brackets in the end
            new_response = f'{segments[0]} [S{index}]'
            index += 1
        else:
            strings = []
            for segment in segments:
                if segment[0] in '.?!;':
                    if len(strings) == 0:
                        strings.append(segment)
                    elif len(strings[-1].strip()) >= min_length:
                        # effective citation
                        segment = f" [S{index}]{segment}"
                        strings.append(segment)
                        index += 1
                    else:
                        strings[-1] += segment
                else:
                    strings.append(segment)
                    
            if strings[-1].strip()[-1] not in '.?!;':
                strings[-1] = strings[-1].strip() + f" [S{index}]"
            new_response = ''.join(strings)
        try:
            assert index > 1
        except:
            if new_response.endswith('.') or new_response.endswith('?') or new_response.endswith('!') or new_response.endswith(';'):
                punc = new_response[-1]
                new_response = new_response[:-1] + f' [S1]{punc}'
            else:
                new_response += f' [S1].'
    # print(new_response)
    # exit()
    return new_response


class OpenLLM:

    def __init__(self, model_name='api_model', host='0.0.0.0', port=2333):
        self.model_name = model_name
        self.api_host = host
        self.api_port = port
        if model_name == 'themis':
            engine_args = EngineArgs(
                model='/cpfs02/llm/shared/public/lantian/hf_models/PKU-ONELab/Themis', 
                tensor_parallel_size=1,
                max_num_seqs=1024,
                max_num_batched_tokens=2048,
                gpu_memory_utilization=0.98,
                swap_space=16)
            self.pipe = LLMEngine.from_engine_args(engine_args)
            self.gen_config = SamplingParams(
                max_tokens=2048, 
                temperature=0.0, 
                n=1)
            self.model_name_ = model_name
        elif model_name == "autoj-classifier":
            self.PROMPT_INPUT_FOR_SCENARIO_CLS: str = "Identify the scenario for the user's query, output 'default' if you are uncertain.\nQuery:\n{input}\nScenario:\n"
            num_gpus = torch.cuda.device_count()
            model_name_or_dir = "GAIR/autoj-scenario-classifier" # or the local directory to store the downloaded model
            self.llm = LLM(model=model_name_or_dir, tensor_parallel_size=num_gpus)
        elif model_name == 'tigerscore':
            #self.scorer = TIGERScorer(model_name="/cpfs02/llm/shared/public/lantian/hf_models/TIGER-Lab/TIGERScore-7B", use_vllm=True)
            self.scorer = TIGERScorer(model_name="/cpfs02/llm/shared/public/lantian/hf_models/TIGER-Lab/TIGERScore-13B", use_vllm=True)
            self.model_name_ = 'tigerscore'
        elif model_name == 'ultracm':
            # self.tokenizer = LlamaTokenizer.from_pretrained("/cpfs01/user/lantian/openbmb/UltraCM-13b")
            # self.model = LlamaForCausalLM.from_pretrained("/cpfs01/user/lantian/openbmb/UltraCM-13b", device_map="auto")
            self.model_name_ = 'ultracm'
            #path = "/cpfs01/user/lantian/openbmb/UltraCM-13b"
            path = '/cpfs02/llm/shared/public/lantian/hf_models/openbmb/UltraCM-13b'
            self.model = LLM(model=path, tensor_parallel_size=1, trust_remote_code=True)
        elif model_name == 'api_model':
            path = f'{self.api_host}:{self.api_port}'
            self.api_client = APIClient(path)
            self.model_name_ = self.api_client.available_models[0]
        elif model_name == 'wizardlm_7b':
            path = '/cpfs01/shared/public/public_hdd/llmeval/model_weights/hf_hub/models--WizardLM--WizardLM-7B-V1.0/snapshots/b245eca88962e16b8ee4b21eb6b58c2e5f871217'
            path = '/cpfs01/user/lantian/WizardLM/WizardMath-7B-V1.0'
            self.model = LLM(model=path, tensor_parallel_size=1)
        elif model_name == "reward_model":
            # reward_model_path = '/cpfs01/shared/public/public_hdd/lvchengqi/ckpts/reward_model/R-Luyou-1B-8k-D20240130-v1-hf'
            reward_model_path = '/cpfs01/shared/public/public_hdd/lvchengqi/ckpts/reward_model/R-Ampere-7B-8k-D20240126-v1_hf'
            self.model = AutoModel.from_pretrained(
                reward_model_path, trust_remote_code=True, device_map="cuda", torch_dtype=torch.float16, attn_implementation="flash_attention_2"
            ).eval()
            self.tokenizer = AutoTokenizer.from_pretrained(reward_model_path, trust_remote_code=True)
        elif model_name in [
            'internlm2-20b-chat', 
            'internlm2-7b-chat',
            'mixtral-8x7b-instruct',
            'mixtral-8x22b-instruct',
            's2_add_critictuning',
            'llama-3.1-8b-instruct',
            'llama-3.1-70b-instruct',
            'llama-3-8b-instruct',
            'llama-3-70b-instruct',
            'qwen2-7b-instruct',
            'qwen2-72b-instruct',
            '/cpfs02/llm/shared/public/lantian/exp/20240418_aliyun_Ampere_7B_v1_1_FT_v1_0_0_s1_rc48_1660_hf_ckpt'
        ]:
            self.model_name_ = model_name
            if model_name == 'internlm2-20b-chat':
                path = '/cpfs01/shared/public/public_hdd/llmeval/model_weights/hf_hub/models--internlm--internlm2-chat-20b/snapshots/3f710f76f56f8c40dc5dd800dbe66f3341cb2c87'
                tp_num = 1
                with_model_name = True
            elif model_name == 'internlm2-7b-chat' or model_name  == '/cpfs02/llm/shared/public/lantian/exp/20240418_aliyun_Ampere_7B_v1_1_FT_v1_0_0_s1_rc48_1660_hf_ckpt':
                if model_name  == 'internlm2-7b-chat':
                    path = '/cpfs01/shared/public/llmeval/model_weights/hf_hub/models--internlm--internlm2-chat-7b/snapshots/70e6cdc9643ce7e3d9a369fb984dc5f1a1b2cec6'
                else:
                    path = '/cpfs02/llm/shared/public/lantian/exp/20240418_aliyun_Ampere_7B_v1_1_FT_v1_0_0_s1_rc48_1660_hf_ckpt'
                tp_num = 1
                with_model_name = True
            elif model_name == 'qwen2-72b-instruct':
                path = '/cpfs02/llm/shared/public/lantian/hf_models/Qwen/Qwen2-72B-Instruct'
                tp_num = 8
                with_model_name = False
            elif model_name == 'qwen2-7b-instruct':
                path = '/cpfs01/shared/public/llmeval/model_weights/hf_hub/models--Qwen--Qwen2-7B-Instruct/snapshots/f2826a00ceef68f0f2b946d945ecc0477ce4450c'
                tp_num = 1
                with_model_name = False
            elif model_name == 'llama-3-8b-instruct':
                path = '/cpfs02/llm/shared/public/lantian/hf_models/meta-llama-3-8B-Instruct'
                tp_num = 1
                with_model_name = True
            elif model_name == 'llama-3.1-8b-instruct':
                path = '/cpfs01/shared/public/llmeval/model_weights/hf_hub/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f'
                tp_num = 1
                with_model_name = False
            elif model_name == 'llama-3.1-70b-instruct':
                path = '/cpfs01/shared/public/llmeval/model_weights/hf_hub/models--meta-llama--Meta-Llama-3.1-70B-Instruct/snapshots/846357c7ee5e3f50575fd4294edb3d898c8ea100'
                tp_num = 8
                with_model_name = False
            elif model_name == 'llama-3-70b-instruct':
                path = '/cpfs01/shared/public/llmeval/model_weights/hf_hub/models--meta-llama--Meta-Llama-3-70B-Instruct/snapshots/7129260dd854a80eb10ace5f61c20324b472b31c'
                tp_num = 8
                with_model_name = False
            elif model_name == 'mixtral-8x7b-instruct':
                path = '/cpfs01/shared/public/llmeval/model_weights/hf_hub/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/bbae113847402a22031211225b5ee45c005de7dd'
                tp_num = 8
                with_model_name = False
            elif model_name == 'mixtral-8x22b-instruct':
                path = '/cpfs01/shared/public/llmeval/model_weights/hf_hub/models--mistralai--Mixtral-8x22B-Instruct-v0.1/snapshots/52572b20024761cb84b0f72af6f3733586f175f2'
                tp_num = 8
                with_model_name = False
            elif model_name == 's2_add_critictuning':
                path = '/cpfs02/llm/shared/public/lantian/exp/s2_add_critictuning_v01rc1/s2_no_critic_344/aliyun_Ampere_7B_v1.1_enchance_FT_v1.0.0_s1_rc47_s2_no_critictuning_v01rc1/344_hf_ckpt'
                tp_num = 1
                with_model_name = True
            else:
                raise Exception(f'[!] Unknown model:', model_name)
            if with_model_name is True:
                if model_name not in ['llama-3-8b-instruct', 'llama-3-70b-instruct']:
                    backend_config = PytorchEngineConfig(
                        session_len=32768, 
                        model_name='internlm2',
                        tp=tp_num
                    )
                else:
                    backend_config = PytorchEngineConfig(
                        session_len=32768, 
                        model_name='llama3',
                        tp=tp_num
                    )
            else:
                backend_config = PytorchEngineConfig(
                    session_len=32768, 
                    tp=tp_num
                )
            self.gen_config = GenerationConfig(temperature=0.0, max_new_tokens=4096)
            #self.gen_config = GenerationConfig(temperature=1.0, max_new_tokens=2048, top_k=50, top_p=0.95)
            self.pipe = pipeline(path, backend_config=backend_config)
            self.model_name_ = model_name
        elif model_name == 'autoj-13b':
            self.model_name_ = 'autoj-13b'
            #model_name_or_dir = '/cpfs01/user/lantian/GAIR/autoj-13b'
            model_name_or_dir = '/cpfs02/llm/shared/public/lantian/hf_models/GAIR/autoj-13b'
            num_gpus = torch.cuda.device_count()
            self.llm = LLM(model=model_name_or_dir, tensor_parallel_size=num_gpus)
        elif model_name == 'llama-2-13b':
            backend_config = PytorchEngineConfig(
                session_len=32768, 
                model_name='llama2', 
                tp=1
            )
            self.gen_config = GenerationConfig(temperature=0.0, max_new_tokens=4096)
            path = '/cpfs01/shared/public/public_hdd/llmeval/model_weights/hf_hub/models--meta-llama--Llama-2-13b-chat-hf/snapshots/c2f3ec81aac798ae26dcc57799a994dfbf521496'
            self.pipe = pipeline(path, backend_config=backend_config)
        elif '/cpfs02/llm/shared/public' in model_name:
            # fine-tuned internlms2 model
            self.model_name_ = model_name
            backend_config = PytorchEngineConfig(
                session_len=32768, 
                model_name='internlm2', 
                #model_name='llama3_1',
                #model_name='llama3',
                tp=1
            )
            if 'ultracm' in model_name:
                self.gen_config = GenerationConfig(temperature=1.0, max_new_tokens=4096, top_p=0.95, top_k=50)
            elif 'autoj' in model_name:
                self.gen_config = GenerationConfig(temperature=1.0, max_new_tokens=4096, top_p=0.95, top_k=50)
            elif 'promethues' in model_name or 'prometheus' in model_name:
                self.gen_config = GenerationConfig(temperature=1.0, top_p=1.0, top_k=50, max_new_tokens=4096)
            else:
                #self.gen_config = GenerationConfig(temperature=0.0, max_new_tokens=4096, top_k=1.0)
                self.gen_config = GenerationConfig(temperature=0.0, max_new_tokens=4096)
                #self.gen_config = GenerationConfig(temperature=0.5, top_p=0.95, top_k=50, max_new_tokens=4096)
                #self.gen_config = GenerationConfig(temperature=0.7, max_new_tokens=4096)
            #if 'llama_3_1' in model_name:
            #    #self.pipe = pipeline(model_name, backend_config=backend_config, chat_template_config=ChatTemplateConfig.from_json('llama3_1_chat_template.json'))
            #    self.pipe = pipeline(model_name, backend_config=backend_config, chat_template_config=ChatTemplateConfig('llama3_chat'))
            #    #self.pipe = pipeline(model_name, backend_config=backend_config, chat_template_config=ChatTemplateConfig.from_json('llama3_1_chat_template.json'))
            #else:
            self.pipe = pipeline(model_name, backend_config=backend_config)
        elif model_name in ['gpt-3.5-turbo', 'gpt-4-1106-preview', 'claude-instant-1']:
            ########## 
            # api_key = "sk-pAYATXjAo7f1pe7w1c38D095F57f4406Aa08Cb82D6EbCcA0"
            # self.client = OpenAI(api_key=api_key, base_url="https://api.ai-gaochao.cn/v1")
            ##########
            pass
        else:
            raise Exception(f'[!] Unknown model:', model_name)

    @torch.no_grad()
    def chat(self, item, max_new_tokens=2048, retry=20, segmented_response=False):
        n = 0
        response = None
        try:
            for _ in range(retry):
                try:
                    if self.model_name == 'tigerscore':
                        responses = self.scorer.score([''], [item['raw_data']['question']], [item['raw_data']['generation']])
                        response = json.dumps(responses[0])
                        return response
                    elif self.model_name == 'ultracm':
                        with torch.no_grad():
                            example = {'instruction': item['raw_data']['question'], 'completion': item['raw_data']['generation']}
                            response = generate_feedback(self.model, example, self.tokenizer)
                        return response
                    elif self.model_name == 'autoj-13b':
                        input_single   = build_autoj_input(
                            prompt=item['raw_data']['question'], 
                            resp1=item['raw_data']['generation'], resp2=None, 
                            protocol="single")
                        sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=1024)
                        output = self.llm.generate(input_single, sampling_params)
                        judgment = output[0].outputs[0].text
                        return judgment
                    elif self.model_name == 'internlm2':
                        response, history = self.model.chat(self.tokenizer, item['raw_data']['question'], history=[])
                    elif self.model_name == 'llama-2-13b':
                        prompt = [{'role': 'user', 'content': item['question']}]
                        response = self.pipe(prompt, gen_config=self.gen_config)
                        return response.text
                    else:
                        h = []
                        for i in range(0, len(history), 2):
                            h.append((history[i]['content'], history[i+1]['content']))
                        response, history = self.model.chat(self.tokenizer, msg, history=h)
                except Exception as error:
                    response = None
                    print(f'[!] meet the error of model {self.model_name_}: {error}; retry {n} times')
                    sys.stdout.flush()
                    time.sleep(random.random() * 5)
                    n += 1
                    continue
                    
                if response is not None:
                    break

            if response is None:
                print(f'[!] request process WRONG for {self.model_name_}')
                sys.stdout.flush()
            else:
                print(f'[!] request process successfully for {self.model_name_}')
                sys.stdout.flush()
            return response
        except Exception as error:
            print(f'[!] request process WRONG for {self.model_name_}')
            sys.stdout.flush()
            return None
        
    def chat_api(self, history, retry=5):
        n = 0
        response = None
        try:
            for _ in range(retry):
                for item in self.api_client.chat_completions_v1(
                    model=self.model_name_,
                    messages=history,
                    temperature=1.0,
                    top_p=0.98
                ):
                    if type(item) == str:
                        response = None
                    else:    
                        response = item['choices'][0]['message']['content']
                if response is not None:
                    break
                if type(response) == str and response.strip() == '' and item['choices'][0]['finish_reason'] == 'length':
                    break
                print(f'[!] meet strange error for {self.model_name_} inference, sleep randomly from 0 to 30s and retry')
                sys.stdout.flush()
                time.sleep(random.random() * 30)

            if response is None:
                print(f'[!] request process WRONG for {self.model_name_}')
            else:
                print(f'[!] request process successfully for {self.model_name_}')
            sys.stdout.flush()
                
            if type(response) == str and response.strip() == '' and item['choices'][0]['finish_reason'] == 'length':
                print(f'[!] LLM {self.model_name_} face the length reason, return EMPTY response')
                sys.stdout.flush()
                return 'LENGTH_EMPTY'
            elif type(response) == str and response.strip() == '':
                print(f'[!] LLM {self.model_name_} face the empty response')
                sys.stdout.flush()
                return None
            return response
        except Exception as error:
            print(f'[!] request process WRONG for {self.model_name_}:', error)
            sys.stdout.flush()
            return None

    @torch.no_grad()
    def batch_chat_close_source(self, msgs, max_new_tokens=2048, retry=20, set_name='', batch_size=2):
        index = 0
        pbar = tqdm(total=len(msgs))
        responses = []
        while index < len(msgs):
            
            #################################
            prompts = [{'model': self.model_name, 'messages': [{'role': "user", 'content': item['question']}]} for item in msgs[index:index+batch_size]]
            response = batch_chat(
                prompts,
                temp=0.0,
                model_name=self.model_name,
                debug=False,
                retry_num=20
            )
            #################################
            response = [i[0] for i in response]
            index += batch_size
            pbar.update(batch_size)
            responses.extend(response)
        return responses

    @torch.no_grad()
    def batch_chat(self, msgs, max_new_tokens=2048, retry=20, set_name='', with_reference=True, with_task=True, with_criteria=True):
        n = 0
        response = None
        try:
            for _ in range(retry):
                try:
                    if self.model_name == 'themis':
                        examples = []
                        for msg in msgs:
                            example = {
                                'question': msg['raw_data']['question'],
                                'generation': msg['raw_data']['generation'],
                                'set_name': set_name
                            }
                            examples.append(example)
                        responses = generate_themis(self.pipe, self.gen_config, examples)
                        return responses

                    elif self.model_name == 'tigerscore':
                        instructions = ['' for item in msgs]
                        input_contexts = [item['raw_data']['question'] for item in msgs]
                        hypo_outputs = [item['raw_data']['generation'] for item in msgs]
                        responses = self.scorer.score(instructions, hypo_outputs, input_contexts)
                        responses = [json.dumps(response) for response in responses]
                        return responses
                    elif self.model_name == 'ultracm':
                        examples = []
                        for msg in msgs:
                            example = {'instruction': msg['raw_data']['question'], 'completion': msg['raw_data']['generation']}
                            examples.append(example)
                        responses = generate_feedback_batch(self.model, examples)
                        return responses
                    elif self.model_name == 'autoj-13b':
                        inputs = []
                        for msg in msgs:
                            input_single = build_autoj_input(
                                prompt=msg['raw_data']['question'], 
                                resp1=msg['raw_data']['generation'], resp2=None, 
                                protocol="single")
                            inputs.append(input_single)
                        sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=1024)
                        outputs = self.llm.generate(inputs, sampling_params)
                        judgments = [output.outputs[0].text for output in outputs]
                        return judgments
                    #elif 'nips2024' not in self.model_name and (self.model_name in ['internlm2-20b-chat', 'internlm2-7b-chat', 'llama-2-13b'] or '/cpfs02/llm/shared' in self.model_name):
                    elif '5_l1_resumm' not in self.model_name and (self.model_name in ['internlm2-20b-chat', 'internlm2-7b-chat', 'llama-2-13b', 'llama-3.1-8b-instruct', 'llama-3.1-70b-instruct', 'qwen2-7b-instruct', 'qwen2-72b-instruct'] or '/cpfs02/llm/shared' in self.model_name):
                    #elif 'transfer_from_tos' not in self.model_name and (self.model_name in ['internlm2-20b-chat', 'internlm2-7b-chat', 'llama-2-13b', 'llama-3-8b-instruct', 'llama-3-70b-instruct'] or '/cpfs02/llm/shared' in self.model_name):
                    #elif '20240811' not in self.model_name and (self.model_name in ['internlm2-20b-chat', 'internlm2-7b-chat', 'llama-2-13b'] or '/cpfs02/llm/shared' in self.model_name):
                    #elif 's2_add_critic_' not in self.model_name and (self.model_name in ['internlm2-20b-chat', 'internlm2-7b-chat', 'llama-2-13b'] or '/cpfs02/llm/shared' in self.model_name):
                    #elif 's2_add_autoj_critic_' in self.model_name or 's2_add_ultrafeedback_critic_' in self.model_name or 's2_no_critic_' in self.model_name:
                        # internlm2 model or the fine-tuned internlm2 model
                        if self.model_name in ['internlm2-20b-chat', 'internlm2-7b-chat', 'llama-2-13b', 'llama-3.1-8b-instruct', 'llama-3.1-70b-instruct', 'qwen2-7b-instruct', 'qwen2-72b-instruct', 'llama-3-8b-instruct', 'llama-3-70b-instruct'] or 's2_no_critic_' in self.model_name or 's1_rc48_1660_hf_ckpt' in self.model_name:
                            mode = 'zero-shot'
                        elif 'autoj' in self.model_name:
                            mode = 'autoj'
                        elif 'tigerscore' in self.model_name:
                            mode = 'tigerscore'
                        elif 'ultracm' in self.model_name or 'ultrafeedback' in self.model_name:
                            mode = 'ultracm'
                        elif 'promethues' in self.model_name or 'prometheus' in self.model_name:
                            mode = 'promethues'
                        elif 'nips2024' in self.model_name:
                            mode = 'nips2024'
                        batch_size = 32
                        index = 0
                        pbar = tqdm(total=len(msgs))
                        outputs = []
                        instruction = instruction_prompts[set_name]
                        my_criteria = instruction_criteria[set_name]
                        while index < len(msgs):
                            msgs_ = [
                                template_maker(msg, mode, domain_name=set_name, instruction=instruction) for msg in msgs[index:index+batch_size]
                            ]
                            responses = self.pipe(msgs_, gen_config=self.gen_config)
                            responses = [response.text for response in responses]
                            index += batch_size
                            outputs.extend(responses)
                            pbar.update(len(msgs_))
                        return outputs
                    #elif 'nips2024' in self.model_name and '_st' in self.model_name:
                    elif '_st' in self.model_name:
                    #elif '20240811' in self.model_name:
                    #elif 's2_add_critic_' in self.model_name:
                    #elif 'transfer_from_tos' in self.model_name:
                        batch_size = 32
                        index = 0
                        pbar = tqdm(total=len(msgs))
                        outputs = []
                        while index < len(msgs):
                            instruction = instruction_prompts[set_name]
                            my_criteria = instruction_criteria[set_name]
                            # build the conversation history
                            #### history with system prompt
                            histories = []
                            evaluated_responses = []
                            for msg in msgs[index:index+batch_size]:
                                if set_name in ['code_exec', 'code_not_exec', 'math_pot']:
                                    criteria_type = 'code_generation'
                                    #criteria_type = ''
                                    #with_criteria = True
                                    #with_reference = False
                                else:
                                    criteria_type = ''
                                    #if set_name in ['math_cot']:
                                    #    #with_criteria = True
                                    #    with_reference = False
                                    #else:
                                    #    #with_criteria = False
                                    #    with_reference = True
                                if instruction:
                                    ch = [{
                                        'role': 'user', 
                                        'content': f'# Instruction\n{instruction}\n\n# Input\n{msg["raw_data"]["question"]}'
                                    }]
                                else:
                                    ch = [{
                                        'role': 'user', 
                                        'content': f'# Input\n{msg["raw_data"]["question"]}'
                                    }]
                                if set_name in ['code_exec', 'code_not_exec']:
                                    evaluated_response = segment_response(
                                        msg['raw_data']['generation'], 
                                        criteria_type=criteria_type)
                                    if set_name == 'code_exec':
                                        evaluated_response = f'# Evaluated Response (Code)\n' + evaluated_response + '\n# Automatic Execution of Evaluated Code\n' + msg['raw_data']['exec_rest']
                                else:
                                    evaluated_response = segment_response(
                                        msg['raw_data']['generation'], 
                                        criteria_type=criteria_type)
                                evaluated_responses.append(evaluated_response)
                                # if set_name in ['code_exec', 'code_not_exec']:
                                #     evaluated_response = f'# Unit Test:\n' + msg['raw_data']['unit_test'] + '\n' + evaluated_response
                                assert with_reference in [True, False]
                                assert with_task in [True, False]
                                assert with_criteria in [True, False]
                                prompt = template_single_turn_nips2024(ch, evaluated_response, my_criteria=my_criteria, with_reference=with_reference, with_task=with_task, with_criteria=with_criteria)
                                histories.append(prompt)
                            responses = self.pipe(histories, gen_config=self.gen_config)
                            assert len(responses) == len(evaluated_responses)
                            responses = [response.text + '\n----------\n' + evaluated_response for response, evaluated_response in zip(responses, evaluated_responses)]
                            index += batch_size
                            outputs.extend(responses)
                            pbar.update(len(histories))
                        return outputs
                    elif 'nips2024' in self.model_name:
                        batch_size = 32
                        index = 0
                        pbar = tqdm(total=len(msgs))
                        outputs = []
                        while index < len(msgs):

                            instruction = instruction_prompts[set_name]
                            my_criteria = instruction_criteria[set_name]

                            # build the conversation history
                            histories = []
                            for msg in msgs[index:index+batch_size]:
                                histories.append([{
                                    'role': 'user', 
                                    'content': f'# Instruction\n{instruction}.\nPlease make sure Do NOT revise the number and the type of the input arguments of the generation function listed in the Unit Test.\n\n# Input\n{msg["raw_data"]["question"]}\n# Unit Test\n{msg["raw_data"]["unit_test"]}'
                                }])
                                # append the evaluated response
                                if set_name in ['code_exec', 'code_not_exec', 'math_pot']:
                                    criteria_type = 'code_generation'
                                else:
                                    criteria_type = ''
                                if set_name == 'code_exec':
                                    evaluated_response = segment_response(
                                        msg['raw_data']['generation'], 
                                        criteria_type=criteria_type)
                                    evaluated_response = f'# Evaluated Response (Code)\n' + evaluated_response + '\n# Automatic Execution of Evaluated Code\n' + msg['raw_data']['exec_rest']
                                else:
                                    evaluated_response = segment_response(
                                        msg['raw_data']['generation'], 
                                        criteria_type=criteria_type)
                                histories[-1].append({
                                    'role': 'assistant',
                                    'content': evaluated_response
                                })
                                # append the user prompt, with the `my_criteria` as input
                                user_prompt = overall_prompts[0].format(my_criteria=my_criteria)
                                # no-ref
                                # user_prompt = overall_prompts[1].format(my_criteria=my_criteria)
                                # no-description
                                # user_prompt = overall_prompts[2].format(my_criteria=my_criteria)
                                histories[-1].append({
                                    'role': 'user',
                                    'content': user_prompt
                                })
                            responses = self.pipe(histories, gen_config=self.gen_config)
                            responses = [response.text for response, history in zip(responses, histories)]
                            index += batch_size
                            outputs.extend(responses)
                            pbar.update(len(histories))
                        return outputs
                    else:
                        raise Exception()

                except Exception as error:
                    response = None
                    print(f'[!] meet the error of model {self.model_name_}: {error}; retry {n} times')
                    sys.stdout.flush()
                    time.sleep(random.random() * 5)
                    n += 1
                    continue
                    
                if response is not None:
                    break

            if response is None:
                print(f'[!] request process WRONG for {self.model_name_}')
                sys.stdout.flush()
            else:
                print(f'[!] request process successfully for {self.model_name_}')
                sys.stdout.flush()
            return response, history
        except Exception as error:
            print(f'[!] request process WRONG for {self.model_name_}')
            sys.stdout.flush()
            return None, None

    #def batch_chat(self, msgs, max_new_tokens=2048, retry=20, set_name='', with_reference=True, with_task=True, with_criteria=True):
    @torch.no_grad()
    def batch_chat_multi_turn(self, msgs, max_new_tokens=2048, set_name='', retry=20, with_reference=True, with_task=True, with_criteria=True):
        '''only for the our method'''
        n = 0
        response = None
        batch_size = 32
        responses = []
        my_criteria = instruction_criteria[set_name]
        instruction = instruction_prompts[set_name]

        #histories.append([{
        #    'role': 'user', 
        #     'content': f'# Instruction\n{instruction}.\nPlease make sure Do NOT revise the number and the type of the input arguments of the generation function listed in the Unit Test.\n\n# Input\n{msg["raw_data"]["question"]}\n# Unit Test\n{msg["raw_data"]["unit_test"]}'
        # }])
        for _ in range(retry):
            try:
                pbar = tqdm(total=len(msgs))
                for index in range(0, len(msgs), batch_size):
                    histories = []
                    msgs_ = [msgs[index] for index in range(index, min(len(msgs), index+batch_size))]
                    for msg in msgs_:
                        #########################################
                        if set_name in ['code_exec', 'code_not_exec']:
                            evaluated_response = segment_response(
                                msg['raw_data']['generation'], 
                                criteria_type='code_generation')
                            if set_name == 'code_exec':
                                evaluated_response = f'# Evaluated Response (Code)\n' + evaluated_response + '\n# Automatic Execution of Evaluated Code\n' + msg['raw_data']['exec_rest']
                        else:
                            evaluated_response = segment_response(
                                msg['raw_data']['generation'], 
                                criteria_type='')
                        if set_name in ['code_exec', 'code_not_exec']:
                            evaluated_response = f'# Unit Test:\n' + msg['raw_data']['unit_test'] + '\n' + evaluated_response
                        #########################################
                        #evaluated_response = segment_response(
                        #    msg['raw_data']['generation'],
                        #    criteria_type=set_name)

                        query = f'# Instruction\n{instruction}\n\n# Input\n{msg["raw_data"]["question"]}'
                        user_1 = multi_turn_prompts_['initial_input'][0].format(response=evaluated_response, conversation_history=query)
                        user_1 = {'role': 'user', 'content': user_1}
                        histories.append([user_1])

                    # generate task
                    if with_task is True:
                        for h in histories:
                            user_2 = multi_turn_prompts_['task'][0]
                            user_2 = {'role': 'user', 'content': user_2}
                            h.append(user_2)
                        res = self.pipe(histories, gen_config=self.gen_config)
                        assert len(res) == len(histories)
                        for ii, (r, h) in enumerate(zip(res, histories)):
                            h.append({'role': 'assistant', 'content': r.text})

                    #### generate critiera
                    if with_criteria is True:
                        for ii, h in enumerate(histories):
                            if my_criteria:
                                user_3 = multi_turn_prompts_['criteria'][0].format(my_criteria=my_criteria)
                            else:
                                user_3 = multi_turn_prompts_['criteria'][1]
                            user_3 = {'role': 'user', 'content': user_3}
                            h.append(user_3)
                        res = self.pipe(histories, gen_config=self.gen_config)
                        assert len(res) == len(histories)
                        for ii, (r, h) in enumerate(zip(res, histories)):
                            h.append({'role': 'assistant', 'content': r.text})
                    
                    #### generate reference
                    if with_reference is True:
                        for ii, h in enumerate(histories):
                            user_4 = multi_turn_prompts_['reference'][0]
                            user_4 = {'role': 'user', 'content': user_4}
                            h.append(user_4)
                        res = self.pipe(histories, gen_config=self.gen_config)
                        assert len(res) == len(histories)
                        for ii, (r, h) in enumerate(zip(res, histories)):
                            h.append({'role': 'assistant', 'content': r.text})

                    # generate feedback
                    for ii, h in enumerate(histories):
                        user_5 = multi_turn_prompts_['feedback'][0]
                        user_5 = {'role': 'user', 'content': user_5}
                        h.append(user_5)
                    res = self.pipe(histories, gen_config=self.gen_config)
                    assert len(res) == len(histories)
                    for ii, (r, h) in enumerate(zip(res, histories)):
                        h.append({'role': 'assistant', 'content': r.text})

                    # generate summarization
                    for ii, h in enumerate(histories):
                        user_6 = multi_turn_prompts_['summarization'][0]
                        user_6 = {'role': 'user', 'content': user_6}
                        h.append(user_6)
                    res = self.pipe(histories, gen_config=self.gen_config)
                    assert len(res) == len(histories)
                    for ii, (r, h) in enumerate(zip(res, histories)):
                        h.append({'role': 'assistant', 'content': r.text})

                    # update
                    responses.extend(['\n----------\n'.join([utterance['content'] for utterance in h if utterance['role'] == 'assistant']) for h in histories])
                    pbar.update(batch_size)
                return responses

            except Exception as error:
                response = None
                print(f'[!] meet the error of model {self.model_name_}: {error}; retry {n} times')
                sys.stdout.flush()
                time.sleep(random.random() * 5)
                n += 1
                continue
        return None

    @torch.no_grad()
    def batch_chat_default(self, msgs, max_new_tokens=2048, retry=20, set_name=''):
        '''tigerscore cannot use default mode'''
        n = 0
        response = None
        try:
            for _ in range(retry):
                try:
                    if self.model_name == 'ultracm':
                        responses = generate_feedback_batch_default(self.model, msgs)
                        return responses
                    elif self.model_name == 'autoj-13b':
                        inputs = [msg['question'] for msg in msgs]
                        sampling_params = SamplingParams(temperature=1.0, top_p=1.0, max_tokens=1024)
                        inputs = inputs[:10]
                        outputs = self.llm.generate(inputs, sampling_params)
                        judgments = [output.outputs[0].text for output in outputs]
                        return judgments
                    else:
                        # internlm2 model or the fine-tuned internlm2 model
                        batch_size = 32
                        index = 0
                        pbar = tqdm(total=len(msgs))
                        outputs = []
                        while index < len(msgs):
                            msgs_ = [
                                [{'role': 'user', 'content': msg['question']}] for msg in msgs[index:index+batch_size]
                            ]
                            responses = self.pipe(msgs_, gen_config=self.gen_config)
                            responses = [response.text for response in responses]
                            index += batch_size
                            outputs.extend(responses)
                            pbar.update(len(msgs_))
                        return outputs
                except Exception as error:
                    response = None
                    print(f'[!] meet the error of model {self.model_name_}: {error}; retry {n} times')
                    sys.stdout.flush()
                    time.sleep(random.random() * 5)
                    n += 1
                    continue
                    
                if response is not None:
                    break

            if response is None:
                print(f'[!] request process WRONG for {self.model_name_}')
                sys.stdout.flush()
            else:
                print(f'[!] request process successfully for {self.model_name_}')
                sys.stdout.flush()
            return response, history
        except Exception as error:
            print(f'[!] request process WRONG for {self.model_name_}')
            sys.stdout.flush()
            return None, None

    @torch.no_grad()
    def batch_chat_multi_turn_v2(self, msgs, max_new_tokens=2048, set_name='', retry=20, with_reference=True, with_task=True, with_criteria=True):
        '''only for the our method'''
        n = 0
        response = None
        batch_size = 32
        responses = []
        my_criteria = instruction_criteria[set_name]
        instruction = instruction_prompts[set_name]

        for _ in range(retry):
            try:
                pbar = tqdm(total=len(msgs))
                for index in range(0, len(msgs), batch_size):
                    histories = []
                    msgs_ = [msgs[index] for index in range(index, min(len(msgs), index+batch_size))]
                    for msg in msgs_:
                        #########################################
                        if set_name in ['code_exec', 'code_not_exec']:
                            evaluated_response = segment_response(
                                msg['raw_data']['generation'], 
                                criteria_type='code_generation')
                            if set_name == 'code_exec':
                                evaluated_response = f'# Evaluated Response (Code)\n' + evaluated_response + '\n# Automatic Execution of Evaluated Code\n' + msg['raw_data']['exec_rest']
                        else:
                            evaluated_response = segment_response(
                                msg['raw_data']['generation'], 
                                criteria_type='')
                        if set_name in ['code_exec', 'code_not_exec']:
                            evaluated_response = f'# Unit Test:\n' + msg['raw_data']['unit_test'] + '\n' + evaluated_response
                        #########################################
                        #evaluated_response = segment_response(
                        #    msg['raw_data']['generation'],
                        #    criteria_type=set_name)

                        query = f'# Instruction\n{instruction}\n\n# Input\n{msg["raw_data"]["question"]}'
                        user_1 = multi_turn_prompts_['initial_input'][0].format(response=evaluated_response, conversation_history=query)
                        user_1 = {'role': 'user', 'content': user_1}
                        histories.append([user_1])

                    #### generate reference
                    if with_reference is True:
                        for ii, h in enumerate(histories):
                            user_4 = multi_turn_prompts_['reference'][1]
                            user_4 = {'role': 'user', 'content': user_4}
                            h.append(user_4)
                        res = self.pipe(histories, gen_config=self.gen_config)
                        assert len(res) == len(histories)
                        for ii, (r, h) in enumerate(zip(res, histories)):
                            h.append({'role': 'assistant', 'content': r.text})

                    # generate task
                    if with_task is True:
                        for h in histories:
                            user_2 = multi_turn_prompts_['task'][0]
                            user_2 = {'role': 'user', 'content': user_2}
                            h.append(user_2)
                        res = self.pipe(histories, gen_config=self.gen_config)
                        assert len(res) == len(histories)
                        for ii, (r, h) in enumerate(zip(res, histories)):
                            h.append({'role': 'assistant', 'content': r.text})

                    #### generate critiera
                    if with_criteria is True:
                        for ii, h in enumerate(histories):
                            if my_criteria:
                                user_3 = multi_turn_prompts_['criteria'][0].format(my_criteria=my_criteria)
                            else:
                                user_3 = multi_turn_prompts_['criteria'][1]
                            user_3 = {'role': 'user', 'content': user_3}
                            h.append(user_3)
                        res = self.pipe(histories, gen_config=self.gen_config)
                        assert len(res) == len(histories)
                        for ii, (r, h) in enumerate(zip(res, histories)):
                            h.append({'role': 'assistant', 'content': r.text})
                    

                    # generate feedback
                    for ii, h in enumerate(histories):
                        user_5 = multi_turn_prompts_['feedback'][0]
                        user_5 = {'role': 'user', 'content': user_5}
                        h.append(user_5)
                    res = self.pipe(histories, gen_config=self.gen_config)
                    assert len(res) == len(histories)
                    for ii, (r, h) in enumerate(zip(res, histories)):
                        h.append({'role': 'assistant', 'content': r.text})

                    # generate summarization
                    for ii, h in enumerate(histories):
                        user_6 = multi_turn_prompts_['summarization'][0]
                        user_6 = {'role': 'user', 'content': user_6}
                        h.append(user_6)
                    res = self.pipe(histories, gen_config=self.gen_config)
                    assert len(res) == len(histories)
                    for ii, (r, h) in enumerate(zip(res, histories)):
                        h.append({'role': 'assistant', 'content': r.text})

                    # update
                    responses.extend(['\n----------\n'.join([utterance['content'] for utterance in h if utterance['role'] == 'assistant']) for h in histories])
                    pbar.update(batch_size)
                return responses

            except Exception as error:
                response = None
                print(f'[!] meet the error of model {self.model_name_}: {error}; retry {n} times')
                sys.stdout.flush()
                time.sleep(random.random() * 5)
                n += 1
                continue
        return None





    @torch.no_grad()
    def batch_chat_correction(self, msgs, set_name, max_new_tokens=2048, retry=20):
        n = 0
        response = None
        try:
            for _ in range(retry):
                try:
                    #feedback = feedback.split('\n----------\n')[0]

                    template = '''You task is to revise the generation for the query given the feedback. The query, generation, and the feedback are shown as follows:
---
# Query:
{query}
---
# Generation
{generation}
---
# Feedback
{feedback}
---

# NOTICE
* Note each sentence in generation are highlighted with the label like [S1], [S2]. You should revise each sentence based on the suggestions in feedback (each feedback entry focus for some sentences).
* Your revised answer should not contain any label like [S1] and [S2]. 
* If the generation is a python code inside a markdown python format like ```python\n...\n```, your revised answer should also follow the markdown python format, generating your revised code in ```python\n```
* Directly generate your revised answer, and **Do NOT** generate any explanation like "Sure, I would like to ...".
* Please Do NOT add any test cases, only revise the evaluated answer (rationale or code).
* If the generation contains code, Please Keep the input and output format the same as the unit test.
'''

                    index = 0
                    batch_size = 4
                    pbar = tqdm(total=len(msgs))
                    outputs = []

                    while index < len(msgs):
                        if set_name in ['math_pot', 'code_exec', 'code_not_exec']:
                            criteria_type='code_generation'
                        else:
                            criteria_type = ''
                        msgs_ = []
                        for msg in msgs[index:index+batch_size]:
                            feedback = json.loads(msg['raw_data'])['feedback']
                            if '\n----------\n' in feedback:
                                feedback = feedback.split('\n----------\n')[0]
                            # feedback = '\n'.join(feedbacks[-2:])
                            # feedback = feedback.replace('# Ground-Truth Feedback', '')
                            # feedback = feedback.replace('Citation Symbol', 'Where is the Error')

                            # items = re.split('(# Feedback List)', feedback)
                            # items = re.split('(# Reference)', feedback)
                            # try:
                            #     assert len(items) == 3
                            #     feedback = items[1] + items[2] 
                            # except:
                            #     pass
                            if set_name in ['code_exec']:
                                generation = '###### Evaluated Code:\n' + segment_response(json.loads(msg['raw_data'])['generation'], criteria_type=criteria_type) + '\n###### Unit Test:\n' + json.loads(msg['raw_data'])['unit_test']
                                generation += '\n###### Execution Result:\n' + json.loads(msg['raw_data'])['exec_rest'] + '\nThe execution result could assist you in analyzing the flaws in evaluated code.'
                            elif set_name in ['code_not_exec']:
                                generation = '###### Evaluated Code:\n' + segment_response(json.loads(msg['raw_data'])['generation'], criteria_type=criteria_type) + '\n###### Unit Test:\n' + json.loads(msg['raw_data'])['unit_test']
                            else:
                                generation = segment_response(json.loads(msg['raw_data'])['generation'], criteria_type=criteria_type)

                            p = template.format(
                                query=json.loads(msg['raw_data'])['question'],
                                generation=generation,
                                feedback=feedback
                            )
                            msgs_.append(p)
                            # if set_name in ['code_exec', 'code_not_exec']:
                            #     ipdb.set_trace()
                        responses = self.pipe(msgs_, gen_config=self.gen_config)
                        responses = [response.text for response in responses]
                        index += batch_size
                        outputs.extend(responses)
                        pbar.update(len(msgs_))
                    return outputs

                except Exception as error:
                    response = None
                    print(f'[!] meet the error of model {self.model_name_}: {error}; retry {n} times')
                    sys.stdout.flush()
                    time.sleep(random.random() * 5)
                    n += 1
                    continue
                    
                if response is not None:
                    break

            if response is None:
                print(f'[!] request process WRONG for {self.model_name_}')
                sys.stdout.flush()
            else:
                print(f'[!] request process successfully for {self.model_name_}')
                sys.stdout.flush()
            return response, history
        except Exception as error:
            print(f'[!] request process WRONG for {self.model_name_}')
            sys.stdout.flush()
            return None, None






    
def test_chat(chat_data_sample, model):
    response, history = model.chat_api(chat_data_sample)
    print(response)
    print('=' * 50)



def template_maker_nips2024(history, msg, mode='task', domain_name='translation', my_criteria=''):
    if mode == 'task':
        prompt = task_criteria_prompts[0].format(my_criteria=my_criteria)
    elif mode == 'reference':
        prompt = reference_prompts[0]
    elif mode in ['feedback', 'feedback_no_ref']:
        if domain_name in ['code_exec', 'code_not_exec', 'math_pot']:
            criteria_type = 'code_generation'
        else:
            criteria_type = ''
        if domain_name == 'code_exec':
            # response = segment_response(f'Evaluated Code:\n' + msg['raw_data']['generation'] + '\nAutomatic Execution:\n' + msg['raw_data']['exec_rest'], criteria_type=criteria_type)
            response = segment_response(msg['raw_data']['generation'], criteria_type=criteria_type)
            response = f'###### Evaluated Response\n' + response + '\n###### Automatic Evaluation\n' + msg['raw_data']['exec_rest']
        else:
            response = segment_response(msg['raw_data']['generation'], criteria_type=criteria_type)
        if mode == 'feedback':
            prompt = feedback_prompts[0].format(evaluated_response=response)
        else:
            prompt = feedback_no_ref_prompts[0].format(evaluated_response=response)
    history.append({'role': 'user', 'content': prompt})


def template_maker(msg, mode='zero-shot', domain_name='translate', instruction=None):
    # zero-shot baseline template 
    if mode == 'zero-shot':
        return [{'role': 'user', 'content': msg['question']}]
    elif mode == 'promethues':
        user_content = '''###Task Description:
An instruction (might include an Input inside it), a response to evaluate, a reference answer are given.
1. Write a criteria
2. Write a reference answer 
3. Write a detailed feedback that assess the quality of the response strictly based on the given score rubric, not evaluating in general.
2. After writing a feedback, write a score that is an integer between 1 and 5. You should refer to the score rubric.
3. The output format should look as follows: \"Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)\"
4. Please do not generate any other opening, closing, and explanations.

###The instruction to evaluate:
{orig_instruction}

###Response to evaluate:
{orig_response}

###Feedback:
'''
        user_content = user_content.format(
            orig_instruction=msg['raw_data']['question'], 
            orig_response=msg['raw_data']['generation']
        )
        return [
            #{'role': 'system', 'content': "A chat between a curious user and an artificial intelligence expert. The expert gives helpful, specific, and concise answers to the user's questions."},
            {'role': 'user', 'content': user_content}
        ]
    elif mode == 'autoj':
        if domain_name == 'code_exec':
            prompt = build_autoj_input(
                prompt=f'# Instruction: {instruction}\n' + msg['raw_data']['question'], 
                resp1='# Evaluated Code:\n' + msg["raw_data"]['generation'] + '\nAutomatic Execution:\n' + msg['raw_data']['exec_rest'], 
                resp2=None, 
                protocol="single")
        else:
            prompt = build_autoj_input(
                prompt=f'# Instruction: {instruction}\n' + msg['raw_data']['question'], 
                resp1=msg['raw_data']['generation'], resp2=None, 
                protocol="single")
        return [
            {'role': 'system', 'content': "A chat between a curious user and an artificial intelligence expert. The expert gives helpful, specific, and concise answers to the user's questions."},
            {'role': 'user', 'content': prompt}
        ]
    elif mode == 'ultracm':
        if domain_name == 'code_exec':
            user_content = ultracm_instruction_template.format(
                instruction=f'# Instruction: {instruction}\n' + msg['raw_data']['question'], 
                completion='# Evaluated Code:\n' + msg["raw_data"]['generation'] + '\nAutomatic Execution:\n' + msg['raw_data']['exec_rest'], 
            )
        else:
            user_content = ultracm_instruction_template.format(
                instruction=f'# Instruction: {instruction}\n' + msg['raw_data']['question'], 
                prompt=msg['raw_data']['question'], 
                completion=msg["raw_data"]['generation'],
            )
        return [
            {'role': 'system', 'content': "A chat between a curious user and an artificial intelligence expert. The expert gives helpful, specific, and concise answers to the user's questions."},
            {'role': 'user', 'content': user_content}
        ]
    elif mode == 'tigerscore':
        FINETUNE_INST = "You are evaluating errors in a model-generated output for a given instruction."
        FINETUNE_INPUT = """\
Instruction: {generation_instruction}
{input_context}


Model-generated Output:
{hypothesis_output}


For each error you give in the response, please also elaborate the following information:
- error location (the words that are wrong in the output)
- error aspect it belongs to.
- explanation why it's an error, and the correction suggestions.
- severity of the error ("Major" or "Minor").
- reduction of score (between 0.5 and 5 given the severity of the error)

Your evaluation output:
"""
        user_content = FINETUNE_INST + '\n' + FINETUNE_INPUT.format(
            generation_instruction='',
            input_context=msg['raw_data']['question'],
            hypothesis_output=msg['raw_data']['generation']
        )
        user_content = user_content.strip('\n ') + '\n'
        return [
            {'role': 'system', 'content': "A chat between a curious user and an artificial intelligence expert. The expert gives helpful, specific, and concise answers to the user's questions."},
            {'role': 'user', 'content': user_content}
        ]
    else:
        raise Exception(f'[!] Unknow mode:', mode)


######## single-turn prompt ########
def template_single_turn_nips2024(queries, evaluated_response, my_criteria='', with_reference=True, with_task=True, with_criteria=True):
    query = '\n'.join([f'**{item["role"]}**: {item["content"]}' for item in queries])
    if with_reference is True and with_task is True and with_criteria is True:
        pp = PROMPT
    elif with_reference is False and with_task is False and with_criteria is False:
        pp = PROMPT_NO_ALL
    elif with_reference is False:
        pp = PROMPT_NO_REF
    elif with_task is False:
        pp = PROMPT_NO_TASK
    elif with_criteria is False:
        pp = PROMPT_NO_CRITERIA
    pp = pp.format(query=query, evaluated_response=evaluated_response, my_criteria=my_criteria)
    #system_prompt = [{'role': 'system', 'content': "A chat between a curious user and an artificial intelligence expert. The expert gives helpful, specific, and concise answers to the user's questions."}]
    #return system_prompt + [dict(role='user', content=pp)]
    return [dict(role='user', content=pp)]
    
    
if __name__ == "__main__":
    model = OpenLLM(**{
        'model_name': 'api_model',
        'host': 'http://192.168.148.58',
        'port': 2333
    })
    response = model.chat_api([{'role': 'user', 'content': 'who are you?'}])
    print(response)


