import argparse
import json
import os
import yaml
from openai import OpenAI, AsyncOpenAI
import re


class ChatWrapper:
    def __init__(self, chat):
        with open("./prompt.yaml","r",encoding='utf-8') as file:
            self.prompt = yaml.load(file, yaml.FullLoader)
        self.chat = chat
    
    def get_obj_response(self, prompt):
        user_prompt = self.prompt['object_locator'].format(prompt=prompt)
        message = [
                            {"role": "system", "content": "You are a brilliant object_locator."},
                            {"role": "user", "content": user_prompt}
                       ]
        response = self.chat.get_response(message=message)
        return response



def convert_data(data, image_size):
    output = []

# 假定你的图像大小为 512, 你可以根据需要调整这个值
    

    for item in data.values():
    # 将坐标标准化并转换为整型
        x1, y1, x2, y2 = [int(i * image_size) for i in item["coordinates"]]

    # 计算宽度和高度
        bbox_width = x2 - x1
        bbox_height = y2 - y1

    # 添加到输出列表
        output.append({
            "bbox": [x1, y1, bbox_width, bbox_height],
            "mask": [],
            "category_name": "",
            "caption": item["phrase"]
        })
    return output

class SyncChat:
    def __init__(self, model, api_key,base_url):
        self.sync_client = OpenAI(api_key=api_key, base_url=base_url)
        self.model = model
    def get_response(self, message, temperature=0.2, max_tokens=1024):
        response  = self.sync_client.chat.completions.create(
                        model=self.model,
                        messages=message,
                        temperature=temperature,
                        max_tokens=max_tokens)
        return response.choices[0].message.content

def run(datafile, args):
    base_url = "###"
    api_key = "###"
    os.makedirs( args.output_dir, exist_ok=True)
    chat = SyncChat("gpt-4-turbo", api_key=api_key, base_url=base_url)
    chat_wrapper = ChatWrapper(chat)
    datalist = []
    image_size = 512
    with open(datafile, 'r') as file:
        for line in file:
            json_object = json.loads(line)
            datalist.append(json_object)
    
    prompt = datalist[args.index - 1]["prompt"]
    print(prompt)
    bbox = chat_wrapper.get_obj_response(prompt=prompt)
    print(bbox)
    with open(args.prompt_res, 'w') as res_file:
            res_file.write(bbox)
    match = re.search('Step 6\n(.+)$', bbox, re.DOTALL)
    if match:
        json_str = match.group(1)  # 提取匹配的 JSON 字符串
        json_data = json.loads(json_str)  # 将 JSON 字符串转化为 Python 字典
        print(json_data)
    # for data in datalist:
    #     prompt = data["prompt"]
    #     print(prompt)
        # bbox = chat_wrapper.get_obj_response(prompt=prompt)
    
        # bbox = json.loads(bbox)
        # print(bbox)
        annos = convert_data(json_data, image_size)
        data = dict()
        data["caption"] = prompt + " high quality. professional photo."
        data["width"] = image_size
        data["height"] = image_size
        data["annos"] = annos
        with open(args.output_json, 'w') as f:
            json.dump(data, f)
        

if __name__== "__main__" :
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_json", type=str,  default="./output/MHalu_json/output_0.json", help="root folder for output")
    parser.add_argument("--prompt_res", type=str,  default="./output/MHalu_json/output_0.json", help="root folder for output")
    parser.add_argument("--output_dir", type=str,  default="./output/MHalu_json/output_0.json", help="root folder for output")
    parser.add_argument("--index", type=int, default=18, help="")
    args = parser.parse_args()
    datafile = '/path/to/MHaluBench/text-to-image.json'
    run(datafile, args)