from typing import Optional

from dash import Dash, html
from dash_cytoscape import Cytoscape
from torch import Tensor
from torch_geometric.data import Data


def show_graph_cytoscape(graph: Data, edge_colours: Optional[Tensor], edge_widths: Optional[Tensor]):
    """
    Displays the graph using Cytoscape.
    A URL to a locally hosted instance of Dash is printed, where the graph is visualised.

    Parameters:

    - `graph`: The graph to visualise.
    - `edge_colours`: The data that should be displayed as edge colours.
                      Values should be between 0 and 1, inclusive.
                      Shape `[graph.num_edges]`.
    - `edge_widths`: The data that should be displayed as the thickness of the edges.
                     Values should be between 0 and 1, inclusive.
                     Shape `[graph.num_edges]`.
    """
    app = Dash("Graph Visualisation")

    nodes = _get_nodes(graph)
    edges = _get_edges(graph, edge_colours, edge_widths)

    # layouts: "random", "preset", "circle", "concentric", "grid", "breadthfirst", "cose"
    # with `dash_cytoscape.load_extra_layouts()`, more layouts become available:
    #     "close-bilkent", "cola", "euler", "spread", "dagre", "klay"
    cytoscape = Cytoscape(
        elements=nodes + edges,
        layout={"name": "cose"},  # "cose" is the best default layout, but still not great
        style={
            # fill the entire viewport
            "height": "calc(100vh - 16px)",
            "width": "calc(100vw - 16px)",
        },
        stylesheet=[
            {
                "selector": "node",
                "style": {
                    "background-color": "white",
                    "border-color": "black",
                    "border-width": "0.3",
                    "width": "5",
                    "height": "5",
                },
            },
            {
                "selector": "edge",
                "style": {
                    "line-color": "black" if edge_colours is None else "mapData(colour, 0, 1, black, red)",
                    "width": "0.3" if edge_widths is None else "mapData(width, 0, 1, 0, 1.5)",
                },
            },
        ],
    )
    app.layout = html.Div([cytoscape])

    app.run_server(debug=True)


def _get_nodes(graph: Data) -> list[dict[str, str | dict[str, str]]]:
    """
    Converts the graph's nodes into the format that Cytoscape expects.
    """
    return [
        {"group": "nodes", "data": {"id": str(node_id)}}
        for node_id in range(graph.num_nodes)
    ]


def _get_edges(
    graph: Data,
    edge_colours: Optional[Tensor],
    edge_widths: Optional[Tensor]
) -> list[dict[str, str | dict[str, str]]]:
    """
    Converts the graph's edges into the format that Cytoscape expects.
    """
    edges = [
        {"group": "edges", "data": {"source": str(node_a), "target": str(node_b)}}
        for node_a, node_b in zip(*graph.edge_index.tolist())
    ]

    if edge_colours is not None:
        for i, edge_colour in enumerate(edge_colours):
            edges[i]["data"]["colour"] = edge_colour

    if edge_widths is not None:
        for i, edge_width in enumerate(edge_widths):
            edges[i]["data"]["width"] = edge_width

    return edges
