import networkx as nx
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from typing import Tuple, Dict, Union, TYPE_CHECKING

from absint_ai.Environment.types.Type import *

if TYPE_CHECKING:
    from absint_ai.Environment.Environment import (
        Environment,
    )  # Import the class only for type checking


def address_to_graph(
    addr: Address, seen_addrs: list = None, env: "Environment" = None
) -> Tuple[nx.DiGraph, str]:
    if seen_addrs is None:
        seen_addrs = []
    if addr in seen_addrs:
        return nx.DiGraph(), ""
    seen_addrs.append(addr)
    G = nx.DiGraph()
    document_record_result = env.lookup("document")
    element_record_result = env.lookup("element")
    if (
        addr in document_record_result.get_all_values()
        or addr in element_record_result.get_all_values()
    ):
        return G, ""
    if addr.get_addr_type() == "concrete":
        object_key = f"C({addr.get_value()})"
        if env.is_object(addr):
            G.add_node(object_key, type="concrete object", object_type="concrete")
        elif env.is_function(addr):
            G.add_node(object_key, type="concrete function", object_type="concrete")
        elif env.is_heap_frame(addr):
            G.add_node(object_key, type="concrete heap frame", object_type="concrete")
        elif env.is_class(addr):
            G.add_node(object_key, type="concrete class", object_type="concrete")
        else:
            raise Exception(f"Unknown object type for {addr}")
        heap_val = env.concrete_heap.get(addr)
    elif addr.get_addr_type() == "abstract":
        object_key = f"A({addr.get_value()})"
        if env.is_object(addr):
            G.add_node(object_key, type="abstract object", object_type="abstract")
        elif env.is_function(addr):
            allocation_site = env.get_meta(addr)["allocation_site"].split("_")[-4:]
            G.add_node(
                object_key,
                type="abstract function",
                object_type="abstract",
                id="_".join(allocation_site),
            )
        elif env.is_heap_frame(addr):
            allocation_site = (
                env.get_meta(addr)["allocation_site"].split("_")[-5:]
                if "allocation_site" in env.get_meta(addr)
                and env.get_meta(addr)["allocation_site"]
                else []
            )
            G.add_node(
                object_key,
                type="abstract heap frame",
                object_type="abstract",
                id="_".join(allocation_site),
            )
        elif env.is_builtin(addr):
            G.add_node(object_key, type="builtin function", object_type="abstract")
        elif env.is_class(addr):
            G.add_node(object_key, type="abstract class", object_type="abstract")
        else:
            raise Exception(f"Unknown object type for {addr}")
        heap_val = env.abstract_heap.get(addr)
    else:
        raise Exception(f"Unknown address type {addr.get_addr_type()}")
    for field_key in heap_val:
        if field_key == "__proto__":
            continue
        if field_key == "__meta__":
            if "__parent__" not in heap_val[field_key]:
                continue
            node_key = f"{object_key}_par"
        elif isinstance(field_key, Type):
            node_key = f"{object_key}_{field_key.get_value()}"
        else:
            node_key = f"{object_key}_{field_key}"
        G.add_node(node_key, type="field", values=[])
        G.add_edge(object_key, node_key)
        if field_key == "__meta__":
            if "__parent__" in heap_val[field_key]:
                parent_addr = heap_val[field_key]["__parent__"]
                if parent_addr:
                    if parent_addr.get_addr_type() == "concrete":
                        parent_node_id = f"C({parent_addr.get_value()})"
                    else:
                        parent_node_id = f"A({parent_addr.get_value()})"
                    G.add_edge(node_key, parent_node_id)
            continue

        possible_values = heap_val[field_key]

        for value in possible_values.get_all_values():
            if value == addr:
                G.nodes[node_key]["values"].append(value)
                G.add_edge(node_key, object_key)
                continue
            if env.is_object(value):
                if value.get_addr_type() == "concrete":
                    node_name = f"C({value.get_value()})"
                elif value.get_addr_type() == "abstract":
                    node_name = f"A({value.get_value()})"
                if node_name in G.nodes:
                    G.add_edge(node_key, node_name)
                else:
                    # logger.info(f"Adding node {node_name} to {G.nodes}")
                    subgraph, node_name = address_to_graph(
                        value, seen_addrs=seen_addrs, env=env
                    )
                    if node_name:
                        G = nx.compose(G, subgraph)
                        G.add_edge(node_key, node_name)
            elif env.is_function(value):
                var_attrs = {}

                if value.get_addr_type() == "concrete":
                    G.nodes[node_key]["values"].append(
                        f"(function, C{value.get_value()})"
                    )
                else:
                    G.nodes[node_key]["values"].append(
                        f"(function, A{value.get_value()})"
                    )
                var_attrs["type"] = "function"
                nx.set_node_attributes(G, {node_key: var_attrs})
            else:
                G.nodes[node_key]["values"].append(value.get_value())

    return G, object_key


def visualize_heaps(env: "Environment") -> nx.DiGraph:
    G = nx.DiGraph()
    for concrete_address in env.concrete_heap.addresses():
        concrete_graph, _ = address_to_graph(concrete_address, env=env)
        if len(concrete_graph.edges) > 0:
            G = nx.compose(G, concrete_graph)
    for abstract_address in env.abstract_heap.addresses():
        abstract_graph, _ = address_to_graph(abstract_address, env=env)
        if len(abstract_graph.edges) > 0:
            G = nx.compose(G, abstract_graph)

    return G


def visualize_stack(env) -> list:
    stack_graph = nx.DiGraph()
    for module_name, stack in env.stack.items():
        for i, stack_frame in enumerate(stack):
            stack_frame_id = f"{module_name} stack frame: {i}"
            stack_graph.add_node(stack_frame_id)
            if i > 0:
                stack_graph.add_edge(
                    f"{module_name} stack frame: {i-1}", stack_frame_id
                )
            else:
                if module_name != "global":
                    stack_graph.add_edge("global stack frame: 0", stack_frame_id)
            stack_frame_attrs = {}
            for var_name in stack_frame.get_variable_names():
                values = []
                for value in stack_frame.get_variable(var_name).get_all_values():
                    if env.is_object(value):
                        if value.get_addr_type() == "concrete":
                            values.append(f"C{value.get_value()}")
                        else:
                            values.append(f"A{value.get_value()}")
                    elif env.is_function(value):
                        if value.get_addr_type() == "concrete":
                            values.append(f"(function, C{value.get_value()})")
                        else:
                            values.append(f"(function, A{value.get_value()})")
                    else:
                        values.append(value.get_value())
                stack_frame_attrs[var_name] = values
            stack_frame_attrs["return values"] = []
            for return_value in stack_frame.get_return_values():
                if isinstance(return_value, Address):
                    if return_value.get_addr_type() == "concrete":
                        stack_frame_attrs["return values"].append(
                            f"C{return_value.get_value()}"
                        )
                    else:
                        stack_frame_attrs["return values"].append(
                            f"A{return_value.get_value()}"
                        )
                else:
                    stack_frame_attrs["return values"].append(return_value.get_value())

            heap_frame_addr = stack_frame.get_heap_frame_address()
            if heap_frame_addr.get_addr_type() == "concrete":
                heap_frame_id = f"C{heap_frame_addr.get_value()}"
            else:
                heap_frame_id = f"A{heap_frame_addr.get_value()}"
            stack_frame_attrs["heap frame"] = heap_frame_id

            nx.set_node_attributes(stack_graph, {stack_frame_id: stack_frame_attrs})
    return stack_graph


def visualize(
    prev_statement: str = "",
    cur_statement: str = "",
    env: "Environment" = None,
) -> None:
    G = visualize_stack(env)
    fig, axs = plt.subplots(1, 2)
    fig.set_size_inches(20, 10)
    ax = axs[0]  # type: ignore
    ax2 = axs[1]  # type: ignore
    ax.set_xlabel("Stack")
    ax2.set_xlabel("Heap")
    pos = nx.drawing.nx_agraph.graphviz_layout(G, prog="dot", args="-Grankdir=BT")
    nodes = nx.draw_networkx_nodes(G, pos=pos, ax=ax, node_shape="s", node_size=1000)
    nx.draw_networkx_edges(G, pos=pos, ax=ax)
    H = visualize_heaps(env)
    pos2 = nx.drawing.nx_agraph.graphviz_layout(H, prog="dot", args="-Grankdir=BT")

    mapping = {}
    mapping["concrete object"] = "red"
    mapping["concrete heap frame"] = "red"
    mapping["concrete function"] = "red"
    mapping["concrete class"] = "red"
    mapping["constant"] = "orange"
    mapping["field"] = "yellow"
    mapping["function"] = "blue"
    mapping["abstract object"] = "green"
    mapping["abstract heap frame"] = "green"
    mapping["abstract function"] = "green"
    mapping["builtin function"] = "green"
    mapping["abstract class"] = "green"
    concrete_object_patch = mpatches.Patch(color="red", label="concrete address")
    constant_patch = mpatches.Patch(color="orange", label="constant")
    field_patch = mpatches.Patch(color="yellow", label="field")
    function_patch = mpatches.Patch(color="blue", label="function")
    abstract_object_patch = mpatches.Patch(color="green", label="abstract address")
    plt.legend(
        handles=[
            concrete_object_patch,
            constant_patch,
            field_patch,
            function_patch,
            abstract_object_patch,
        ]
    )

    # pos = nx.spring_layout(G)
    ATTRIBUTE_NAME = "type"
    COLOR_SCHEME = "Set2"  # try plt.cm.Blues if your attribute is a continuous variable
    for node in H.nodes:
        if "type" not in H.nodes[node]:
            logger.info(f"HERE {node} {H.nodes[node]}")
    colors = [mapping[H.nodes[node][ATTRIBUTE_NAME]] for node in list(H.nodes())]
    labels = {}
    for node in H.nodes():
        if H.nodes[node]["type"] == "constant" or H.nodes[node]["type"] == "field":
            labels[node] = node.split("_")[-1]
        else:
            labels[node] = node
    nodes2 = nx.draw_networkx_nodes(H, pos=pos2, ax=ax2, node_color=colors)
    nx.draw_networkx_labels(H, pos=pos2, labels=labels, ax=ax2, font_size=6)
    nx.draw_networkx_edges(H, pos=pos2, ax=ax2)

    annot = ax.annotate(
        "",
        xy=(0, 0),
        xytext=(20, 20),
        textcoords="offset points",
        bbox=dict(boxstyle="round", fc="w"),
        arrowprops=dict(arrowstyle="->"),
        size=10,
    )
    annot2 = ax2.annotate(
        "",
        xy=(0, 0),
        xytext=(20, 20),
        textcoords="offset points",
        bbox=dict(boxstyle="round", fc="w"),
        arrowprops=dict(arrowstyle="->"),
    )
    # ax2.set_axis_off()
    annot.set_visible(False)
    annot2.set_visible(False)

    def update_annot(ind, graph1: bool) -> None:  # type: ignore
        if graph1:
            idx_to_node_dict = {}
            for idx, node in enumerate(G.nodes):
                idx_to_node_dict[idx] = node
            node_idx = ind["ind"][0]
            node = idx_to_node_dict[node_idx]
            xy = pos[node]
            annot.xy = xy
            node_attr = {"node": "_".join(node.split("_")[-3:])}
            node_attr.update(G.nodes[node])
            text = "\n".join(f"{k}: {v}" for k, v in node_attr.items())
            annot.set_text(text)
        else:
            # Add annotations to the heap visualization
            idx_to_node_dict = {}
            for idx, node in enumerate(H.nodes):
                idx_to_node_dict[idx] = node
            node_idx = ind["ind"][0]
            node = idx_to_node_dict[node_idx]
            xy = pos2[node]
            annot2.xy = xy
            node_attr = {"node": node}
            node_attr.update(H.nodes[node])
            text = "\n".join(f"{k}: {v}" for k, v in node_attr.items())
            annot2.set_text(text)

    def hover(event) -> None:  # type: ignore
        vis = annot.get_visible()
        if event.inaxes == ax:
            cont, ind = nodes.contains(event)
            if cont:
                update_annot(ind, graph1=True)
                annot.set_visible(True)
                fig.canvas.draw_idle()
            else:
                if vis:
                    annot.set_visible(False)
                    fig.canvas.draw_idle()
        vis2 = annot2.get_visible()
        if event.inaxes == ax2:
            cont, ind = nodes2.contains(event)
            if cont:
                update_annot(ind, graph1=False)
                annot2.set_visible(True)
                fig.canvas.draw_idle()
            else:
                if vis2:
                    annot2.set_visible(False)
                    fig.canvas.draw_idle()

    fig.canvas.mpl_connect("motion_notify_event", hover)
    text = fig.text(
        0.50,
        0.02,
        f"Prev line: {prev_statement}, Current line: {cur_statement}",
        horizontalalignment="center",
        wrap=True,
    )
    plt.show()
