import json
import random
from tqdm import tqdm

task_prompt = "<image>Given the multi-modal knowledge graph. Please count the number of entities in it."

answer_prompt = """
<think>
There are several entities in the given multi-modal knowledge graph:
{}
Therefore, the number of entities is {}
</think>
<answer>{}</answer>
"""


if __name__ == "__main__":
    task_data = json.load(open("data/processed_instances_with_images_task1.json", "r"))
    task_prompts = []
    count = 0
    for data in tqdm(task_data):
        entity = data["entity"]
        relation = data["relation"]
        triple = data["triple"]
        source = data["source"]
        image_path = data["image"]
        if count >= 11000:
            break
        count += 1
        processed_image_path = image_path.replace("generate/images", "m5/task1")
        entity_list = "\n".join(entity)
        answer = answer_prompt.format(entity_list, len(entity), len(entity))
        data_instance = {
            "messages": [
                {
                    "content": task_prompt,
                    "role": "user"
                },
                {
                    "content": answer,
                    "role": "assistant"
                }
            ],
            "images": [processed_image_path],
            "metadata": {
                "entity": entity,
                "relation": relation,
                "triple": triple,
                "source": source,
                "task_type": 1
            }
        }
        # print(data_instance)
        task_prompts.append(data_instance)
    random.shuffle(task_prompts)
    json.dump(task_prompts[0: 6400], open("dataset/task1_prompts_train.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[6400: 7200], open("dataset/task1_prompts_valid.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[7200: 8000], open("dataset/task1_prompts_test.json", "w"), ensure_ascii=False)