import json
import os
import time
import argparse
import re
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import openai
import traceback
from vllm import LLM, SamplingParams
from prompts.RAL_evaluator import EVALUATION_PROMPT
from operator import itemgetter


def load_jsonl(file_path):
    _data = []
    with open(file_path, 'r') as f:
        for data in f:
            jline = json.loads(data)
            _data.append(jline)
    return _data

def make_query(user_input, system_prompt, examples):   
    query = ""
    query += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" # will tokenizer add the first <|begin_of_text|> ??
    for example in examples:
        query += f"{example['user']} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n {example['assistant']} <|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
    query += f"{user_input} <|eot_id|>"

    query += "<|start_header_id|>assistant<|end_header_id|>\n\n"

    return query

def get_content(line):
    instruction = line['instruction'][:6000]
    question = line['question']
    if line['output'] != None:
        output = line['output'][:4000]
    else:
        output = 'None'
    content =  SYS_MSG.format(input=instruction, output=output, question=question)
    return content

def get_payload(line):
    content = get_content(line)
    payload = {
        "model": "gpt-4o",
        "messages": [
            {
                "role": "user",
                "content": content
            }
        ],
        "max_tokens": 8192,
        "temperature": 0.0,
        "top_p": 0.3,
        "stream": True
    }
    return payload


def save_jsonl(entry, sava_path):
    with open(sava_path, 'a', encoding='utf-8')as file:
        file.write(json.dumps(entry, ensure_ascii=False)+ "\n")


def get_answer(input_data: dict, retry=30):
    entry, save_path = input_data['data'], input_data['save_path']
    if retry<0:
        return
    try:
        ### gpt
        payload = get_payload(entry)
        # chat_completion = openai.ChatCompletion.create(model=payload['model'], temperature=0, messages=payload['messages'])
        # generation = chat_completion.choices[0].message.content

        ### llama3
        content = get_content(entry)
        query = make_query(content, "", "")
        output = SYS_VLLM.generate(query, SYS_PARA)
        generation = output[0].outputs[0].text
        ###
        
        if generation == None or generation == "":
            get_answer(input_data, retry=retry-1)

        re_result = re.findall(r'答案：是|答案：否', generation)
        if len(re_result) == 1:
            if "是" in re_result[0]:
                entry['point_judge'] = True
            else:
                entry['point_judge'] = False
        else: 
            if "是" in generation and "否" not in generation:
                entry['point_judge'] = True
            else:
                entry['point_judge'] = False
        
        entry['point_explanation'] = generation
        entry['payload'] = payload
        save_jsonl(entry, save_path)
        return entry
    except Exception as e:
        time.sleep(1.2)
        retry -= 1
        if retry < 0:
            entry['point_judge'] = False
            entry['point_explanation'] = "None"
            entry['payload'] = payload
            save_jsonl(entry, save_path)
        print(f"retry:剩余{retry}次")
        print(e)
        print(traceback.format_exc())
        get_answer(input_data, retry=retry)


def run_evaluation(save_path, datas, num_pool):
    _input = [{"data":i, "eval_model": "gpt-4-1106-preview", "save_path": save_path} for i in datas if i ]
    with ThreadPoolExecutor(max_workers=num_pool) as executor:
        list(tqdm(executor.map(get_answer, _input), total=len(_input), desc='Processing', ncols=100))


def get_data(data_path, llm_output_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    with open(llm_output_path, 'r', encoding='utf-8') as f:
        outputs = [json.loads(line) for line in f.readlines()]
    
    #####
    data = sorted(data, key=itemgetter('main_id'))
    outputs = sorted(outputs, key=itemgetter('main_id'))
    #####
    
    datas = []
    for i, (d, o) in enumerate(zip(data, outputs)):
        for j, q in enumerate(d['scoring_questions']):
            if q['rule'] != None:
                continue
            datas.append({
                "main_id" : d["main_id"],
                "point_id" : q["point_id"],
                "instruction" : d['instruction'],
                "question" : q['question'],
                "output" : o['generated'],
            })
    return datas


def main_run(args):
    datas = get_data(data_path=args.data_path, llm_output_path=args.llm_output_path)
    run_evaluation(args.output_path, datas, args.num_pool)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default="")
    parser.add_argument("--llm_output_path", type=str, default="")
    parser.add_argument("--num_pool", type=int, default=1)
    parser.add_argument("--output_path", type=str, default="")
    parser.add_argument("--api_key", type=str, default="")
    parser.add_argument("--api_base", type=str, default="")
    args = parser.parse_args()
    openai.api_key = args.api_key
    openai.api_base = args.api_base
    SYS_MSG = EVALUATION_PROMPT
    ### 
    SYS_PARA = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=8192)
    SYS_VLLM = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
    ###
    main_run(args)
