import networkx as nx
from tqdm.autonotebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
from pyvis.network import Network


def draw_image(data, ax):
    img = data.view(28, 28).cpu().numpy()
    ax.imshow(img, cmap="gray")


def data_graph(
    node_df,
    edge_df,
    dataset=None,
    draw_function=None,
    class_labels=True,
    node_title_formatter=lambda i, row: row["title"] if "title" in row else str(row),
    node_label_formatter=lambda i, row: row["label"] if "label" in row else str(i),
    node_size_formatter=lambda row: row["size"] if "size" in row else 10,
    edge_title_formatter=lambda row: row["title"] if "title" in row else "",
    edge_label_formatter=lambda row: row["label"] if "label" in row else "",
    edge_value_formatter=lambda row: row["value"] if "value" in row else 1,
    max_images=3000,
    max_num_examples=3,
    save_file="./graph.html",
):
    if class_labels is True and dataset is not None:
        class_labels = torch.unique(torch.tensor([dataset[i][1] for i in range(len(dataset))])).tolist()

    G = nx.Graph()
    bar = tqdm(node_df.iterrows(), total=len(node_df), desc="Adding Nodes")
    for i, row in bar:
        if i < max_images:
            num_examples = min(len(row["data"]), max_num_examples) + (class_labels is not False)
            num_rows = np.ceil(np.sqrt(num_examples)).astype(int)
            num_cols = num_examples // num_rows
            fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 10))
            axs = axs.flatten() if num_rows > 1 else [axs]
            for j, ax in enumerate(axs[:-1]):
                ax.axis("equal")
                ax.set_axis_off()
                if j <= num_examples:
                    data = row["data"][j]
                    draw_function(data=data, ax=ax)

            # fig, ax = draw_function(data=dataset.data[row["indices"][0][0]])

            # fig, ax = plt.subplots()
            # plt.margins(0,0)
            # ax.pie(row["class_proportions"], labeldistance=.6, labels = list(range(dataset.num_classes)))
            # ax.set_box_aspect(1)
            # ax.set_axis_off()
            # fig.tight_layout()
            # plt.tight_layout(pad=0)

            if class_labels and "class_proportions" in row:
                axs[-1].pie(row["class_proportions"], labeldistance=0.6, labels=class_labels)
            axs[-1].axis("equal")
            axs[-1].set_axis_off()

            fig.canvas.draw()
            img = Image.frombytes("RGBa", fig.canvas.get_width_height(), fig.canvas.buffer_rgba())
            plt.close(fig)
            img.convert("RGB").save(f"images/{i}.png")

        G.add_node(
            i,
            title=node_title_formatter(i, row),
            label=node_label_formatter(i, row),
            image=f"images/{i}.png",
            shape="image",
            size=node_size_formatter(row),  # 10 * (np.log(row["count"]) + 3)
            **{k: str(v) for k, v in row.items() if k not in ["label", "title", "size", "image", "data"]},
        )
    pbar = tqdm(edge_df.iterrows(), total=len(edge_df), desc="Adding Edges")
    for (A, B), row in pbar:
        G.add_edge(
            A,
            B,
            title=edge_title_formatter(row),
            label=edge_label_formatter(row),
            value=edge_value_formatter(row),
        )
        # G.add_edge(mask_tuple, other_node, weight=bits_different)
        bar.set_postfix({"Nodes": G.number_of_nodes(), "Edges": G.number_of_edges()})
    # G = nx.relabel_nodes(G, {node: str(node) for node in G.nodes}, copy=False)
    print(f"Number of Nodes: {G.number_of_nodes()}\nNumber of Edges: {G.number_of_edges()}")

    nt = Network(height="1000px", width="100%")
    nt.from_nx(G)
    # nt.from_nx(G.subgraph(choices(list(G.nodes), k=300)))
    nt.show_buttons()
    # layout = nx.spring_layout(G)
    # for node in nt.nodes:
    #     node_id = node["id"]
    #     if node_id in layout:
    #         node["x"], node["y"] = layout[node_id][0]*1000, layout[node_id][1]*1000
    # nt.repulsion(node_distance=300, central_gravity=0.2, spring_length=200, spring_strength=0.05)
    nt.toggle_physics(False)
    nt.save_graph(save_file)
