import numpy as np
import torch 
import torch.nn.functional as F

from embodied_cd.trl.algos.pipe import SentenceSimilarityPipeline, TfidfPipeline
from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.common.print_utils import *


TokenList = ['turn', 'open', 'on', 'in', 'the', 'put', 'place']
ReplaceList = ['A123', 'B456', 'C789']


class RAGPipeline:
    def __init__(
        self, 
        env_name='virtualhome', 
        rew_tokenizer=None, 
        rew_model=None, 
        dataset=None,
        evaluate="train",
    ):
        self.env_name = env_name 
        self.rew_tokenizer = rew_tokenizer
        self.rew_model = rew_model
        self.dataset = dataset

        # get pipeline
        self.cossim_pipe = SentenceSimilarityPipeline()
        self.tfidf_pipe = TfidfPipeline()

        # 
        self.evaluate = evaluate

    def retrieve(self, instruction, state, history, think_list):
        filtered_dataset = []
        

        if self.rew_model == None:
            # 1. filter with 'instruction'
            scores = [] 
            for data in self.dataset:
                score = self.tfidf_pipe(instruction, data["instruction"]) 
                scores.append(score)
            max_score = np.max(scores)
            
            for data, score in zip(self.dataset, scores):
                if score == max_score:
                    filtered_dataset.append(data)

            print_check(f"1. Filtered by instruction: {len(filtered_dataset)}")

            # 2. filter with 'state' & history
            scores = []
            for data in filtered_dataset:
                # masking
                state, history, _ = self.masking(
                    instruction, PromptTemplate.preprocess(state), history, None)
                state_data, history_data, _ = self.masking(
                    data['instruction'], PromptTemplate.preprocess(data['state']), data['history'], None)

                state_score = self.cossim_pipe(state, state_data)
                history_score = self.cossim_pipe(history, history_data)
                scores.append(state_score + history_score)
            sort_index = np.argsort(scores)[::-1]
            topk = 1 if int(len(sort_index) * 0.1) == 0 else int(len(sort_index) * 0.1)
            sort_index = sort_index[:topk] # retrieve top 10 %

            print_check(f"2. Filtered by State and History: {len(sort_index)}")
            
            # 3. scoring ...
            scores = [0. for _ in range(len(think_list))]
            for index in sort_index:
                data = filtered_dataset[index]
                _, _, think_list = self.masking(instruction, None, None, think_list)
                _, _, think_list_data = self.masking(data['instruction'], None, None, data['think_list'])

                for i, (think1, think2) in enumerate(zip(think_list, think_list_data)):
                    score = self.cossim_pipe(think1, think2)
                    scores[i] += score / len(sort_index)

            score_str = []
            for score in scores:
                score_str.append(self.get_score_str(score))

            rag_action = None
        else:
            # retrieve the sample using the reward function
            query = [f"{state}\n{history}"]
            response = []
            for data in self.dataset:
                response.append(f"{data['state']}\n{data['history']}")
            query_embed = self.encode(query)
            response_embed = self.encode(response)
            
            scores = []
            for res, embed in zip(response, response_embed):
                cossim = F.cosine_similarity(query_embed[0], embed, dim=0)
                scores.append(cossim.item())
            max_index = np.argmax(scores)
            filtered_data = self.dataset[int(max_index)]

            # check think cossim
            scores = []
            if self.env_name == 'virtualhome':
                think_list_masked = think_list
                think_list_data_masked = filtered_data['think_list']
                #_, _, think_list_masked = self.masking(instruction, None, None, think_list)
                #_, _, think_list_data_masked = self.masking(filtered_data['instruction'], None, None, filtered_data['think_list'])
            elif self.env_name == 'alfred':
                think_list_masked = think_list
                think_list_data_masked = filtered_data['think_list']
            else:
                raise NotImplementedError

            for i, (think1, think2) in enumerate(zip(think_list_masked, think_list_data_masked)):
                cossim = self.cossim_pipe(think1, think2)
                scores.append(cossim)

            score_str = []
            for score in scores:
                score_str.append(self.get_score_str(score))

            rag_action = filtered_data['action']
        return scores, score_str, rag_action
    
    def encode(self, sentences: list):
        input_ids = self.rew_tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(self.rew_model.device)
        model_output = self.rew_model(**input_ids)
        # mean pooling
        token_embeddigns = model_output[0]
        input_mask_expanded = input_ids['attention_mask'].unsqueeze(-1).expand(token_embeddigns.size()).float()
        sentence_embeddings = torch.sum(token_embeddigns * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sentence_embeddings

    def masking(self, instruction, state, history, think_list):
        doc_tokens = instruction.lower().split(" ")
        for token in doc_tokens:
            if token in TokenList:
                doc_tokens.remove(token)
        doc_lengths = [len(token) for token in doc_tokens]
        doc_lengths = np.argsort(doc_lengths)[::-1]
        """
        _doc_tokens = []
        for i in doc_lengths:
            _doc_tokens.append(doc_tokens[i])
        """
        doc_dict = {}
        for i, token in enumerate(doc_tokens):
            doc_dict[token] = ReplaceList[i] 

        for origin, target in doc_dict.items():
            if state is not None:
                state = state.replace(origin, target)
            if history is not None:
                history = state.replace(origin, target)
            if think_list is not None:
                for i, think in enumerate(think_list):
                    think_list[i] = think.replace(origin, target)
        return state, history, think_list
    
    def get_score_str(self, score):
        if self.evaluate == "train":
            thresh = 0.0
        elif self.evaluate == "seen":
            thresh = 0.1
        elif self.evaluate == "unseen":
            thresh = 0.1

        if score < (0.6 - thresh):
            return "There are many erros in the Think. You need a major revision in the Think."
        elif score >= (0.6 - thresh) and (score < 0.82 - thresh):
            #return "There are small errors in the Think. You need a minor revision in the Think."
            return "There are some errors in the Think. You need a moderate revision in the Think."
        else:
            #return "There is no error in the Think. You do not need to revise the Think, only the small modifications are allowed, if necessary."
            return "There is little error in the Think. You need a minor revision in the Think."
