import json
import random
from tqdm import tqdm
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity



qwen_answer_gen_prompt = """
There is one entity replaced by [MASK] in the given triple list with the format of [head, tail, relation]. We will provide you with the correct entity and several options. Please explain why the correct entity is more suitable for the triples compared with other options entities and why other entities are not suitable by considering the known triples.

The triples are as follows:
{}

Correct Answer Entity: {}
Options: {}

Please provide a concise short explanation by considering the known triples. Do not include any additional information.
"""

answer_prompt = """
<think>
{}
</think>
<answer>{}</answer>
"""

task_prompt = "<image>Given the multi-modal knowledge graph. One entity in it is replaced by [MASK]. Please select one correct entity from the options. {}"

class NativeSentenceBERT:
    def __init__(self, model_name="pretrain"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).cuda()
    
    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def encode(self, texts):
        encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        encoded_input = {key: value.cuda() for key, value in encoded_input.items()}
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        return self.mean_pooling(model_output, encoded_input["attention_mask"]).cpu().numpy()
    
    def similarity(self, query, candidates, top_k=30):
        query_vec = self.encode([query])
        cand_vecs = candidates
        sim_scores = cosine_similarity(query_vec, cand_vecs)[0]
        top_indices = np.argsort(sim_scores)[-top_k:][::-1]
        # return [(candidates[i], sim_scores[i]) for i in top_indices]
        return top_indices


if __name__ == "__main__":
    task_data = json.load(open("data/processed_instances_with_images_task7.json", "r"))
    task_prompts = []
    gen_prompts = []
    option_prompts = []
    """
    for i in tqdm(range(len(task_data))):
        triple = task_data[i]["triple"]
        answer_entity = task_data[i]["answer"]
        source = task_data[i]['source']
        top_indices = native_bert.similarity(answer_entity, encode_candidates[source])
        top_ents = [candidates[source]['entity'][idx] for idx in top_indices][1:]
        select_options = random.sample(top_ents, 4)
        random.shuffle(select_options)
        opt_prompt = ""
        select_answer_index = random.choice(['A', 'B', 'C', 'D', 'E'])
        for opt in ['A', 'B', 'C', 'D', 'E']:
            if opt == select_answer_index:
                opt_prompt += f"{opt}. {answer_entity}\n"
            else:
                opt_prompt += f"{opt}. {select_options.pop()}\n"
        option_prompts.append((opt_prompt.strip(), select_answer_index))
        gen_prompt = qwen_answer_gen_prompt.format(triple, answer_entity, opt_prompt.strip())
        gen_prompts.append(gen_prompt)
    print(option_prompts[0])
    print(gen_prompts[0])
    json.dump(gen_prompts, open("data/gen_prompts_task7.json", "w"), ensure_ascii=False)
    json.dump(option_prompts, open("data/option_prompts_task7.json", "w"), ensure_ascii=False)
    
    from qwen import call_qwen_batch
    gen_prompts = json.load(open("data/gen_prompts_task7.json", "r"))
    option_prompts = json.load(open("data/option_prompts_task7.json", "r"))
    generated_answers = call_qwen_batch(gen_prompts)
    json.dump(generated_answers, open("data/generated_answers_task7.json", "w"), ensure_ascii=False)
    """
    option_prompts = json.load(open("data/option_prompts_task7.json", "r"))
    generated_answers = json.load(open("data/generated_answers_task7.json", "r"))
    for i in range(len(task_data)):
        if i >= 12000:
            break
        data = task_data[i]
        cot = generated_answers[i]
        entity = data["entity"]
        relation = data["relation"]
        triple = data["triple"]
        source = data["source"]
        image_path = data["image"]
        answer_entity = data["answer"]
        processed_image_path = image_path.replace("generate/images/", "m5/")
        options, answer_index = option_prompts[i]
        answer = answer_prompt.format(cot, answer_index)
        data_instance = {
            "messages": [
                {
                    "content": task_prompt.format(options),
                    "role": "user"
                },
                {
                    "content": answer,
                    "role": "assistant"
                }
            ],
            "images": [processed_image_path],
            "metadata": {
                "entity": entity,
                "relation": relation,
                "triple": triple,
                "source": source,
                "task_type": 7,
                "answer": answer_entity,
                "answer_index": answer_index
            }
        }
        # print(data_instance)
        task_prompts.append(data_instance)
    random.shuffle(task_prompts)
    print(task_prompts[0])
    json.dump(task_prompts[0: 6400], open("dataset/task7_prompts_train.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[6400: 7200], open("dataset/task7_prompts_valid.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[7200: 8000], open("dataset/task7_prompts_test.json", "w"), ensure_ascii=False)
