import json
import random
from tqdm import tqdm

task_prompt1 = "<image>Given the multi-modal knowledge graph. Please count the number of different relations in it."
# task_prompt2 = "<image>Given the multi-modal knowledge graph. Please count the number of relations among entities."

answer_prompt1 = """
<think>
There are several different relations in the given multi-modal knowledge graph:
{}
Therefore, the number of different relations is {}
</think>
<answer>{}</answer>
"""

"""
answer_prompt2 = ""
<think>
There are several relations in the given multi-modal knowledge graph:
{}
Therefore, the number of all relations is {}
</think>
<answer>{}</answer>
""
"""

if __name__ == "__main__":
    task_data = json.load(open("data/processed_instances_with_images_task2.json", "r"))
    task_prompts = []
    count = 0
    for data in tqdm(task_data):
        entity = data["entity"]
        relation = data["relation"]
        triple = data["triple"]
        source = data["source"]
        image_path = data["image"]
        if count >= 12000:
            break
        count += 1
        processed_image_path = image_path.replace("generate/images/", "m5/")
        task_prompt = task_prompt1
        relation_list = "\n".join(relation)
        answer = answer_prompt1.format(relation_list, len(relation), len(relation))
        
        data_instance = {
            "messages": [
                {
                    "content": task_prompt,
                    "role": "user"
                },
                {
                    "content": answer,
                    "role": "assistant"
                }
            ],
            "images": [processed_image_path],
            "metadata": {
                "entity": entity,
                "relation": relation,
                "triple": triple,
                "source": source,
                "task_type": 2
            }
        }
        # print(data_instance)
        task_prompts.append(data_instance)
    random.shuffle(task_prompts)
    json.dump(task_prompts[0: 6400], open("dataset/task2_prompts_train.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[6400: 7200], open("dataset/task2_prompts_valid.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[7200: 8000], open("dataset/task2_prompts_test.json", "w"), ensure_ascii=False)