import math
import networkx as nx
from typing import Set
from function_call_agent.graph_search.graph_tool import graph_to_tree_string, draw_dependency_graph


class GraphPruner:
    def __init__(self, G: nx.DiGraph, data_type):
        self.G = G
        self.api_nodes = [n for n, d in G.nodes(data=True) if n.startswith('api') or d['type'] == 'api']
        self.param_nodes = [n for n, d in G.nodes(data=True) if n.startswith('param') or d['type'] == 'param']
        self.alpha_decay = 0.85  # Alpha
        self.beta_growth = 1.15  # Beta
        self.max_nodes = 24
        self.data_type = data_type

    def _heuristic_threshold(self, decay_level: int) -> float:

        return max(0.3, 0.5 * (0.9 ** decay_level))

    def _trace_back(self, node: str, target_api, target_param, decay_level: int,
                    alpha: float, beta: float, graph_degree, visited, cnt, edge_data_list) -> Set[str]:
        if graph_degree is not None and decay_level >= min(graph_degree, 5):
            return set()
        current_relevant = set()
        node_type = 'api' if node.startswith('api') else 'param'
        if (len(edge_data_list) > self.max_nodes) and node_type == 'param':
            return set()
        predecessors = sorted(self.G.predecessors(node),
                              key=lambda n: self.G.get_edge_data(n, node).get('weight', 0).cast(),
                              reverse=True)

        for pred in predecessors:
            edge_data = self.G.get_edge_data(pred, node)
            if edge_data['label'].cast() in ["depends_on"]:
                continue
            if pred in visited:
                continue
            if not self.is_required(node, pred):
                continue
            visited.add(pred)
            target_api_edge = self.G.get_edge_data(pred, target_api)
            target_api_score = target_api_edge['weight'].cast() if target_api_edge else 0
            target_param_edge = self.G.get_edge_data(pred, target_param)
            target_param_score = target_param_edge['weight'].cast() if target_param_edge else 0
            base_decay = 1 / (1 + math.sqrt(decay_level))
            score = (edge_data['weight'].cast() + target_api_score + target_param_score)/3 * base_decay

            current_threshold = self._heuristic_threshold(decay_level)
            if node_type == 'param':
                if score < current_threshold and score < alpha:
                    continue  # Alpha
                if score > beta:
                    break     # Beta
                if score >= current_threshold:
                    current_relevant.add(pred)
                    edge_info = (node, pred, self.G.get_edge_data(pred, node))
                    if edge_info not in edge_data_list:
                        edge_data_list.append(edge_info)

                    new_alpha = max(alpha, score * self.alpha_decay)
                    new_beta = min(beta, score * self.beta_growth)

                    child_relevant = self._trace_back(pred, target_api, target_param,
                                                      decay_level + 1, new_alpha, new_beta, graph_degree, visited,
                                                      cnt + 1, edge_data_list=edge_data_list)
                    current_relevant |= child_relevant
            else:
                current_relevant.add(pred)
                edge_info = (node, pred, self.G.get_edge_data(pred, node))
                if edge_info not in edge_data_list:
                    edge_data_list.append(edge_info)

                new_alpha = max(alpha, score * self.alpha_decay)
                new_beta = min(beta, score * self.beta_growth)

                child_relevant = self._trace_back(pred, target_api, target_param,
                                                  decay_level + 1, new_alpha, new_beta, graph_degree, visited,
                                                  cnt + 1, edge_data_list=edge_data_list)
                current_relevant |= child_relevant
        return current_relevant

    def is_required(self, api, param):
        if api.startswith('param-'):
            return True
        if self.data_type == 'api_bank':
            return True
        else:
            return eval(self.G.nodes[api]['prop']['input_param'])[param[6:]]['is_required'] == 'true'

    def _get_relevant_nodes(self, target_api: str, graph_degree, edge_data_list) -> Set[str]:

        relevant = set()

        for target_param in self.G.predecessors(target_api):
            if self.G.get_edge_data(target_param, target_api)['label'].cast() == 'depends_on':
                continue
            if not self.is_required(target_api, target_param):
                continue
            edge_info = (target_api, target_param, self.G.get_edge_data(target_param, target_api))
            if edge_info not in edge_data_list:
                edge_data_list.append(edge_info)
            relevant.add(target_param)
            visited = {target_api}
            visited.add(target_param)
            relevant |= self._trace_back(target_param, target_api, target_param,
                                         decay_level=0, alpha=0.4, beta=0.9, graph_degree=graph_degree,
                                         visited=visited, cnt=0, edge_data_list=edge_data_list)
        relevant.add(target_api)
        return relevant

    def prune(self, target_apis: str, graph_degree=None) -> nx.DiGraph:
        sub_nodes = set()
        edge_data_list = []
        for target_api in target_apis:
            relevant_nodes = self._get_relevant_nodes(target_api, graph_degree=graph_degree,
                                                      edge_data_list=edge_data_list)
            connected_nodes = {
                n for n in relevant_nodes
                if nx.shortest_path_length(self.G, n, target_api) <= 5
            }
            sub_nodes |= connected_nodes
        subgraph = nx.DiGraph(self.G.subgraph(sub_nodes))

        pruned_graph = nx.DiGraph()
        pruned_graph.add_nodes_from(subgraph.nodes(data=True))
        pruned_graph.add_edges_from(edge_data_list)

        return pruned_graph

class Graph_Search_AB:
    def __init__(self, target_apis, graph, graph_degree=None, data_type=None):
        self.target_apis = target_apis
        self.graph_degree = graph_degree
        self.graph = graph
        self.data_type = data_type

    def get_graph(self):
        pruner = GraphPruner(self.graph, self.data_type)
        pruned_graph = pruner.prune(target_apis=self.target_apis, graph_degree=self.graph_degree)

        tree_str = graph_to_tree_string(pruned_graph, target_apis=self.target_apis)
        # print(tree_str)
        # draw_dependency_graph(pruned_graph, target_apis)
        return tree_str, pruned_graph



