# postponed evaluation of annotations (helps avoid circular import)
from __future__ import annotations
import random
from typing import NewType, TypeAlias, TYPE_CHECKING

import torch
import torch_geometric
from torch_geometric.data import Data

from karger import MetaGraph

if TYPE_CHECKING:  # avoid circular import
    from ._config import NoigenConfig


_NodeColour = NewType("_NodeColour", int)

_Edge: TypeAlias = tuple[int, int]


class _Graph:
    node_colours: list[_NodeColour]
    edges: list[_Edge]
    edge_weights: list[float]
    num_inter_cluster_edges_remaining: int
    num_intra_cluster_edges_remaining: int
    inter_cluster_edge_scale: float

    def __init__(self, config: NoigenConfig):
        self.node_colours = [_NodeColour(random.randrange(config.num_clusters)) for _ in range(config.num_nodes)]
        self.edges = []
        self.edge_weights = []
        self.num_inter_cluster_edges_remaining = int(config.num_edges * config.inter_cluster_edge_ratio)
        self.num_intra_cluster_edges_remaining = config.num_edges - self.num_inter_cluster_edges_remaining
        self.inter_cluster_edge_scale = config.inter_cluster_edge_scale

    def add_edge(self, edge: _Edge):
        """
        Checks whether the edge can be added to the graph, then adds the edge if possible.
        Also generates a random edge weight.
        """
        if self._validate_edge(edge):
            self.edges.append(edge)
            self.edge_weights.append(self._generate_edge_weight(edge))

    def _validate_edge(self, edge: _Edge) -> bool:
        """
        Returns `True` if the edge can be added to the graph.
        """
        # avoid self-loops
        if edge[0] == edge[1]:
            return False

        # skip if the edge already exists
        # the reference implementation by chekuri et al. doesn't include this check
        # also check the reversed edge, because the graph is assumed to be undirected
        if edge in self.edges or tuple(reversed(edge)) in self.edges:
            return False

        is_inter_cluster_edge = self.node_colours[edge[0]] != self.node_colours[edge[1]]
        if (is_inter_cluster_edge and self.num_inter_cluster_edges_remaining == 0) \
           or (not is_inter_cluster_edge and self.num_intra_cluster_edges_remaining == 0):
            # we already have enough of this kind of edge
            return False

        if is_inter_cluster_edge:
            self.num_inter_cluster_edges_remaining -= 1
        else:
            self.num_intra_cluster_edges_remaining -= 1

        return True

    def _generate_edge_weight(self, edge: _Edge) -> float:
        """
        Randomly generates the weight of an edge.
        The weight is scaled by a factor of `inter_cluster_edge_scale` if the two nodes have different colours.
        """
        edge_weight = random.uniform(0, 100)

        if self.node_colours[edge[0]] != self.node_colours[edge[1]]:
            edge_weight *= self.inter_cluster_edge_scale

        return edge_weight

    def has_enough_edges(self) -> bool:
        """
        Returns `True` if the graph already has the configured number of edges.
        """
        return self.num_inter_cluster_edges_remaining == 0 and self.num_intra_cluster_edges_remaining == 0

    def to_pyg(self) -> Data:
        """
        Converts this graph to a pytorch geometric `Data` object.
        """
        edges = torch.tensor(self.edges, dtype=torch.int64).T
        edge_weights = torch.tensor(self.edge_weights)
        edges, edge_weights = torch_geometric.utils.to_undirected(edge_index=edges, edge_attr=edge_weights)
        return Data(edge_index=edges, num_nodes=len(self.node_colours), edge_attr=edge_weights)


def noigen(config: NoigenConfig) -> Data:
    """
    Generates a graph for the minimum k-cut problem using the noigen generator.

    # Summary

    The graph is generated in the following manner:

    1. Create a Hamilton path to ensure that the resulting graph is connected
    2. Assign each node to one out of `config.num_clusters` clusters
    3. Add random edges until the specified edge density is reached, while making sure that exactly
       `int(config.num_edges * config.inter_cluster_edge_ratio)` of those edges are between nodes of different clusters
    4. The weight for each edge is drawn uniformly at random from the interval `[0, 100]`.
       If the two nodes were assigned different clusters, then the weight of the edge is scaled down by
       a factor of `config.inter_cluster_edge_scale`

    # Sources

    This is based on the generator first published in
    Nagamochi et al., "Implementing an efficient minimum capacity cut algorithm", Mathematical Programming, 1994.

    This method's implementation is very loosely inspired by the code for the long version of
    Chekuri et al., "Experimental Study of Minimum Cut Algorithms",
    in Proceedings of the Eighth Annual ACM-SIAM Symposium on Discrete Algorithms, 1997.
    The paper's code can be found here: http://www.columbia.edu/~cs2035/codes/cut-src.tar.gz
    """
    config.validate()

    graph = _Graph(config)

    # create hamilton path to ensure the final graph is connected
    # in the reference implementation by chekuri et al. this is a hamilton cycle
    for node_a in range(config.num_nodes - 1):
        node_b = node_a + 1
        graph.add_edge((node_a, node_b))

    # generate remaining edges
    while not graph.has_enough_edges():
        node_a = random.randrange(config.num_nodes)
        node_b = random.randrange(config.num_nodes)
        graph.add_edge((node_a, node_b))

    graph_pyg = graph.to_pyg()
    # pre-compute MetaGraph to save time during training
    graph_pyg.meta_graph = MetaGraph.from_pyg(graph_pyg)
    return graph_pyg
