import json
import os
import time
import argparse
import traceback
from os.path import exists
import openai
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from vllm import LLM, SamplingParams
from prompts.RAL_extractor import EXTRACTION_PROMPT, EXTRACTION_PROMPT_EACH
from operator import itemgetter


def load_jsonl(file_path):
    _data = []
    with open(file_path, 'r') as f:
        for data in f:
            jline = json.loads(data)
            _data.append(jline)
    return _data

def make_query(user_input, system_prompt, examples):   
    query = ""
    query += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" # will tokenizer add the first <|begin_of_text|> ??
    for example in examples:
        query += f"{example['user']} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n {example['assistant']} <|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
    query += f"{user_input} <|eot_id|>"

    query += "<|start_header_id|>assistant<|end_header_id|>\n\n"

    return query

def get_content(line):
    instruction = line['instruction'][:6000]
    question = line['question']
    if line['output'] != None:
        output = line['output'][:4000]
    else:
        output = "None"
    if 'each' in line['rule']:
        content =  SYS_MSG_EACH.format(instruction=instruction, response=output, question=question)
    else:
        content =  SYS_MSG.format(instruction=instruction, response=output, question=question)
    return content

def get_payload(line):
    content =  get_content(line)
    payload = {
        "model": "gpt-4o",
        "messages": [
            {
                "role": "user",
                "content": content
            }
        ],
        "max_tokens": 8192,
        "temperature": 0.0,
        "top_p": 0.95,
        "stream": True
    }
    return payload


def save_jsonl(entry, sava_path):
    with open(sava_path, 'a', encoding='utf-8') as f:
        f.write(json.dumps(entry, ensure_ascii=False)+ "\n")


def get_answer(input_data: dict, retry=30):
    entry, save_path = input_data['data'], input_data['save_path']
    if retry<0:
        return
    try:
        ### gpt
        payload = get_payload(entry)
        # chat_completion = openai.ChatCompletion.create(model=payload['model'], temperature=0, messages=payload['messages'])
        # generation = chat_completion.choices[0].message.content
        
        ### llama3-8b
        content = get_content(entry)
        query = make_query(content, "", "")
        output = SYS_VLLM.generate(query, SYS_PARA)
        generation = output[0].outputs[0].text
        ###

        if generation == None or generation == "":
            get_answer(input_data, retry=retry-1)

        entry['ass'] = generation
        entry['payload'] = payload
        save_jsonl(entry, save_path)
    except Exception as e:
        time.sleep(1.2)
        retry -= 1
        if retry < 0:
            entry['ass'] = "None"
            entry['payload'] = payload
            save_jsonl(entry, save_path)
        print(f"retry:剩余{retry}次")
        print(e)
        print(traceback.format_exc())
        get_answer(input_data, retry=retry)


def run_extraction(save_path, datas, num_pool):
    _input = [{"data": i, "eval_model": "gpt-4-1106-preview", "save_path":save_path} for i in datas if i]
    with ThreadPoolExecutor(max_workers=num_pool) as executor:
        list(tqdm(executor.map(get_answer, _input), total=len(_input), desc='Processing', ncols=100))


def get_data(data_path, llm_output_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    with open(llm_output_path, 'r', encoding='utf-8') as f:
        outputs = [json.loads(line) for line in f.readlines()]
    
    #####
    data = sorted(data, key=itemgetter('main_id'))
    outputs = sorted(outputs, key=itemgetter('main_id'))
    #####
    
    datas = []
    for i, (d, o) in enumerate(zip(data, outputs)):
        for j, q in enumerate(d['scoring_questions']):
            if q['rule'] == None:
                continue
            datas.append({
                "main_id" : d["main_id"],
                "point_id" : q["point_id"],
                "instruction" : d['instruction'],
                "rule" : q['rule'],
                "question" : q['question'],
                "output" : o['generated'],
            })
    
    return datas


def main_run(args):    
    datas = get_data(data_path=args.data_path, llm_output_path=args.llm_output_path)
    run_extraction(args.output_path, datas, args.num_pool)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default="")
    parser.add_argument("--llm_output_path", type=str, default="")
    parser.add_argument("--num_pool", type=int, default=1)
    parser.add_argument("--output_path", type=str, default="")
    parser.add_argument("--api_key", type=str, default="")
    parser.add_argument("--api_base", type=str, default="")
    args = parser.parse_args()
    openai.api_key = args.api_key
    openai.api_base = args.api_base
    SYS_MSG = EXTRACTION_PROMPT
    SYS_MSG_EACH = EXTRACTION_PROMPT_EACH
    ### 
    SYS_PARA = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=8192)
    SYS_VLLM = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
    ###
    main_run(args)
