import json
import random
from tqdm import tqdm

task_prompt = "<image>Given the multi-modal knowledge graph. Please count the number of entities that have image information in the given knowledge graph."

answer_prompt = """
<think>
There are several entities with images in the given multi-modal knowledge graph:
{}
Other entities without images are:
{}
Therefore, the number of entities is {}
</think>
<answer>{}</answer>
"""


if __name__ == "__main__":
    task_data = json.load(open("data/processed_instances_with_images_task3.json", "r"))
    task_prompts = []
    count = 0
    for data in tqdm(task_data):
        if count >= 11000:
            break
        count += 1
        entity = data["entity"]
        relation = data["relation"]
        triple = data["triple"]
        source = data["source"]
        image_path = data["image"]
        image_mask = data["image_mask"]
        processed_image_path = image_path.replace("generate/images/", "m5/")
        entity_list = "\n".join(entity)
        image_list = []
        no_image_list = []
        for i in range(len(entity)):
            if image_mask[i] == 1:
                image_list.append(entity[i])
            else:
                no_image_list.append(entity[i])
        num_image = image_mask.count(1)
        answer = answer_prompt.format('\n'.join(image_list), '\n'.join(no_image_list), num_image, num_image)
        data_instance = {
            "messages": [
                {
                    "content": task_prompt,
                    "role": "user"
                },
                {
                    "content": answer,
                    "role": "assistant"
                }
            ],
            "images": [processed_image_path],
            "metadata": {
                "entity": entity,
                "relation": relation,
                "triple": triple,
                "source": source,
                "task_type": 3,
                "image_mask": image_mask
            }
        }
        # print(data_instance)
        task_prompts.append(data_instance)
    random.shuffle(task_prompts)
    json.dump(task_prompts[0: 6400], open("dataset/task3_prompts_train.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[6400: 7200], open("dataset/task3_prompts_valid.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[7200: 8000], open("dataset/task3_prompts_test.json", "w"), ensure_ascii=False)