import networkx as nx
import numpy as np 
import matplotlib.pyplot as plt
import copy
from pyvis.network import Network
import datetime

class DataGraph(nx.DiGraph):
        def __init__(self, model, LLM, query, input=None, max_node_count=100, label= None, verbose=False):
            super().__init__()
            self.query = query
            self.max_node_count = max_node_count
            self.label = label
            self.verbose = verbose
            
            
            if self.label == None:
                self.training = False
            elif self.label == "Q":
                self.training = False
            else:
                self.training = True
                
            if input is None:
                self.input = query
            else:
                self.input = input
                
            self.model = model
            self.LLM = LLM
            if self.training:
                self.baseline = model.evaluate(self.input,[""], label)[0]
            else:
                if self.label == "Q":
                    self.baseline = 0
                else:
                    self.baseline = float(model.evaluate_heuristic(self.input,[""])[0])

            self.baseline_answer = model.predict(self.input,[""])[0]
            if self.verbose:
                print("Baseline : ", self.baseline)
            self.add_node(len(self), information = "", type = "thought" ,score = self.baseline)
        
        def compute_new_scores(self, new_thoughts):
            new_scores = []
            if self.training:
                new_scores = self.model.evaluate(self.input, new_thoughts, self.label)
            else:
                if self.label == "Q":
                    new_scores = len(new_thoughts) * [0]
                else:
                    new_scores = self.model.evaluate_heuristic(self.input, new_thoughts)
            return new_scores
        
        def update(self, transformation):
            prompts,new_thoughts = transformation.apply(self.LLM, self)
            new_scores = self.compute_new_scores(new_thoughts)
            if self.verbose:
                print("new scores : ", new_scores)
            for (i,t) in enumerate(new_thoughts):
                self.add_thought(t, np.concatenate((transformation.input_thoughts,transformation.input_documents)), float(new_scores[i]), prompts[i])
            return new_scores
        
        def add_thought(self, thought, parents, score, prompt):
            id = len(self)
            self.add_node(id, information = thought, type ="thought", score = score)
            for p in parents:
                self.add_edge(int(p), id)
        
        def document_nodes(self):
            return [n for n in self.nodes if self.nodes[n]["type"] == "document"]
        
        def thought_nodes(self):
            return [n for n in self.nodes if self.nodes[n]["type"] == "thought"]
        
        def add_documents(self, collection):
            collection = [(i+len(self),{"information" : x["text"].replace('$$','\$\$'), "type" : "document"}) for (i,x) in enumerate(collection)]
            self.add_nodes_from(collection) 
        
        def matrix(self):
            mat =np.pad(nx.to_numpy_array(self), (0,self.max_node_count - len(self)), mode='constant', constant_values=0).astype(np.float32) 
            return mat
        
        def scores(self, nodes=None):
            if nodes is None:
                return {n : self.nodes[n]["score"] if self.nodes[n]["type"] == "thought" else 0 for n in self.nodes }
            else:
                return {n : self.nodes[n]["score"] for n in nodes if self.nodes[n]["type"] == "thought"}
        
        def draw_sub_graph(self, nodes):
            # this function draw in red the nodes in the list nodes
            edges_list = [(nodes[i],nodes[i+1]) for i in range(len(nodes)-1)]
            edges,weights = zip(*nx.get_edge_attributes(self,'weight').items())
            pos = self.pos
            nx.draw(self, pos, node_color='b', edgelist=edges, edge_color=weights, edge_cmap=plt.cm.Blues)
            nx.draw_networkx_nodes(self, pos, nodelist=nodes, node_color='r')
            nx.draw_networkx_edges(self, pos, edgelist=edges_list, edge_color='r', width=2)
            plt.savefig("results/test_subgraph.png")
            
        def display_save(self, save_folder):
            date = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
            nx.draw_kamada_kawai(self, with_labels=True)
            plt.savefig(save_folder + date + ".png")
            plt.clf()
            nx.write_gml(self, save_folder + date + ".gml")
   