import networkx as nx
from nebula3.gclient.net import ConnectionPool
from nebula3.Config import Config
from function_call_agent.graph_search.graph_search_alpha_bata import Graph_Search_AB
from function_call_agent.graph_search.graoh_search_heuristic import HybridSearcher


class NebulaConnection:
    def __init__(self):
        self.config = Config()
        self.config.max_connection_pool_size = 10
        self.pool = ConnectionPool()

    def __enter__(self):
        self.pool.init([('xx.xx.xx.xx', xx)], self.config)
        return self.pool

    def __exit__(self, *args):
        self.pool.close()


class GraphLoader:
    def __init__(self, connection_pool):
        self.conn_pool = connection_pool
        self.apis = {}

    def load_graph(self, graph=None) -> nx.DiGraph:
        if graph is None:
            graph = 'tool_graph_apibank'
        G = nx.DiGraph()
        with self.conn_pool.session_context('root', 'nebula') as session:
            session.execute('USE ' + graph)

            node_result = session.execute("MATCH (n) RETURN n")
            nodes = {}
            for record in node_result:
                node = record.get_value_by_key('n').as_node()
                node_id = node.get_id().cast()
                if node.prop_values('Node__')[0].cast() == 'api':
                    prop_name = node.get_id().cast()
                    prop_dict = dict(zip(node.prop_names('Props__'), node.prop_values('Props__')))
                    prop_dict = {
                        'description': prop_dict['description'].cast(),
                        'input_param': prop_dict['input_param'].cast(),
                        'output_param': prop_dict['output_param'].cast(),
                        'param_mapping': eval(prop_dict['param_mapping'].cast()),
                        'ori_name': prop_dict['ori_name'].cast()
                    }
                    self.apis[prop_name] = prop_dict
                else:
                    prop_name = node.get_id().cast()
                    prop_dict = dict(zip(node.prop_names('Props__'), node.prop_values('Props__')))
                    prop_dict = {
                        'description': prop_dict['description'],
                        'param_type': prop_dict['param_type'],
                    }

                nodes[node_id] = {
                    "type": node.prop_values('Node__')[0].cast(),
                    'name': prop_name,
                    'prop': prop_dict
                }

            edges = []
            edge_result = session.execute("MATCH ()-[r]->() RETURN r")
            for record in edge_result:
                edge_value = record.get_value_by_key('r')
                relationship = edge_value.as_relationship()
                edge_props = {
                    'call_cnt': relationship.properties()['call_cnt'],
                    'dependence_cnt': relationship.properties()['dependence_cnt'],
                    'dependence_rate': relationship.properties()['dependence_rate'],
                    'weight': relationship.properties()['weight'],
                    'label': relationship.properties()['label'],
                }
                edges.append((relationship.start_vertex_id().cast(), relationship.end_vertex_id().cast(), edge_props))

        G.add_nodes_from(nodes.items())
        G.add_edges_from(edges)
        return G


class Graph_Search:
    def __init__(self, search_type=None, graph_degree=None, data_type=None):
        self.search_type = search_type
        if data_type == 'api_bank':
            graph = 'tool_graph_apibank'
        else:
            graph = 'tool_graph_toolbench'
        with NebulaConnection() as conn_pool:
            self.loader = GraphLoader(conn_pool)
            self.graph = self.loader.load_graph(graph)
        self.graph_degree = graph_degree
        self.searcher = None
        self.apis = self.loader.apis
        self.sub_graph = None
        self.data_type = data_type

    def run(self, target_apis):
        return self.get_graph_search(target_apis)

    def get_graph_search(self, target_apis):
        if self.search_type == 'alpha_beta':
            self.searcher = Graph_Search_AB(target_apis, graph_degree=self.graph_degree, graph=self.graph,
                                            data_type=self.data_type)
            info1, self.sub_graph = self.searcher.get_graph()
            return info1
        elif self.search_type == 'hybrid':
            if self.searcher is None:
                self.searcher = HybridSearcher(self.graph, graph_degree=self.graph_degree, data_type=self.data_type)
            info1, self.sub_graph = self.searcher.hybrid_search(target_apis)
            return info1
        else:
            raise ValueError('Not Found search model')


if __name__ == '__main__':
    gs = Graph_Search(search_type='hybrid',
                      graph_degree=None,
                      data_type='tool_bench')

    info = gs.run(target_apis=['api-job_details_for_indeed'])
    print(info)


