from source.controller.mcts import MCTS, Node
import numpy as np 
import numpy as np
from source.language_models.lama import Llama2HF
from source.controller.graph import DataGraph
from source.controller.ranking import top_k
import pandas as pd
import os 
from source.language_models.prompt_templates import document_query_prompt
from munch import DefaultMunch



class Q_search(MCTS):
    def __init__(self, dataset_name, collection, save_graph, save_folder, collect_data=False, q_esti_thought=None, q_esti_parents=None, verbose=False):
        super().__init__(dataset_name=dataset_name, collection=collection, save_graph=save_graph, save_folder=save_folder, collect_data=collect_data, q_esti_thought=q_esti_thought, q_esti_parents=q_esti_parents, verbose= verbose)
        
        self.q_estimator = self.q_esti_parents
        self.treshold = 0.22
        

    
    def list_possible_pairs(self, index):
        results = []
        for i in range(len(self.graph)):
            if i > index:
                results.append((index, i))
            elif i < index:
                results.append((i, index))
        return results
        
    def best_q_value_pair(self):
       best_pair = max(self.q_values.items(), key=lambda x: x[1])
       return best_pair
       
    def _compute_new_q_values(self, new_node_list):
        for new_node in new_node_list:
            pair_list = self.list_possible_pairs(new_node)
            for pair in pair_list:
                if pair not in self.q_values:
                    q_value = self.q_estimator.predict(np.concatenate(([self.node_dict[pair[0]].embedding], [self.node_dict[pair[1]].embedding],  [self.node_dict[0].embedding]), axis=1), num_iteration=self.q_esti_parents.best_iteration)
                    self.q_values[pair] = q_value
                   
    def _step(self):
        best_pair, best_score = self.best_q_value_pair()
        
        
        if self.node_dict[best_pair[1]].is_document:
            self.node_dict[best_pair[0]].expand_document(best_pair[1])
        else:
            if self.node_dict[best_pair[0]].is_document:
                self.node_dict[best_pair[0]].expand_thought([best_pair[1]])
            else:
                self.node_dict[best_pair[1]].expand_document(best_pair[0])
        new_node = len(self.graph)-1
        self._compute_new_q_values([new_node])
        terminal = best_score > self.treshold
        
        return terminal, self.node_dict[new_node]
        
        
        
    def _need_new_document(self, thought=None, query=None):
        if self.verbose:
            print("Need new document")
        if query == None:
            query = self.LLM.query(document_query_prompt(self.graph.query, thought))[1][0]
            
        index = len(self.graph)
        new_documents = top_k([query], dataset_name=self.dataset_name, k=self.new_document_rate, retriever=self.retriever, tokenizer=self.retriever_tokenizer, collection=self.collection)[0]["ctxs"]
        self.graph.add_documents(new_documents)
        new_docs = [ index + i for i in range(self.new_document_rate)]
        self.documents += new_docs
        for doc in new_docs:
            docu_node = Node(doc, self.graph, [], self.node_dict, self.collect_data, is_document=True, retriever=self.retriever, tokenizer=self.retriever_tokenizer, dataset_name=self.dataset_name, k=self.new_document_rate)
        self._compute_new_q_values(new_docs)
    
    
    
    def _search(self, root, label=None):
        self.graph = DataGraph(self.scoring_model, self.LLM, root, max_node_count=100, label= label)
        self.documents= []
        self.node_dict = {}
        self.q_values = {}

        root_node = Node(0,self.graph,[],self.node_dict, self.collect_data, retriever=self.retriever, tokenizer=self.retriever_tokenizer, dataset_name=self.dataset_name, k=self.new_document_rate)

        self._need_new_document(query=root)
        i = 0
        terminal = False
            
        while i < self.size and not terminal:
            if self.verbose:
                print("step : ", i , "max score : ", max(list(self.q_values.values())))
            terminal, node = self._step()
            self._need_new_document(thought=self.graph.nodes[node.index]["information"])
            if i < self.min_thoughts:
                terminal = False            
            i += 1
            
        if terminal and i > 0:
            best_thought = self.graph.nodes[node.index]["information"]
        else:
            best_thought = self.graph.nodes[0]["information"]
        
        result = {"Best thought" :best_thought, 
                  "Best answer" :self.scoring_model.predict(root, [best_thought]), 
                  "Baseline answer" :self.graph.baseline_answer}
        
        if self.save_graph:
            self.graph.display_save(save_folder=self.save_folder)
            
        return result 
        
        





class QMCTS(MCTS):
    def __init__(self, dataset_name, size, collection, save_graph, save_folder, collect_data=False, q_esti_thought=None, q_esti_parents=None, verbose=False):
        super().__init__(dataset_name=dataset_name, size=size, collection=collection, save_graph=save_graph, save_folder=save_folder, collect_data=collect_data, q_esti_thought=q_esti_thought, q_esti_parents=q_esti_parents, verbose= verbose)
        
        self.q_estimator = self.q_esti_parents
        # self.treshold = 0.25
        # self.new_document_rate = 10
            
    def list_possible_pairs(self, index):
        results = []
        for i in range(len(self.graph)):
            if "med_records" in self.dataset_name :
                if (self.graph.nodes[i]["type"] == "document") and (i in self.documents):
                    if i > index:
                        results.append((index, i))
                    elif i < index:
                        results.append((i, index))
            else:
                if i > index:
                    results.append((index, i))
                elif i < index:
                    results.append((i, index))
        return results
       
    def _compute_new_q_values(self, new_node_list):
        for new_node in new_node_list:
            pair_list = self.list_possible_pairs(new_node)
            for pair in pair_list:
                if pair not in self.q_values:
                    q_value = self.q_estimator.predict(np.concatenate(([self.node_dict[pair[0]].embedding], [self.node_dict[pair[1]].embedding],  [self.node_dict[0].embedding]), axis=1), num_iteration=self.q_esti_parents.best_iteration)
                    self.q_values[pair] = q_value
    
    def expansion(self, node):
        if len(self.documents) == 0:
            max_node = max(sorted(list(self.graph.scores().items()), key=lambda x : x[0])[self.index_batch:], key=lambda x: x[1])[0]
            self._need_new_document(self.graph.nodes[max_node]["information"])
        
        list_pair = self.list_possible_pairs(node.index)
        for pair in list_pair:
            if pair not in self.q_values:
                q_value = self.q_estimator.predict(np.concatenate(([self.node_dict[pair[0]].embedding], [self.node_dict[pair[1]].embedding],  [self.node_dict[0].embedding]), axis=1), num_iteration=self.q_esti_parents.best_iteration)
                self.q_values[pair] = q_value
        best_pair, score = max([(pair,self.q_values[pair]) for pair in list_pair], key=lambda x: x[1])
        if self.verbose:
            print("best score : ", score)

        if self.node_dict[best_pair[1]].is_document:
            self.documents.remove(best_pair[1])
            new_node =  self.node_dict[best_pair[0]].expand_document(best_pair[1])
        else:
            if (self.node_dict[best_pair[0]].is_document) and ("med_records" in self.dataset_name ):
                self.documents.remove(best_pair[0])
                new_node =  self.node_dict[best_pair[1]].expand_document(best_pair[0])
            else:
                new_node =  self.node_dict[best_pair[0]].expand_thought([best_pair[1]])
        
        self.graph.nodes[new_node.index]["score"] = score[0]
        return new_node, score
    
    def search(self, root, label=None):
        self.q_values = {}
        return self._search(root, label)
        
        
        
    def _need_new_document(self, thought=None, query=None):
        if self.verbose:
            print("Need new document")
        if query == None:
            query = self.LLM.query(document_query_prompt(self.graph.query, thought))[1][0]
            
        index = len(self.graph)
        new_documents = top_k([query], dataset_name=self.dataset_name, k=self.new_document_rate, retriever=self.retriever, tokenizer=self.retriever_tokenizer, collection=self.collection)[0]["ctxs"]
        self.graph.add_documents(new_documents)
        new_docs = [ index + i for i in range(self.new_document_rate)]
        self.documents += new_docs
        for doc in new_docs:
            docu_node = Node(doc, self.graph, [], self.node_dict, self.collect_data, is_document=True, retriever=self.retriever, tokenizer=self.retriever_tokenizer, dataset_name=self.dataset_name, k=self.new_document_rate)
    
    
    def _step(self, root):
        # print("document : ", self.documents)

        node = self.selection(root)
        for _ in range(self.step_depth):
            node, score = self.expansion(node)
        self.backpropagation(node)
        return score > self.treshold, node