from typing import Optional

from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, TwoSlopeNorm
import networkx as nx
from torch import Tensor
import torch_geometric
from torch_geometric.data import Data


def show_graph_nx(graph: Data, edge_colours: Optional[Tensor], edge_widths: Optional[Tensor]):
    """
    Displays the graph using networkx.

    Parameters:

    - `graph`: The graph to visualise.
    - `edge_colours`: The data that should be displayed as edge colours.
                      Values should be between 0 and 1, inclusive.
                      Shape `[graph.num_edges]`.
    - `edge_widths`: The data that should be displayed as the thickness of the edges.
                     Values should be between 0 and 1, inclusive.
                     Shape `[graph.num_edges]`.
    """
    if edge_colours is not None:
        # nx.draw()'s edge_color input is silently ignored if it's a pytorch tensor
        edge_colours = edge_colours.numpy()

        # the value 0 should be shown as black, positive values as shades of red, negative values as shades of blue
        # this means that if we have negative values, we need to normalise the edge colours such that 0 is in the middle
        if edge_colours.min() < 0:
            edge_colourmap = LinearSegmentedColormap.from_list("blue_to_black_to_red", ["blue", "black", "red"])
            normalisation = TwoSlopeNorm(vcenter=0)
            edge_colours = normalisation(edge_colours)
        else:
            # note: this assumes that at least one value is 0, which might not be the case (if it isn't, black is not 0)
            edge_colourmap = LinearSegmentedColormap.from_list("black_to_red", ["black", "red"])
    else:
        edge_colourmap = None

    # line widths larger than roughly 10 make the edges occlude each other
    edge_widths = 1 if edge_widths is None else 10 * edge_widths

    graph_nx = torch_geometric.utils.to_networkx(graph)
    nx.draw(
        graph_nx,
        # nodes
        with_labels=True,
        font_size=10,
        node_color="white",
        edgecolors="black",  # draws circles around nodes (not sure why this parameter is called edgecolors)
        pos = graph.x,  # for non-TSP graphs, graph.x is None and the normal layouting is used
        # edges
        arrows=False,  # don't draw arrow heads because the graph is undirected anyways
        edge_color=edge_colours,
        edge_cmap=edge_colourmap,
        width=edge_widths,
    )
    plt.show()
