import matplotlib.pyplot as plt
import networkx as nx
from collections import defaultdict


def draw_dependency_graph(G, target_apis):

    plt.figure(figsize=(16, 10))
    plt.ion()
    plt.rcParams['font.family'] = 'Arial Unicode MS'
    pos = nx.drawing.nx_agraph.graphviz_layout(G, prog='dot', args='-Grankdir=LR')

    node_colors = []
    for node in G.nodes():
        if node in target_apis:
            node_colors.append('gold')
        elif G.nodes[node]['type'] == 'api':
            node_colors.append('#8ECB8E')
        else:
            node_colors.append('#FFA07A')

    nx.draw(G, pos,
            with_labels=True,
            node_color=node_colors,
            edge_color='#666666',
            node_size=3000,
            font_size=12,
            font_weight='bold',
            arrows=False,
            arrowsize=25,
            width=0)

    edge_colors = []
    for u, v in G.edges():
        if G.nodes[u]['type'] == 'api' and G.nodes[v]['type'] == 'param':
            edge_colors.append('#FF6B6B')
            plt.annotate("",
                         xy=pos[v],
                         xytext=pos[u],
                         arrowprops=dict(arrowstyle="->",
                                         color="#FF6B6B",
                                         lw=3,
                                         connectionstyle="arc3,rad=0.2"))
        elif G.nodes[u]['type'] == 'param' and G.nodes[v]['type'] == 'api':
            edge_colors.append('#4D96FF')
            plt.annotate("",
                         xy=pos[v],
                         xytext=pos[u],
                         arrowprops=dict(arrowstyle="->",
                                         color="#4D96FF",
                                         lw=3,
                                         shrinkA=15, shrinkB=15))

    legend_elements = [
        plt.Line2D([0], [0], marker='o', color='w', label='target API',
                   markerfacecolor='gold', markersize=15),
        plt.Line2D([0], [0], marker='o', color='w', label='normal API',
                   markerfacecolor='#8ECB8E', markersize=15),
        plt.Line2D([0], [0], marker='o', color='w', label='parameter node',
                   markerfacecolor='#FFA07A', markersize=15),
        plt.Line2D([0], [0], color='#FF6B6B', lw=3, label='API→parameter'),
        plt.Line2D([0], [0], color='#4D96FF', lw=3, label='parameter→API')
    ]
    plt.legend(handles=legend_elements,
               loc='upper left',
               fontsize=12,
               framealpha=0.9)

    title = f"Tool dependency graph（target：{', '.join(target_apis)}）"
    plt.title(title, fontsize=14, pad=20)

    plt.gca().set_facecolor('#F5F5F5')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.show()
    plt.savefig('test.png')


def graph_to_tree_string(pruned_graph, target_apis, reverse=True):

    api_structures = []
    global_visited = defaultdict(list)

    def build_layer(node, visited_nodes, depth=0, prefix=""):
        children = list(tree.predecessors(node))
        children = [child for child in children if tree.get_edge_data(child, node)['label'].cast() in ['input', 'output']]
        if node.startswith('param-'):
            new_children = []
            for child in children:
                if len(new_children) == 3:
                    break
                if child in global_visited[node]:
                    continue
                new_children.append(child)
                global_visited[node].append(child)
            children = new_children
        lines = []
        node_type = "API" if node.startswith('api') else "Parameter"

        main_line = f"{'│   ' * (depth - 1)}{prefix}● {node} [{node_type}]"
        if node_type == "API" and node != api:
            main_line += f" (similarity: {tree.nodes[node]['decayed_sim']:.2f})"

        if node_type == "API":
            text_ = node_descriptions[node]['api_desc'].replace('\n', '|')
            desc_line = f"\n{'│   ' * (depth)}├── [Function] {text_}"
        else:
            text_ = node_descriptions[node]['api_desc'].cast().replace('\n', '|')
            desc_line = f"\n{'│   ' * (depth)}├── [Type] {node_descriptions[node]['param_type']} [Description] {text_}"

        if node in visited_nodes:
            if node_type == 'API':
                return lines
            return [main_line+desc_line]
        if node_type == "API":
            visited_nodes.append(node)
        else:
            visited_nodes.append(node)
        lines.append(main_line + desc_line)
        for i, child in enumerate(children):
            is_last = i == len(children) - 1
            new_prefix = "└── " if is_last else "├── "
            lines.extend(build_layer(child, visited_nodes.copy(), depth + 1, new_prefix))

        return lines

    for api in target_apis:

        if reverse:
            tree = nx.reverse_view(pruned_graph)
        else:
            tree = pruned_graph
        node_descriptions = {
            n: {
                "api_desc": pruned_graph.nodes[n]['prop'].get('description', "Function is not defined") if n.startswith('api')
                else pruned_graph.nodes[n]['prop'].get('description', "Parameter is not defined"),
                "param_type": None if n.startswith('api')
                else pruned_graph.nodes[n]['prop'].get('param_type', "str")
            } for n in pruned_graph.nodes()
        }

        for node in tree.nodes():
            edge = tree.get_edge_data(node, api)
            tree.nodes[node]['decayed_sim'] = 0 if edge is None else edge['weight'].cast()

        header = f"\nTool dependency tree structure（root node：{api}）\n{'-' * 40}"
        body = '\n'.join(build_layer(api, []))
        api_structures.append(f"{header}\n{body}")

    return '\n'.join(api_structures) + "\n\nExplanation：●=node, ├──=branch, └──=final branch"
