# postponed evaluation of annotations (helps avoid circular import)
from __future__ import annotations
import random
from typing import TYPE_CHECKING

import torch
from torch import Tensor
import torch_geometric
from torch_geometric.data import Data

from karger import MetaGraph

if TYPE_CHECKING:  # avoid circular import
    from ._config import SimpleGraphConfig


def generate_simple_graph(config: SimpleGraphConfig) -> Data:
    """
    Returns a graph with `config.num_nodes` nodes, consisting of `config.num_clusters` complete subgraphs of random
    sizes.
    A random number of additional edges between `config.min_edges_between_clusters` and
    `config.max_edges_between_clusters` is also inserted.
    These edges always connect two distinct clusters.

    The edge ground truth labels are 1 if the edge connects nodes of different clusters, 0 if it connects nodes of the
    same cluster.
    By the construction of the graph, this is the unique optimal solution to the minimum k-cut problem on this graph,
    where k equals `config.num_clusters`.
    """
    config.validate()

    # clusters should at least have size max_edges_between_clusters + 2 so that cutting the edges we inserted
    # between the clusters is always the (unique) optimal solution.
    # Imagine a situation where we randomly insert all edges between the same two clusters.
    # If the smallest cluster in the graph is at most one node larger than the number of edges inserted between
    # clusters, cutting out a single node from this smallest cluster is just as cheap (or cheaper) than the
    # intended solution.
    cluster_sizes = _generate_cluster_sizes(config)

    graph = []
    index_offsets = [cluster_sizes[:i].sum() for i in range(cluster_sizes.size(0))]

    for cluster_size, index_offset in zip(cluster_sizes, index_offsets):
        subgraph = _create_complete_graph(cluster_size)
        subgraph += index_offset
        graph.append(subgraph)

    num_edges_between_clusters = random.randint(config.min_edges_between_clusters, config.max_edges_between_clusters)
    edges_between_clusters = _generate_edges_between_clusters(num_edges_between_clusters, cluster_sizes, index_offsets)

    graph = torch.cat(graph, dim=1)
    ground_truth_edge_labels = torch.cat((
        torch.zeros(graph.size(1)),
        torch.ones(edges_between_clusters.size(1))
    ))
    graph = torch.cat((graph, edges_between_clusters), dim=1)
    graph, ground_truth_edge_labels = \
        torch_geometric.utils.to_undirected(edge_index=graph, edge_attr=ground_truth_edge_labels)
    edge_weights = torch.ones_like(ground_truth_edge_labels)
    graph_pyg = Data(edge_index=graph, num_nodes=config.num_nodes, edge_attr=edge_weights, y=ground_truth_edge_labels)

    # pre-compute MetaGraph to save time during training
    graph_pyg.meta_graph = MetaGraph.from_pyg(graph_pyg)

    return graph_pyg


def _generate_cluster_sizes(config: SimpleGraphConfig) -> Tensor:
    """
    Returns a Tensor of size `[num_clusters]` containing the randomly generated cluster sizes.
    The entries of the Tensor add up to `num_nodes`.
    """
    min_cluster_size = config.max_edges_between_clusters + 2
    cluster_sizes = torch.zeros([config.num_clusters], dtype=torch.int64)

    for i in range(config.num_clusters - 1):
        # this upper limit is there so that there are enough nodes left for the remaining clusters to have the required
        # minimum size
        max_cluster_size = config.num_nodes - cluster_sizes.sum() - (config.num_clusters - i - 1) * min_cluster_size
        cluster_sizes[i] = random.randint(min_cluster_size, max_cluster_size)

    cluster_sizes[-1] = config.num_nodes - cluster_sizes.sum()

    return cluster_sizes


def _create_complete_graph(num_nodes: int) -> Tensor:
    """
    Creates a complete graph (i.e. one where every pair of nodes is connected by an edge).
    Returns a Tensor of size `[2, num_nodes * (num_nodes - 1) / 2]` that contains every edge of the graph,
    where each edge is represented by the ids of the start and end nodes.
    """
    edges = []
    for node_a in range(num_nodes):
        for node_b in range(node_a + 1, num_nodes):
            edges.append([node_a, node_b])
    return torch.tensor(edges, dtype=torch.int64).T


def _generate_edges_between_clusters(
    num_edges_between_clusters: int,
    cluster_sizes: Tensor,
    index_offsets: Tensor,
) -> Tensor:
    """
    Parameters:
    - `num_edges_between_clusters`: The number of edges to add
    - `cluster_sizes`: The number of nodes each cluster contains. Size `[num_clusters]`
    - `index_offsets`: The index of the first node in each cluster. Size `[num_clusters]`

    Returns a Tensor of size `[2, num_edges_between_clusters]` that contains the added edges,
    where each edge is represented by the ids of the start and end nodes.
    Note: In some cases, this might generate fewer than `num_edges_between_clusters` edges, resulting in a smaller
    output Tensor.
    """
    num_clusters = cluster_sizes.size(0)
    edges_between_clusters = []

    for _ in range(num_edges_between_clusters):
        # sample two distinct clusters
        cluster_a, cluster_b = random.sample(range(num_clusters), 2)

        node_a = random.randrange(cluster_sizes[cluster_a]) + index_offsets[cluster_a]
        node_b = random.randrange(cluster_sizes[cluster_b]) + index_offsets[cluster_b]

        if [node_a, node_b] in edges_between_clusters or [node_b, node_a] in edges_between_clusters:
            # edge already exists
            continue

        edges_between_clusters.append([node_a, node_b])

    return torch.tensor(edges_between_clusters, dtype=torch.int64).T
