import json
import random
from tqdm import tqdm
# from qwen import call_qwen_batch

task_prompt = "<image>Given the multi-modal knowledge graph. Please describe the knowledge presented by it."


qwen_answer_gen_prompt = """
Given a list of triples in the format of [head, tail, relation], please summarize the knowledge presented by these triples in a concise manner.

The triples are as follows:
{}
Please provide a concise summary of the knowledge contained in these triples. Do not include any additional information.
"""


if __name__ == "__main__":
    task_data = json.load(open("data/processed_instances_with_images_task5.json", "r"))
    task_prompts = []
    gen_prompts = []
    """
    for i in range(len(task_data)):
        triple = task_data[i]["triple"]
        gen_prompt = qwen_answer_gen_prompt.format(triple)
        gen_prompts.append(gen_prompt)
    generated_answers = call_qwen_batch(gen_prompts)
    json.dump(generated_answers, open("data/generated_answers_task5.json", "w"), ensure_ascii=False)
    """
    generated_answers = json.load(open("data/generated_answers_task5.json", "r"))
    count = 0
    for i in tqdm(range(len(task_data))):
        if count >= 12000:
            break
        count += 1
        data = task_data[i]
        answer = generated_answers[i]
        entity = data["entity"]
        relation = data["relation"]
        triple = data["triple"]
        source = data["source"]
        image_path = data["image"]
        processed_image_path = image_path.replace("generate/images/", "m5/")
        data_instance = {
            "messages": [
                {
                    "content": task_prompt,
                    "role": "user"
                },
                {
                    "content": answer,
                    "role": "assistant"
                }
            ],
            "images": [processed_image_path],
            "metadata": {
                "entity": entity,
                "relation": relation,
                "triple": triple,
                "source": source,
                "task_type": 4,
            }
        }
        # print(data_instance)
        task_prompts.append(data_instance)
    random.shuffle(task_prompts)
    json.dump(task_prompts[0: 6400], open("dataset/task5_prompts_train.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[6400: 7200], open("dataset/task5_prompts_valid.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[7200: 8000], open("dataset/task5_prompts_test.json", "w"), ensure_ascii=False)