import json
import copy
import torch
import pickle
import random
from tqdm import tqdm
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel


def capture():

    query_prefix = "Instruct: Given a question, retrieve information that answer the question\nQuery: "
    max_length = 8192
    model = AutoModel.from_pretrained('nvidia/NV-Embed-v2', trust_remote_code=True)
    model.cuda()

    data_path = 'viquae'
    with open(f'/data/{data_path}/test.json', 'r') as f:
        data = json.load(f)

    result = []
    for each in tqdm(data):
        question_embedding = model.encode([each['vqa_question']], instruction=query_prefix, max_length=max_length)[0]
        knowledge_embedding = model.encode([each['knowledge'][0]], instruction="", max_length=max_length)[0]
        result.append(dict(id=each['id'], question=question_embedding, knowledge=knowledge_embedding))

    with open(f'/embedding/knowledge/{data_path}_vqa.pkl', 'wb') as f:
        pickle.dump(result, f)


def cosine():
    data_path = 'viquae'
    with open(f'data/{data_path}/test.json', 'r') as f: data = json.load(f)
    with open(f'embedding/{data_path}.pkl', 'rb') as f: embedding = pickle.load(f)
    knowledges = F.normalize(torch.stack([each['knowledge'] for each in embedding]), p=2, dim=1)
    questions = F.normalize(torch.stack([each['question'] for each in embedding]), p=2, dim=1)
    answer_similars = (knowledges @ knowledges.T) * 100
    question_similars = (questions @ knowledges.T) * 100
    answer_indices = torch.topk(answer_similars, k=10, dim=1).indices
    question_indices = torch.topk(question_similars, k=20, dim=1).indices
    result = []
    for i, each in enumerate(tqdm(data)):
        assert i in answer_indices[i]
        assert i in question_indices[i]
        additional_1 = [data[indices]['knowledge'][0] for indices in answer_indices[i]]
        additional_2 = [data[indices]['knowledge'][0] for indices in question_indices[i]]
        three_knowledge, rag_knowledge = [], []
        for k in additional_1:
            if k not in three_knowledge: three_knowledge.append(k)
            if len(three_knowledge) >= 3: break
        for k in additional_2:
            if k not in rag_knowledge: rag_knowledge.append(k)
            if len(rag_knowledge) >= 10: break
        random.shuffle(three_knowledge)
        random.shuffle(rag_knowledge)
        temp = copy.deepcopy(each)
        temp['one_knowledge'] = temp.pop('knowledge')[0]
        temp['three_knowledge'] = three_knowledge
        temp['rag_knowledge'] = rag_knowledge
        result.append(temp)
    with open(f'data/{data_path}/testkn.json', 'w') as f:
        json.dump(result, f, ensure_ascii=False, indent=4)


if __name__ == "__main__":
    capture()