import torch
import pickle as pkl
import heapq

from rank_bm25 import BM25Okapi
from transformers import DPRContextEncoderTokenizer, DPRContextEncoder

def top_k_indices(lst, k):
    if k > len(lst):
        k = len(lst)
    largest = heapq.nlargest(k, enumerate(lst), key=lambda x: x[1])
    indices = [index for index, value in largest]
    return indices

class EmbeddingFnsClass(object):
    def __init__(self, model_path):
        self.kg_tokenizer = DPRContextEncoderTokenizer.from_pretrained(model_path)
        self.kg_model = DPRContextEncoder.from_pretrained(model_path)
        self.cache = {}

    def __call__(self, sentences):
        return_embeddings = [None] * len(sentences)
        inference_sentences = []
        for idx, sent in enumerate(sentences):
            if sent in self.cache:
                return_embeddings[idx] = self.cache[sent]
            else:
                inference_sentences.append(sent)

        if inference_sentences:
            input_ids = self.kg_tokenizer(inference_sentences, return_tensors="pt", padding=True, truncation=True)
            embeddings = self.kg_model(input_ids=input_ids['input_ids'], attention_mask=input_ids['attention_mask']).pooler_output

            ret_idx = 0
            for idx in range(embeddings.shape[0]):
                self.cache[inference_sentences[idx]] = embeddings[idx]
                while return_embeddings[ret_idx] is not None:
                    ret_idx += 1
                return_embeddings[ret_idx] = embeddings[idx]
        return torch.stack(return_embeddings, axis=0)


class InContextExpert(object):
    def __init__(self, path):
        with open(path, "rb") as f:
            self.expert_dataset = pkl.load(f)

    def make_zsp_context(self):
        data = {}
        for key in self.expert_dataset:
            data[key] = []
            action_data = []
            for d in self.expert_dataset[key]:
                if d['timesteps'] == 1:
                    if action_data:
                        data[key].append(action_data)
                    action_data = []
                action_data.append(d['action'])
            if action_data:
                data[key].append(action_data)
        return data

    def make_context(self, agent, executions, icl_prompt_template, queries):
        listed_data = []
        if executions[0] == "Explore the home" or executions[0].split()[0] == "Find":
            for key in self.expert_dataset.keys():
                listed_data.extend(self.expert_dataset[key])
            executions = queries
        else:
            for exec in executions:
                listed_data.extend(self.expert_dataset[exec])

        temp_kg = agent.retrieve(executions, num_edges=12, return_type="str_list")
        for kg in temp_kg:
            if "hold" in kg:
                hold_condition = kg
                break

        search_data = [d['knowledge_graph'].retrieve(executions, embedding_fns=agent.embedding_fns, num_edges=12, return_type="str_list") for d in listed_data]
        idx_map = [idx for idx in range(len(search_data)) if hold_condition in search_data[idx]]
        search_data = [d for d in search_data if hold_condition in d]
        if not search_data:
            search_data = [d['knowledge_graph'].retrieve(executions, embedding_fns=agent.embedding_fns, num_edges=12, return_type="str_list") for d in listed_data]
            idx_map = [idx for idx in range(len(search_data))]
        bm25 = BM25Okapi(search_data)
        score = bm25.get_scores(temp_kg)
        searching_idx = top_k_indices(score, 2)

        icl_prompt = []
        for idx in searching_idx:
            if listed_data[idx_map[idx]]['timesteps'] != 1:
                prev_action = listed_data[idx_map[idx] - 1]['action']
            else:
                prev_action = "First step"
            icl_prompt.append(icl_prompt_template(", ".join(search_data[idx]), listed_data[idx_map[idx]]['task_instruction'], prev_action, listed_data[idx_map[idx]]['action']))
        return temp_kg, icl_prompt