import os
import tqdm
import time
import openai
import random
import glob
from diskcache import Cache
import threading
import queue
import time
import argparse

import os
import json
import datetime
import traceback

import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry


openai_api = ''
# api_key_RM = None

print("OPENAI_API_ADDR:" + openai_api)
TNUM = 500
date = datetime.datetime.now().strftime('%m%d_%H_%M_%S')

parser = argparse.ArgumentParser()
parser.add_argument("--json-path", type=str, required=True)
parser.add_argument("--output-path", type=str, required=True)
parser.add_argument("--gpt-key", type=str, default="")
args = parser.parse_args()

class OpenAIApiException(Exception):
    def __init__(self, msg, error_code):
        self.msg = msg
        self.error_code = error_code

class OpenAIApiProxy():
    def __init__(self, api_key=""):
    #def __init__(self, api_key=None):
        retry_strategy = Retry(
            total=500,  # 最大重试次数（包括首次请求）
            backoff_factor=1,  # 重试之间的等待时间因子
            status_forcelist=[429, 500, 502, 503, 504],  # 需要重试的状态码列表
            allowed_methods=["POST"]  # 只对POST请求进行重试
        )
        adapter = HTTPAdapter(max_retries=retry_strategy)
        # 创建会话并添加重试逻辑
        self.session = requests.Session()
        self.session.mount("https://", adapter)
        self.session.mount("http://", adapter)
        self.api_key = api_key

    def call(self, model_name, prompt, headers={}, max_tokens=512):
        params_gpt = {
            "model": model_name,
            "messages": [{"role": "user", "content":''}],
            "max_tokens": max_tokens,
            "stop": "\n" if max_tokens <= 10 else None,
            "temperature": 0.001,
        }
        params_gpt['model'] = model_name
        params_gpt['messages'][0]['content'] = prompt
        # print("prompt", prompt)

        headers['Content-Type'] = headers['Content-Type'] if 'Content-Type' in headers else 'application/json'
        if self.api_key:
            headers['Authorization'] = "Bearer " + self.api_key
        url = openai_api + '/v1/chat/completions'
        # print(url)
        # print(json.dumps(params_gpt, indent=4))
        response = self.session.post(url, headers=headers, data=json.dumps(params_gpt))

        if response.status_code != 200:
            err_msg = "access openai error, status code: %s" % (response.status_code)
            raise OpenAIApiException(err_msg, response.status_code)
        # print(response)
        data = json.loads(response.text)
        return data

cache = Cache('./cache')
if args.gpt_key != "":
    proxy = OpenAIApiProxy(api_key=args.gpt_key)
else:
    proxy = OpenAIApiProxy()

def get_prompt(query):
    messages = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': ''}
    ]
    messages[-1]['content'] = '''我现在需要你对一些标题/名字/口号/标语/个性签名/朋友圈文案等文字性内容进行拆解，这些内容可能是中文，也可能是英文。最终需要你以json的格式返回拆解后的内容。我会给你一些正确答案的示例，拆解的规范如下：
    1. 你需要把文本中的文字性内容的答案，返回到json中\"回答\"的list中，答案中可能含有特殊字符，也需要保留。丢弃其他与答案无关的内容，比如对答案的解释。
    2. 需要保留文字性内容的标点符号。
    3. 如果有多个回答，请将所有答案放在同一个list中。
    '''
    messages.append({'role': 'assistant', 'content': '明白了，我会根据您提供的示例和规范来进行拆解。请提供需要答案示例，以及需要拆解的文本。'})
    messages.append({'role': 'user', 'content': ''})
    user_input_for_judging = "示例1：\n"
    user_input_for_judging += "放飞梦想，砥砺前行"
    user_input_for_judging += '正确答案示例如下：\n'
    user_input_for_judging += '''
    {
        "内容": ["放飞梦想，砥砺前行"]
    }\n
    '''
    user_input_for_judging += "示例2：\n"
    user_input_for_judging += "当然可以！请参考以下建议：\n\n1. Clean Clothes, Fresh Start.\n2. Washing Wisdom, Pure Results.\n3. Expertly Cleaned, Always Bright.\n4. Sparkling Clean, Ready to Wear.\n5. Tough Stains, No Problem.\n6. Laundry Love, Fresh as New.\n7. Bright Whites, Bold Colors.\n8. Wash with Us, Wear with Pride.\n9. Clean Clothes, Happy Hearts.\n10. Fluff and Fold, Fresh as a Rose.\n\n以上标语均使用五个英文单词，简洁明了地传达了洗衣店的特色和理念。请根据您的喜好和需求选择合适的标语。"
    user_input_for_judging += '正确答案示例如下：\n'
    user_input_for_judging += '''
    {
        "内容": ["Clean Clothes, Fresh Start.", "Washing Wisdom, Pure Results.", "Expertly Cleaned, Always Bright.", "Sparkling Clean, Ready to Wear.", "Tough Stains, No Problem.", "Laundry Love, Fresh as New.", "Bright Whites, Bold Colors.", "Wash with Us, Wear with Pride.", "Clean Clothes, Happy Hearts.", "Fluff and Fold, Fresh as a Rose"]
    }\n
    '''
    user_input_for_judging += "示例3：\n"
    user_input_for_judging += "当然可以！以下是几个六字汉字组合的建议，你可以从中挑选或加以修改：\n\n1. **步韵潮流天地**：此名体现了时尚和潮流元素，同时“步韵”也暗示了舒适行走的意境。\n2. **足下风采无限**：强调了穿在脚下的鞋子带来的风采与无界限的可能性。\n3. **锦绣履程开端**：结合了“锦绣前程”的成语，寓意着穿着美鞋开始一段美好旅程。\n4. **雅步风情万种**：此名体现了优雅的步伐和多姿多彩的风情。\n5. **行者天地任我**：传达了穿好鞋即可行走天下、自由探索的理念。\n6. **踏月追风轻盈**：借用诗句中的意象，形容穿鞋时的轻便舒适和愉悦感受。\n7. **漫步四季轮回**：突显了穿好鞋可以轻松穿梭于不同季节的美好体验。\n8. **足迹艺术之旅**：此名强调了行走的足迹就像艺术一样值得欣赏和探索。\n9. **风尚步青云台**：结合了“青云直上”的成语，比喻穿好鞋即可步步高升、成功成名。\n10. **千里马行天下**：此名体现了穿好鞋即可像千里马一样行走天下、大展宏图的精神。\n\n以上这些名字都考虑到了独特性、文化意蕴和品牌定位等要素。在选择时还需要考虑目标市场和消费者的喜好来做出决定。"
    user_input_for_judging += '正确答案示例如下：\n'
    user_input_for_judging += '''
    {
        "内容": ["步韵潮流天地", "足下风采无限", "锦绣履程开端", "雅步风情万种", "行者天地任我", "踏月追风轻盈", "漫步四季轮回", "足迹艺术之旅", "风尚步青云台", "千里马行天下"]
    }\n
    '''
    user_input_for_judging += "示例4：\n"
    user_input_for_judging += "Dive into the magical world of sea life, waves of wonder await at the aquarium! 🌊✨"
    user_input_for_judging += '正确答案示例如下：\n'
    user_input_for_judging += '''
    {
        "内容": ["Dive into the magical world of sea life, waves of wonder await at the aquarium! 🌊✨"]
    }\n
    '''

    user_input_for_judging += '\n需要进行拆解的文本如下：\n'
    user_input_for_judging += '{}\n\n'.format(query)
    user_input_for_judging += '现在请给出回答，请注意：1. 请返回json格式的内容，包括内容一项；2. 请不要返回json以外的其他内容；3. 丢弃其他与答案无关的内容，比如对答案的解释。请回答：'

    messages[-1]['content'] = user_input_for_judging

    content = ""
    for m in messages:
        content += m['role']
        content += ": "
        content += m['content']
        content += "\n\n"
    return content

# def gpt_eval(query, loop, model="gpt-4-1106-preview", max_tokens=200, debug=False, try_count=3, use_cache=True):
def gpt_eval(query, loop, model="gpt-4", max_tokens=200, debug=False, try_count=3, use_cache=True):
    try:
        content = get_prompt(query)
        cache_key = '{}\t{}'.format(content, loop)
        #print(cache_key)
        #cache
        if use_cache and cache_key in cache:
            score = cache[cache_key]
            return score

        resp = proxy.call(model, content)

        if debug:
            print(resp["choices"][0]["message"]["content"].split('\n')[0].strip())
        score = resp["choices"][0]["message"]["content"].strip()
        # print("content", content)
        # print("res", score)
        # print("prompt", prompt)
        cache[cache_key] = score
        return score
    except:
        try_count += 1
        if try_count >= 1:
            raise
        print("try_count ", try_count)
        time.sleep(1 * try_count)
        return gpt_eval(query, loop, model, max_tokens, debug, try_count)

def query_eval(query, loop=2, use_cache=True):
    #request
    for i in range(loop):
        eval_res = gpt_eval(query, loop)
        return eval_res
    return None

LIMIT = 200000
EVAL_LOOP_PER_CASE = 1

threads = []
lock = threading.Lock()
buffer = queue.Queue(maxsize=TNUM * 2)
results = []

def load_json(path):
    with open(path,"r") as f:
        return json.load(f)
    
def load_json_lines(path):
    data = []
    with open(path,"r") as f:
        for line in f:
            data.append(json.loads(line))
    return data

print("-----------------------")
print("Slogan query gpt4.")
data = []
for file in glob.glob(args.json_path + "*.jsonl"):
    data.extend(load_json_lines(file))

def line_format(line):
    if type(line) != str:
        line = str(line)
    return line.strip().replace('\\n', '\n').replace('<n>', '\n')

def consumer():
    # print(f'{threading.current_thread().name} {buffer.qsize()} Start')
    time.sleep(3)
    while not buffer.empty():
#         time.sleep(1)

        index, query, dic, pbar = buffer.get()
        # print("query")
        try:
            res = query_eval(query.strip(), loop=EVAL_LOOP_PER_CASE)
            if res is None:
                dic["gpt_res"] = 0
                continue
            else:
                dic["gpt_res"] = 1
                dic["gpt_decom"] = res
            with lock:
                global results
                results.append(dic)
        except Exception as e:
            # print(e)
            traceback.print_exc()
            time.sleep(1)
            print("expectation", e)
            print(f'{threading.current_thread().name} Error {query}')
        finally:
            buffer.task_done()
            pbar.update(1)
    # print(f'{threading.current_thread().name} {buffer.qsize()} Exit')

def producer():
    # print(f'row define prompt->{raw_header[0]}, answer1->{raw_header[-2]} answer2->{raw_header[-1]}')
    size = 0
    # for res in data.values():
    #     size += len(res)
    size = len(data)
    
    pbar = tqdm.tqdm(size, desc=f'producer {threading.current_thread().name}')
    idx = 0
    # for prompt, res in data.items():
    #     for d in res:
    #         response = d["gen"]
    #         buffer.put((idx, response, d, pbar))
    #         idx += 1
    for d in data:
        response = d["gen"]
        buffer.put((idx, response, d, pbar))
        idx += 1

#p_thread = threading.Thread(target=producer)
#p_thread.start()
if not os.path.exists(args.output_path):
    for i in range(TNUM):
        threads.append(threading.Thread(target=consumer))
    for c_thread in threads:
        c_thread.start()

    #p_thread.join()
    producer()
    for c_thread in threads:
        c_thread.join()

    with open(args.output_path, "w") as file:
        for d in results:
            json_data = json.dumps(d, ensure_ascii=False)
            file.write(json_data + "\n")
    # with open(args.output_path, "w") as file:
    #         json.dump(results, file, ensure_ascii=False, indent=4)
    print("write file: ", args.output_path)
else:
    print(f"已经产出{args.output_path}")
