import json
import random
from tqdm import tqdm

task_prompt = "<image>Given the multi-modal knowledge graph. Please count the number of knowledge triples in it."

answer_prompt = """
<think>
There are several knowledge triples in the given multi-modal knowledge graph:
{}
Therefore, the number of triples is {}
</think>
<answer>{}</answer>
"""


if __name__ == "__main__":
    task_data = json.load(open("data/processed_instances_with_images_task4.json", "r"))
    task_prompts = []
    count = 0
    for data in tqdm(task_data):
        if count >= 12000:
            break
        count += 1
        entity = data["entity"]
        relation = data["relation"]
        triple = data["triple"]
        source = data["source"]
        image_path = data["image"]
        processed_image_path = image_path.replace("generate/images/", "m5/")
        triple_list = ["{}--{}--{}".format(t[0], t[2], t[1]) for t in triple]
        answer = answer_prompt.format('\n'.join(triple_list), len(triple_list), triple_list)
        data_instance = {
            "messages": [
                {
                    "content": task_prompt,
                    "role": "user"
                },
                {
                    "content": answer,
                    "role": "assistant"
                }
            ],
            "images": [processed_image_path],
            "metadata": {
                "entity": entity,
                "relation": relation,
                "triple": triple,
                "source": source,
                "task_type": 4,
            }
        }
        # print(data_instance)
        task_prompts.append(data_instance)
    random.shuffle(task_prompts)
    json.dump(task_prompts[0: 6400], open("dataset/task4_prompts_train.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[6400: 7200], open("dataset/task4_prompts_valid.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[7200: 8000], open("dataset/task4_prompts_test.json", "w"), ensure_ascii=False)