from typing import Literal, Optional

import torch
from torch import Tensor
from torch_geometric.data import Data

from ._cytoscape import show_graph_cytoscape
from ._networkx import show_graph_nx


def visualise_graph(
    graph: Data,
    edge_features: Optional[Tensor] = None,
    edge_weight_exponent: float = 1,
    frontend: Literal["cytoscape", "networkx"] = "networkx",
):
    """
    Parameters:

    - `graph`: The graph to draw
    - `edge_features`: The edge data to visualise using colour.
                       If `None`, the graph's ground truth labels (`graph.y`) will be used.
                       Shape `[graph.num_edges]`.
    - `edge_weight_exponent`: This exponent is applied to all edge weights before displaying them as line widths.
                              This is useful if the edge weights differ by orders of magnitude, because it makes the
                              thin lines more visible.
                              A lower value means a stronger effect:
                              A value of 1 means that no change is made, and a value of 0 means that all edges are
                              drawn with the same width.
                              The input is ignored if the graph does not have edge weights
                              (i.e. `graph.edge_attr` is `None`).
                              For example, setting `edge_weight_exponent=-1.2` can be useful to visualise TSP graphs.
    - `frontend`: The frontend used to display the graph.
                  Must be `"cytoscape"` or `"networkx"`.
    """
    edge_colours = _get_edge_colours(graph, edge_features)
    edge_widths = _get_edge_widths(graph, edge_weight_exponent)

    if frontend == "cytoscape":
        show_graph_cytoscape(graph, edge_colours, edge_widths)
    elif frontend == "networkx":
        show_graph_nx(graph, edge_colours, edge_widths)
    else:
        raise ValueError(f'frontend must be "cytoscape" or "networkx", but was "{frontend}"')


def _get_edge_colours(graph: Data, edge_features: Optional[Tensor]) -> Optional[Tensor]:
    """
    Decides which data to visualise as edge colours and informs the user about that choice.

    Returns `None`, or a Tensor of shape `[graph.num_edges]` with values between 0 and 1, inclusive.
    """
    if edge_features is not None:
        print("Using provided edge features")
        return edge_features
    elif graph.y is not None:
        print("Using graph's ground truth labels")
        return graph.y
    else:
        print("Graph has no ground truth labels, and no edge features were provided")
        return None


def _get_edge_widths(graph: Data, edge_weight_exponent: float) -> Optional[Tensor]:
    """
    Converts each edge weight to the width that the edge should be drawn with.

    Returns a Tensor of shape `[graph.num_edges]` with values between 0 and 1, inclusive.
    If all edges have the same weight (or no weight), then `None` is returned.
    """
    if graph.edge_attr is not None:
        if graph.edge_attr.max() == graph.edge_attr.min():
            print("All edges have the same weight. Not displaying weights")
            return None
        else:
            edge_widths = torch.pow(graph.edge_attr, edge_weight_exponent)
            return edge_widths / edge_widths.max()
    else:
        print("Graph does not have edge weights")
        return None


if __name__ == "__main__":
    from data_generation import SimpleGraphConfig
    config = SimpleGraphConfig(num_nodes=20, num_clusters=3, min_edges_between_clusters=4, max_edges_between_clusters=4)
    graph = config.generate_graph()
    visualise_graph(graph)
