# 内容分

import os
import tqdm
import time
import openai
import random
# import pandas as pd
from diskcache import Cache
import threading
import queue
import time
import argparse

import os
import json
import datetime
import traceback

import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry

openai_api = ''
# api_key_RM = None

print("OPENAI_API_ADDR:" + openai_api)
TNUM = 500
date = datetime.datetime.now().strftime('%m%d_%H_%M_%S')

parser = argparse.ArgumentParser()
parser.add_argument("--json-path", type=str, required=True)
parser.add_argument("--output-path", type=str, required=True)
parser.add_argument("--gpt-key", type=str, default="")
args = parser.parse_args()

class OpenAIApiException(Exception):
    def __init__(self, msg, error_code):
        self.msg = msg
        self.error_code = error_code

class OpenAIApiProxy():
    def __init__(self, api_key=""):
    #def __init__(self, api_key=None):
        retry_strategy = Retry(
            total=500,  # 最大重试次数（包括首次请求）
            backoff_factor=1,  # 重试之间的等待时间因子
            status_forcelist=[429, 500, 502, 503, 504],  # 需要重试的状态码列表
            allowed_methods=["POST"]  # 只对POST请求进行重试
        )
        adapter = HTTPAdapter(max_retries=retry_strategy)
        # 创建会话并添加重试逻辑
        self.session = requests.Session()
        self.session.mount("https://", adapter)
        self.session.mount("http://", adapter)
        self.api_key = api_key

    def call(self, model_name, prompt, headers={}, max_tokens=512):
        params_gpt = {
            "model": model_name,
            "messages": [{"role": "user", "content":''}],
            "max_tokens": max_tokens,
            "stop": "\n" if max_tokens <= 10 else None,
            "temperature": 0.001,
        }
        params_gpt['model'] = model_name
        params_gpt['messages'][0]['content'] = prompt
        # print("prompt", prompt)

        headers['Content-Type'] = headers['Content-Type'] if 'Content-Type' in headers else 'application/json'
        if self.api_key:
            headers['Authorization'] = "Bearer " + self.api_key
        url = openai_api + '/v1/chat/completions'
        # print(url)
        # print(json.dumps(params_gpt, indent=4))
        response = self.session.post(url, headers=headers, data=json.dumps(params_gpt))

        if response.status_code != 200:
            err_msg = "access openai error, status code: %s" % (response.status_code)
            raise OpenAIApiException(err_msg, response.status_code)
        # print(response)
        data = json.loads(response.text)
        return data

cache = Cache('../cache')
if args.gpt_key != "":
    proxy = OpenAIApiProxy(api_key=args.gpt_key)
else:
    proxy = OpenAIApiProxy()

def get_prompt(prompt, query):
    messages = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': ''}
    ]
    messages[-1]['content'] = \
    '''
    你是一个打分专家，这里需要你对一个问答进行评估。
    我会为你提供一个问题，这个问题要求写一个标题/名字/口号/标语/个性签名/朋友圈文案其中的某种类型的文字性内容，同时这个问题中明确了这个文字性内容是专门为了某个事物创作的。
    然后我会给你一些对应的回答，这些回答以list形式出现，其中每个元素是一个答案，可能有一个答案，也可能有多个答案，回答可能是中文，也可能是英文。
    你需要首先抽取出问题中内容相关的指令，具体而言包括了文字性内容的类型，和针对的是哪一种事物。
    你需要判断答案的风格是否满足了文字性内容的类型，并打分（满分2分），再判断答案的描写是否和事物契合，并打分（满分2分），将二者加和作为总分。
    最终对于每一个答案你只需要给出上述答案的总分，不需要回答其他任何内容，请将所有答案的分数按照对应的顺序放在同一个list中。
    '''
    messages.append({'role': 'assistant', 'content': '明白了，我会根据您提供的示例和规范来进行打分。'})
    messages.append({'role': 'user', 'content': ''})
    user_input_for_judging = "现在给你的问题是：\n"
    user_input_for_judging += prompt + '\n'
    user_input_for_judging += "给你的回答是：\n"
    user_input_for_judging += query + '\n'
    user_input_for_judging += "请给出你的打分："
    messages[-1]['content'] = user_input_for_judging

    content = ""
    for m in messages:
        content += m['role']
        content += ": "
        content += m['content']
        content += "\n\n"
    # print(content)
    return content

# def gpt_eval(query, loop, model="gpt-4-1106-preview", max_tokens=200, debug=False, try_count=3, use_cache=True):
def gpt_eval(prompt, query, loop, model="gpt-4", max_tokens=200, debug=False, try_count=3, use_cache=True):
    try:
        content = get_prompt(prompt, query)
        cache_key = '{}\t{}'.format(content, loop)
        #print(cache_key)
        #cache
        if use_cache and cache_key in cache:
            score = cache[cache_key]
            return score

        resp = proxy.call(model, content)

        if debug:
            print(resp["choices"][0]["message"]["content"].split('\n')[0].strip())
        score = resp["choices"][0]["message"]["content"].strip()
        # print("content", content)
        # print("res", score)
        # print("prompt", prompt)
        cache[cache_key] = score
        return score
    except:
        try_count += 1
        if try_count >= 1:
            raise
        print("try_count ", try_count)
        time.sleep(1 * try_count)
        return gpt_eval(prompt, query, loop, model, max_tokens, debug, try_count)

def query_eval(prompt, query, loop=2, use_cache=True):
    #request
    for i in range(loop):
        eval_res = gpt_eval(prompt, query, loop)
        return eval_res
    return None

LIMIT = 200000
EVAL_LOOP_PER_CASE = 1

threads = []
lock = threading.Lock()
buffer = queue.Queue(maxsize=TNUM * 2)
results = []

def load_json(path):
    with open(path,"r") as f:
        return json.load(f)
    
def load_json_lines(path):
    data = []
    with open(path,"r") as f:
        for line in f:
            data.append(json.loads(line))
    return data

print("-----------------------")
print("Slogan reward content.")
data = load_json_lines(args.json_path)

def line_format(line):
    if type(line) != str:
        line = str(line)
    return line.strip().replace('\\n', '\n').replace('<n>', '\n')

def consumer():
    # print(f'{threading.current_thread().name} {buffer.qsize()} Start')
    time.sleep(3)
    while not buffer.empty():
#         time.sleep(1)

        index, prompt, ans, dic, pbar = buffer.get()
        # print("query")
        try:
            res = query_eval(prompt, ans.strip(), loop=EVAL_LOOP_PER_CASE)
            if res is None:
                dic["rewards"]["content"] = None
                continue
            else:
                dic["rewards"]["content"] = res
            with lock:
                global results
                results.append(dic)
        except Exception as e:
            # print(e)
            traceback.print_exc()
            time.sleep(1)
            print("expectation", e)
            print(f'{threading.current_thread().name} Error {ans}')
        finally:
            buffer.task_done()
            pbar.update(1)
    # print(f'{threading.current_thread().name} {buffer.qsize()} Exit')

def producer():
    # print(f'row define prompt->{raw_header[0]}, answer1->{raw_header[-2]} answer2->{raw_header[-1]}')
    size = len(data)
    
    pbar = tqdm.tqdm(size, desc=f'producer {threading.current_thread().name}')
    idx = 0
    for d in data:
        if "conversations" in d:
            prompt = d["conversations"][0]["value"]
        else:
            prompt = d["prompt"][0]["value"]
        response = d["decom_dic"]["内容"]
        if "rewards" not in d.keys():
            d["rewards"] = {}
        if "explanation" not in d.keys():
            d["explanation"] = {}
        buffer.put((idx, prompt, str(response), d, pbar))
        idx += 1
    # for index, row in enumerate(data):
    #     if index >= LIMIT:
    #         break
    #     query = row["gpt_query"]
    #     buffer.put((index, query, row, pbar))

#p_thread = threading.Thread(target=producer)
#p_thread.start()
if not os.path.exists(args.output_path):
    for i in range(TNUM):
        threads.append(threading.Thread(target=consumer))
    for c_thread in threads:
        c_thread.start()

    #p_thread.join()
    producer()
    for c_thread in threads:
        c_thread.join()

    with open(args.output_path, "w") as file:
        for d in results:
            json_data = json.dumps(d, ensure_ascii=False)
            file.write(json_data + "\n")
    # with open(args.output_path, "w") as file:
    #         json.dump(results, file, ensure_ascii=False, indent=4)
    # print("write file: ", args.output_path)
else:
    print(f"已经产出{args.output_path}")
