import json
import random
from tqdm import tqdm
from qwen import call_qwen_batch

task_prompt = "<image>Given the multi-modal knowledge graph. Please point out the wrong entity in it."


qwen_answer_gen_prompt = """
There is one entity replaced by a wrong entity in the given triples list with the format of [head, tail, relation]. We will provide you with the original entity and the new entity. Please explain why the original entity is more plausible for the triples compared with the wrong entity in the wrong entities.

The triples are as follows:
{}

Original entity: {}
New entity: {}

Please provide a concise short explanation. Do not include any additional information.
"""

answer_prompt = """
<think>
{}
</think>
<answer>{}</answer>
"""


if __name__ == "__main__":
    task_data = json.load(open("data/processed_instances_with_images_task6.json", "r"))
    task_prompts = []
    gen_prompts = []
    for i in range(len(task_data)):
        triple = task_data[i]["triple"]
        original_entity = task_data[i]["old_entity"]
        new_entity = task_data[i]['new_entity']
        gen_prompt = qwen_answer_gen_prompt.format(triple, original_entity, new_entity)
        gen_prompts.append(gen_prompt)
    print(gen_prompts[0])
    generated_answers = call_qwen_batch(gen_prompts)
    json.dump(generated_answers, open("data/generated_answers_task6.json", "w"), ensure_ascii=False)
    for i in range(len(task_data)):
        data = task_data[i]
        cot = generated_answers[i]
        entity = data["entity"]
        relation = data["relation"]
        triple = data["triple"]
        source = data["source"]
        image_path = data["image"]
        original_entity = data["old_entity"]
        new_entity = data['new_entity']
        processed_image_path = image_path.replace("generate/images/", "m5/")
        answer = answer_prompt.format(cot, new_entity)
        data_instance = {
            "messages": [
                {
                    "content": task_prompt,
                    "role": "user"
                },
                {
                    "content": answer,
                    "role": "assistant"
                }
            ],
            "images": [processed_image_path],
            "metadata": {
                "entity": entity,
                "relation": relation,
                "triple": triple,
                "source": source,
                "task_type": 6,
                "old_entity": original_entity,
                "new_entity": new_entity
            }
        }
        # print(data_instance)
        task_prompts.append(data_instance)
    random.shuffle(task_prompts)
    json.dump(task_prompts[0: 6400], open("dataset/task6_prompts_train.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[6400: 7200], open("dataset/task6_prompts_valid.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[7200: 8000], open("dataset/task6_prompts_test.json", "w"), ensure_ascii=False)