# postponed evaluation of annotations (helps avoid circular import)
from __future__ import annotations
import math
from typing import TYPE_CHECKING

import networkx as nx
from networkx import Graph as GraphNX
import torch_geometric
from torch_geometric.data import Data

if TYPE_CHECKING:  # avoid circular import
    from ._config import TSPGraphConfig


def generate_tsp_graph(config: TSPGraphConfig) -> Data:
    """
    Returns a complete graph with `config.num_nodes` nodes.
    Each node is assigned a position in the unit square uniformly at random as node features.
    The edge weights are the Euclidean distances between the nodes.

    The graph is undirected and without self-loops.
    The NetworkX `Graph` is attached in the `networkx` field.

    This procedure is almost identical to the graph generator "Random Uniform Euclidean Instances" described in
    Johnson et al., "Experimental Analysis of Heuristics for the STSP",
    in The Traveling Salesman Problem and its Variations, 2002.
    (The only difference is that that method uses a larger square instead of the unit square)

    Graphs generated this way have been used for evaluation in
    - Dai et al., "Learning Combinatorial Optimization Algorithms over Graphs", NeurIPS, 2017
    - Joshi et al., "An Efficient Graph Convolutional Network Technique for the Travelling Salesman Problem",
      arXiv preprint, 2019
    - Kool et al., "Attention, Learn to Solve Routing Problems!", ICLR, 2019
    """
    # skipping validation, because every config is valid

    graph_nx = generate_graph_nx(config.num_nodes)

    graph = torch_geometric.utils.from_networkx(graph_nx, group_node_attrs=["pos"], group_edge_attrs=["weight"])
    graph.networkx = graph_nx

    # flatten edge_attr to be consistent with the noigen graph generator (also, nx.draw() crashes if i don't do this)
    graph.edge_attr = graph.edge_attr.flatten()

    return graph


def generate_graph_nx(num_nodes: int) -> GraphNX:
    # radius > sqrt(2) means that all nodes will be connected
    graph_nx = nx.random_geometric_graph(num_nodes, radius=10)

    # add distances as edge weights (not sure why NetworkX can't do this)
    for edge in graph_nx.edges:
        pos_1 = graph_nx.nodes[edge[0]]["pos"]
        pos_2 = graph_nx.nodes[edge[1]]["pos"]
        graph_nx.edges[edge]["weight"] = math.dist(pos_1, pos_2)

    return graph_nx
