import os
import tqdm
import time
import openai
import random
import glob
from diskcache import Cache
import threading
import queue
import time
import argparse

import os
import json
import datetime
import traceback

import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry

openai_api = ''
api_key_RM = None

print("OPENAI_API_ADDR:" + openai_api)
TNUM = 50
date = datetime.datetime.now().strftime('%m%d_%H_%M_%S')

parser = argparse.ArgumentParser()
parser.add_argument("--json-path", type=str, required=True)
parser.add_argument("--output-path", type=str, required=True)
parser.add_argument("--gpt-key", type=str, default="")
args = parser.parse_args()

class OpenAIApiException(Exception):
    def __init__(self, msg, error_code):
        self.msg = msg
        self.error_code = error_code

class OpenAIApiProxy():
    def __init__(self, api_key=""):
        retry_strategy = Retry(
            total=500,  # 最大重试次数（包括首次请求）
            backoff_factor=1,  # 重试之间的等待时间因子
            status_forcelist=[429, 500, 502, 503, 504],  # 需要重试的状态码列表
            allowed_methods=["POST"]  # 只对POST请求进行重试
        )
        adapter = HTTPAdapter(max_retries=retry_strategy)
        # 创建会话并添加重试逻辑
        self.session = requests.Session()
        self.session.mount("https://", adapter)
        self.session.mount("http://", adapter)
        self.api_key = api_key

    def call(self, model_name, prompt, headers={}, max_tokens=512):
        params_gpt = {
            "model": model_name,
            "messages": [{"role": "user", "content":''}],
            "max_tokens": max_tokens,
            "stop": "\n" if max_tokens <= 10 else None,
            "temperature": 0.001,
        }
        params_gpt['model'] = model_name
        params_gpt['messages'][0]['content'] = prompt
        # print("prompt", prompt)

        headers['Content-Type'] = headers['Content-Type'] if 'Content-Type' in headers else 'application/json'
        if self.api_key:
            headers['Authorization'] = "Bearer " + self.api_key
        url = openai_api + '/v1/chat/completions'
        response = self.session.post(url, headers=headers, data=json.dumps(params_gpt))

        if response.status_code != 200:
            err_msg = "access openai error, status code: %s" % (response.status_code)
            raise OpenAIApiException(err_msg, response.status_code)
        # print(response)
        data = json.loads(response.text)
        return data

cache = Cache('./cache')
if args.gpt_key != "":
    proxy = OpenAIApiProxy(api_key=args.gpt_key)
else:
    proxy = OpenAIApiProxy()

def get_prompt(query):
    messages = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': ''}
    ]
    messages[-1]['content'] = '''我现在需要你对一首诗歌的文本进行拆解，最终以json的格式返回拆解后的内容。我会给你一些正确答案的示例，拆解的规范如下：
    1. 你需要判断文本中是否含有标题，如果含有标题，那么把标题作为返回的json中的\"标题\"的部分，如果不含有标题，那么标题部分返回一个空字符串，注意不要把诗歌第一句变成标题；
    2. 你需要把文本中诗歌部分的内容按句子拆解开，拆解到每个逗号或者顿号，按照拆开的顺序，以list的形式返回到json中\"诗歌\"中，标点符号不需要出现在list中；
    3. 诗歌中的其他部分则不出现在返回的文件中；
    4. 如果诗歌内容中有书名号《》或者引号，请去掉。
    '''
    messages.append({'role': 'assistant', 'content': '明白了，我会根据您提供的示例和规范来进行拆解。请提供需要答案示例，以及需要拆解的文本。'})
    messages.append({'role': 'user', 'content': ''})
    user_input_for_judging = "示例1：\n"
    user_input_for_judging += "【题目】诉衷情·忆往昔\n\n【作者】现代•无名氏\n\n【词牌】诉衷情\n\n【韵部】平韵，韵字：先、烟、天、年、边、巅、传、川、篇\n\n【正文】\n\n清风抚碧波，忆往昔，泪点点。\n携手共游天地间，山水烟云绕足边。\n\n笑语声犹在耳畔，转眼已是数年离散。\n梦里相逢醉瑶巅，醒来惟有孤影伴。\n\n情难忘，意难传，相思如川怎撑船？\n愿寄一纸平安篇，海角天涯心相连。\n"
    user_input_for_judging += '正确答案示例如下：\n'
    user_input_for_judging += '''
    {
        "标题": "诉衷情·忆往昔",
        "诗歌": ["清风抚碧波", "忆往昔","泪点点","携手共游天地间","山水烟云绕足边","笑语声犹在耳畔","转眼已是数年离散","梦里相逢醉瑶巅","醒来惟有孤影伴","情难忘","意难传","相思如川怎撑船","愿寄一纸平安篇","海角天涯心相连"]
    }\n
    '''
    user_input_for_judging += "示例2：\n"
    user_input_for_judging += "月夜独酌青天外，清风伴我醉仙途。\n繁星似海波光涌，天地豪情一饮无。\n"
    user_input_for_judging += '正确答案示例如下：\n'
    user_input_for_judging += '''
    {
        "标题": "",
        "诗歌": ["月夜独酌青天外", "清风伴我醉仙途", "繁星似海波光涌", "天地豪情一饮无"]
    }\n
    '''
    user_input_for_judging += "示例3：\n"
    user_input_for_judging += "当我看见那枚钉子\n静静地躺在地板上，\n我心生莫名的情感，\n如诗人笔下涌动的篇章。\n\n它，虽貌不惊人，平凡无奇，\n却承载着历史的重量，\n无数的故事因它而起，\n多少的辉煌与哀伤。\n"
    user_input_for_judging += '正确答案示例如下：\n'
    user_input_for_judging += '''
    {
        "标题": "",
        "诗歌": ["当我看见那枚钉子", "静静地躺在地板上", "我心生莫名的情感", "如诗人笔下涌动的篇章", "它", "虽貌不惊人", "平凡无奇", "却承载着历史的重量", "无数的故事因它而起", "多少的辉煌与哀伤。"]
    }\n
    '''
    user_input_for_judging += "示例4：\n"
    user_input_for_judging += "当我看见那枚钉子\n静静地躺在地板上，\n我心生莫名的情感，\n如诗人笔下涌动的篇章。\n\n它，虽貌不惊人，平凡无奇，\n却承载着历史的重量，\n无数的故事因它而起，\n多少的辉煌与哀伤。\n"
    user_input_for_judging += '正确答案示例如下：\n'
    user_input_for_judging += '''
    {
        "标题": "",
        "诗歌": ["当我看见那枚钉子", "静静地躺在地板上", "我心生莫名的情感", "如诗人笔下涌动的篇章", "它", "虽貌不惊人", "平凡无奇", "却承载着历史的重量", "无数的故事因它而起", "多少的辉煌与哀伤。"]
    }\n
    '''

    user_input_for_judging += '\n需要进行拆解的文本如下：\n'
    user_input_for_judging += '{}\n\n'.format(query)
    user_input_for_judging += '现在请给出回答，请注意：1. 请返回json格式的内容，包括标题和诗歌两项；2. 请不要返回json以外的其他内容；3. 去掉诗歌内容中的引号或者书名号；4. 注意不要把诗歌第一句变成标题，没有标题就返回空字符串。请回答：'

    messages[-1]['content'] = user_input_for_judging

    content = ""
    for m in messages:
        content += m['role']
        content += ": "
        content += m['content']
        content += "\n\n"
    return content

# def gpt_eval(query, loop, model="gpt-4-1106-preview", max_tokens=200, debug=False, try_count=3, use_cache=True):
def gpt_eval(query, loop, model="gpt-4-1106-preview", max_tokens=200, debug=False, try_count=3, use_cache=True):
    try:
        
        content = get_prompt(query)
        cache_key = '{}\t{}'.format(content, loop)
        #print(cache_key)
        #cache
        if use_cache and cache_key in cache:
            score = cache[cache_key]
            return score
        resp = proxy.call(model, content)

        if debug:
            print(resp["choices"][0]["message"]["content"].split('\n')[0].strip())
        score = resp["choices"][0]["message"]["content"].strip()
        cache[cache_key] = score
        return score
    except:
        try_count += 1
        if try_count >= 1:
            raise
        print("try_count ", try_count)
        #time.sleep(1 * try_count)
        return gpt_eval(query, loop, model, max_tokens, debug, try_count)

def query_eval(query, loop=2, use_cache=True):
    #request
    for i in range(loop):
        eval_res = gpt_eval(query, loop)
        return eval_res
    return None

LIMIT = 200000
EVAL_LOOP_PER_CASE = 1

threads = []
lock = threading.Lock()
buffer = queue.Queue(maxsize=TNUM * 100)
results = []

def load_json(path):
    with open(path,"r") as f:
        return json.load(f)
    
def load_json_lines(path):
    data = []
    with open(path,"r") as f:
        for line in f:
            data.append(json.loads(line))
    return data

print("-----------------------")
print("Poem query gpt4.")
data = []
for file in glob.glob(args.json_path + "*.jsonl"):
    print("file:", file)
    data.extend(load_json_lines(file))

def line_format(line):
    if type(line) != str:
        line = str(line)
    return line.strip().replace('\\n', '\n').replace('<n>', '\n')

def consumer():
    # print(f'{threading.current_thread().name} {buffer.qsize()} Start')
    #time.sleep(1)
    while not buffer.empty():
#         time.sleep(1)

        index, query, dic, pbar = buffer.get()
        # print("query:", query)
        try:
            res = query_eval(query.strip(), loop=EVAL_LOOP_PER_CASE)
            if res is None:
                dic["gpt_res"] = 0
                continue
            else:
                dic["gpt_res"] = 1
                dic["gpt_decom"] = res
            with lock:
                global results
                results.append(dic)
        except Exception as e:
            # print(e)
            traceback.print_exc()
            #time.sleep(1)
            print("expectation", e)
            print(f'{threading.current_thread().name} Error {query}')
        finally:
            buffer.task_done()
            pbar.update(1)
    # print(f'{threading.current_thread().name} {buffer.qsize()} Exit')

def producer():
    # print(f'row define prompt->{raw_header[0]}, answer1->{raw_header[-2]} answer2->{raw_header[-1]}')
    size = 0
    #for res in data.values():
    #    size += len(res)
    
    pbar = tqdm.tqdm(size, desc=f'producer {threading.current_thread().name}')
    idx = 0

    for each in data:
        # prompt = each['conversations'][0]['value']
        response = each["gen"]
        buffer.put((idx, response, each, pbar))
        idx += 1

    '''
    for prompt, res in data.items():
        for d in res:
            response = d["gen"]
            buffer.put((idx, response, d, pbar))
            idx += 1
    '''

    # for index, row in enumerate(data):
    #     if index >= LIMIT:
    #         break
    #     query = row["gpt_query"]
    #     buffer.put((index, query, row, pbar))

#p_thread = threading.Thread(target=producer)
#p_thread.start()
#if not os.path.exists(args.output_path):
if 1:
    producer()
    for i in range(TNUM):
        threads.append(threading.Thread(target=consumer))
    for c_thread in threads:
        c_thread.start()

    #p_thread.join()
    #producer()
    for c_thread in threads:
        c_thread.join()

    with open(args.output_path, "w") as file:
        for d in results:
            json_data = json.dumps(d, ensure_ascii=False)
            file.write(json_data + "\n")
    # with open(args.output_path, "w") as file:
    #         json.dump(results, file, ensure_ascii=False, indent=4)
    print("write file: ", args.output_path)
#else:
#    print(f"已经产出{args.output_path}")
