import json
import random
from tqdm import tqdm
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity


qwen_answer_gen_prompt = """
There is one relation (edge) replaced by [MASK] in the given triple list with the format of [head, tail, relation]. We will provide you with the correct answer and several options. Please explain why the answer relation is more suitable for the triples compared with other options entities and why other relations are not suitable by considering the known triples.

The triples are as follows:
{}

Correct Answer Relation: {}
Options: {}

Please provide a concise short explanation by considering the known triples. Do not include any additional information.
"""

answer_prompt = """
<think>
{}
</think>
<answer>{}</answer>
"""

task_prompt = "<image>Given the multi-modal knowledge graph. One relation in it is replaced by [MASK]. Please select one correct relation from the options. {}"

class NativeSentenceBERT:
    def __init__(self, model_name="pretrain"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).cuda()
    
    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def encode(self, texts):
        encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        encoded_input = {key: value.cuda() for key, value in encoded_input.items()}
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        return self.mean_pooling(model_output, encoded_input["attention_mask"]).cpu().numpy()
    
    def similarity(self, query, candidates, top_k=30):
        query_vec = self.encode([query])
        cand_vecs = candidates
        sim_scores = cosine_similarity(query_vec, cand_vecs)[0]
        top_indices = np.argsort(sim_scores)[-top_k:][::-1]
        # return [(candidates[i], sim_scores[i]) for i in top_indices]
        return top_indices


candidates = json.load(open("data/candidates.json", "r"))
encode_candidates = {}
native_bert = NativeSentenceBERT()
print("Initialize Relation Embeddings.")
for key in candidates:
    ent_cands = candidates[key]['relation']
    if len(ent_cands) <= 20000:
        embeddings = native_bert.encode(ent_cands)
        encode_candidates[key] = embeddings
    else:
        embeddings1 = native_bert.encode(ent_cands[0: 20000])
        embeddings2 = native_bert.encode(ent_cands[20000: 40000])
        embeddings3 = native_bert.encode(ent_cands[40000: 60000])
        embeddings4 = native_bert.encode(ent_cands[60000:])
        embeddings = np.concatenate([embeddings1, embeddings2, embeddings3, embeddings4], axis=0)
    encode_candidates[key] = embeddings
    print(embeddings.shape)



if __name__ == "__main__":
    task_data = json.load(open("data/processed_instances_with_images_task8.json", "r"))
    task_prompts = []
    gen_prompts = []
    option_prompts = []
    
    option_prompts = json.load(open("data/option_prompts_task8.json", "r"))
    generated_answers = json.load(open("data/generated_answers_task8.json", "r"))
    for i in range(len(task_data)):
        if i >= 12000:
            break
        data = task_data[i]
        cot = generated_answers[i]
        entity = data["entity"]
        relation = data["relation"]
        triple = data["triple"]
        source = data["source"]
        image_path = data["image"]
        answer_entity = data["answer"]
        processed_image_path = image_path.replace("generate/images/", "m5/")
        options, answer_index = option_prompts[i]
        answer = answer_prompt.format(cot, answer_index)
        data_instance = {
            "messages": [
                {
                    "content": task_prompt.format(options),
                    "role": "user"
                },
                {
                    "content": answer,
                    "role": "assistant"
                }
            ],
            "images": [processed_image_path],
            "metadata": {
                "entity": entity,
                "relation": relation,
                "triple": triple,
                "source": source,
                "task_type": 7,
                "answer": answer_entity,
                "answer_index": answer_index
            }
        }
        # print(data_instance)
        task_prompts.append(data_instance)
    random.shuffle(task_prompts)
    print(task_prompts[0])
    json.dump(task_prompts[0: 6400], open("dataset/task8_prompts_train.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[6400: 7200], open("dataset/task8_prompts_valid.json", "w"), ensure_ascii=False)
    json.dump(task_prompts[7200: 8000], open("dataset/task8_prompts_test.json", "w"), ensure_ascii=False)
