import numpy as np
from source.controller.retriever.contriever import Contriever
from transformers import AutoTokenizer
from source.language_models.lama import Llama2HF
from source.controller.graph import DataGraph
from source.controller.transformation import Transformation
from source.controller.ranking import top_k
from scipy.special import softmax
from source.models.model import QA_model
import pandas as pd
import os 
import lightgbm as lgb
import torch

def choose_thought(graph, thought_sample_size):
    thought_nodes = graph.thought_nodes()
    thought_score = [graph.nodes[t]["score"] for t in thought_nodes]
    sample_size = min(thought_sample_size, len(thought_nodes))
    
    zip_list = list(zip(thought_nodes, thought_score))
    sorted_zip_list = sorted(zip_list, key=lambda x: x[1], reverse=True)[:sample_size]
    thought_nodes, thought_score = zip(*sorted_zip_list)
    
    thought_score = softmax(np.array(thought_score))
    thoughts = np.random.choice(thought_nodes, 1, replace=False, p=thought_score)  
    return thoughts.astype("int64")


class Node():
    def __init__(self,index,graph, parents, node_dict, collect_data=False, is_document=False, retriever=None, tokenizer=None, dataset_name=None, k=None):
        self.parents = parents
        self.visits = 1
        self.index = index
        if not is_document:
            self.value = graph.nodes[index]["score"]
        else:
            self.value = 0
        self.children = []
        self.node_dict = node_dict
        self.node_dict[index] = self
        self.collect_data = collect_data #collect_data
        self.graph = graph
        self.is_document = is_document
        self.retriever = retriever
        self.tokenizer = tokenizer
        self.dataset_name = dataset_name
        self.k = k

        
        for p in parents:
            p.children.append(self)
        
    def is_leaf(self):
        return len(self.children) == 0

    def expand_document(self, document, output_size=1):
        transformation = Transformation(self.graph.query, [self.index], [document], output_size, model=self.graph.model)
        self.graph.update(transformation)
        if document not in self.node_dict:
            docu_node = Node(document, self.graph, [], self.node_dict, self.collect_data, is_document=True, retriever=self.retriever, tokenizer=self.tokenizer, dataset_name=self.dataset_name, k=self.k)
            new_node = Node(len(self.graph)-1, self.graph, [self, docu_node ], self.node_dict, self.collect_data, retriever=self.retriever, tokenizer=self.tokenizer, dataset_name=self.dataset_name, k=self.k)
        else:
            docu_node = self.node_dict[document]
            new_node = Node(len(self.graph)-1, self.graph, [self, docu_node ], self.node_dict, self.collect_data, retriever=self.retriever, tokenizer=self.tokenizer, dataset_name=self.dataset_name, k=self.k)

        return new_node
    
    def expand_thought(self, thoughts, output_size=1):
        transformation = Transformation(self.graph.query, [self.index] + list(thoughts), [], output_size, self.graph.model)
        self.graph.update(transformation)
        new_node = Node(len(self.graph)-1, self.graph, [self, self.node_dict[thoughts[0]]], self.node_dict, self.collect_data, retriever=self.retriever, tokenizer=self.tokenizer, dataset_name=self.dataset_name, k=self.k)
        return new_node
        


class MCTS(): 
    def __init__(self, size=10, dataset_name="wikipedia_dump", collection=None, save_graph=False, save_folder=None, collect_data=False, verbose=False, q_esti_thought=None, q_esti_parents=None, eval_score=False):
        
        self.dataset_name = dataset_name        
        self.LLM = Llama2HF(config_path="source/language_models/config.json", model_name="TheBloke/Llama-2-70B-chat-AWQ", verbose=verbose)
        
        self.retriever = Contriever.from_pretrained("facebook/contriever").to("cuda")
        self.retriever_tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")
        self.retriever.eval()
        if torch.cuda.is_available():
            self.retriever = self.retriever.cuda()

        self.scoring_model = QA_model(self.LLM, verbose=verbose)
        
        if q_esti_thought != None:
            self.q_esti_thought = lgb.Booster(model_file=q_esti_thought)

        if q_esti_parents != None:
            self.q_esti_parents = lgb.Booster(model_file=q_esti_parents)
        
        self.collection = collection
        
        self.c = np.sqrt(2)
        self.nb_output = 1
        self.p_document = 0.7
        self.nb_input = 1
        self.new_document_rate = 10
        self.thought_sample_size = 5
        
        self.threshold = 0.49
        self.min_thoughts = 0 #5
        
        self.step_depth = 1
        
        self.size = size
        self.save_graph = save_graph
        self.save_folder = save_folder
        self.collect_data = collect_data
        self.verbose = verbose
        self.eval_score = eval_score
        
    def search(self, root, label=None):
        return self._search(root, label=label)
        
    def _search(self, root, label=None):
        self.graph = DataGraph(self.scoring_model, self.LLM, root, max_node_count=100, label= label, verbose=self.verbose)
        self.documents= []
        self.node_dict = {}
        root_node = Node(0,self.graph,[],self.node_dict, self.collect_data, retriever=self.retriever, tokenizer=self.retriever_tokenizer, dataset_name=self.dataset_name, k=self.new_document_rate)

        self._need_new_document(query=root)
        i = 0
        terminal = float(self.graph.nodes[0]["score"]) > self.threshold                
        
            
        while i < self.size and not terminal:
            if self.verbose:
                print("step : ", i)
                print("max score : ", max(self.graph.scores().values()))
                if i>0:
                    print("last_score : ", self.graph.nodes[len(self.graph)-1]["score"])
            
            terminal, node = self._step(root_node)
            
            
            
            if i < self.min_thoughts:
                terminal = False
                
            
            i += 1
            
        if terminal and i > 0:
            best_thought = self.graph.nodes[node.index]["information"]
        else:
            best_thought = self.graph.nodes[0]["information"]
        
        result = {"Best thought" :best_thought, 
                  "Best answer" :self.scoring_model.predict(root, [best_thought]), 
                  "Baseline answer" :self.graph.baseline_answer}
        
        if self.save_graph:
            self.graph.display_save(save_folder=self.save_folder)
            
        torch.cuda.empty_cache()       
        return result 

    def _step(self, root):
        node = self.selection(root)
        for _ in range(self.step_depth):
            node = self.expansion(node)
        self.backpropagation(node)
        torch.cuda.empty_cache()
        return float(self.graph.nodes[node.index]["score"]) > self.threshold, node
        
    def UCT_score(self, node):
        return node.value / node.visits + self.c * np.sqrt(np.log(np.sum([parent.visits for parent in node.parents])) / node.visits)
    
    def _need_new_document(self, thought=None, query=None):
        
        if self.verbose:
            print("Need new document")
            
        if "med_records" in self.dataset_name:
            query = self.graph.query
        # query = " ".join(query[1])
        if query == None:
            query = self.LLM.query(self.scoring_model.document_query_prompt(self.graph.query, thought))[1][0]
        if type(query) != str:
            query = query[0]
            
        index = len(self.graph)
        self.index_batch = len(self.graph.scores())
        if self.new_document_rate > 0:
            new_documents = top_k([query], dataset_name=self.dataset_name, k=self.new_document_rate, retriever=self.retriever, tokenizer=self.retriever_tokenizer, collection=self.collection)[0]["ctxs"]
        else:
            new_documents = []
        self.graph.add_documents(new_documents)
        self.documents += [ index + i for i in range(self.new_document_rate)]
        torch.cuda.empty_cache()       

    def selection(self, node):
        if node.is_leaf():
            return node
        else:
            return self.selection(max(node.children, key=self.UCT_score))
        
    def expansion(self, node):
        if (self.documents != []) and (np.random.rand() <  max(self.p_document, len(self.documents)/(len(self.graph)-1))):
            document = self.documents.pop(0)
            if len(self.documents) == 0:
                max_node = max(sorted(list(self.graph.scores().items()), key=lambda x : x[0])[self.index_batch:], key=lambda x: x[1])[0]
                self._need_new_document(self.graph.nodes[max_node]["information"])
            return node.expand_document(document, self.nb_output)
        else:
            thoughts = choose_thought(self.graph, self.thought_sample_size) 
            while thoughts[0] == node.index:
                thoughts = choose_thought(self.graph, self.thought_sample_size)
            return node.expand_thought(thoughts, self.nb_output)
    
    def backpropagation(self, node):
        if len(node.parents) != 0 :
            for parent in node.parents: 
                parent.visits += 1
                parent.value += node.value
                self.backpropagation(parent)

