# postponed evaluation of annotations (helps avoid circular import)
from __future__ import annotations
import random
from typing import TYPE_CHECKING

import torch
from torch import Tensor
from torch.distributions.uniform import Uniform
import torch_geometric
from torch_geometric.data import Data

from karger import MetaGraph
# TODO make these functions public, maybe restructure
from .._simple_graph._generator import _generate_cluster_sizes, _generate_edges_between_clusters

if TYPE_CHECKING:  # avoid circular import
    from ._config import SimpleGraphPlusConfig


def generate_simple_graph_plus(config: SimpleGraphPlusConfig) -> Data:
    """
    Returns a graph with `config.num_nodes` nodes.
    The graph consists of `config.num_clusters` clusters.
    A random number of edges between `config.min_edges_between_clusters` and `config.max_edges_between_clusters`
    connect nodes from different clusters.
    Considering only edges within the same cluster, each node has a degree strictly larger than the total weight of
    edges between clusters.
    This avoids minimum cuts that separate a single node from the rest of the graph.

    The edge ground truth labels are 1 if the edge connects nodes of different clusters, 0 if it connects nodes of the
    same cluster.
    """
    config.validate()

    # same as generate_simple_graph
    cluster_sizes = _generate_cluster_sizes(config)
    index_offsets = [cluster_sizes[:i].sum() for i in range(cluster_sizes.size(0))]
    num_edges_between_clusters = random.randint(config.min_edges_between_clusters, config.max_edges_between_clusters)
    edges_between_clusters = _generate_edges_between_clusters(num_edges_between_clusters, cluster_sizes, index_offsets)

    # calculate the sum of weights of edges between clusters, because each node degree must be larger than that
    edges_between_clusters_weights = _generate_edge_weights(config, edges_between_clusters.size(1))
    total_weight_between_clusters = edges_between_clusters_weights.sum()

    # similar to generate_simple_graph
    edges_within_clusters = []
    for cluster_size, index_offset in zip(cluster_sizes, index_offsets):
        subgraph = _generate_cluster(config, cluster_size, degree_lower_limit=total_weight_between_clusters)
        subgraph += index_offset
        edges_within_clusters.append(subgraph)

    edges_within_clusters = torch.cat(edges_within_clusters, dim=1)
    # TODO consider the case with random edge weights
    edges_within_clusters_weights = _generate_edge_weights(config, edges_within_clusters.size(1))

    # similar to generate_simple_graph
    ground_truth_edge_labels = torch.cat((
        torch.zeros(edges_within_clusters.size(1)),
        torch.ones(edges_between_clusters.size(1))
    ))
    edges = torch.cat((edges_within_clusters, edges_between_clusters), dim=1)
    edge_weights = torch.cat((edges_within_clusters_weights, edges_between_clusters_weights))
    edges, (edge_weights, ground_truth_edge_labels) = \
        torch_geometric.utils.to_undirected(edge_index=edges, edge_attr=(edge_weights, ground_truth_edge_labels))
    graph_pyg = Data(edge_index=edges, num_nodes=config.num_nodes, edge_attr=edge_weights, y=ground_truth_edge_labels)

    _sanity_check(graph_pyg, degree_lower_limit=total_weight_between_clusters)

    # pre-compute MetaGraph to save time during training
    graph_pyg.meta_graph = MetaGraph.from_pyg(graph_pyg)

    return graph_pyg


def _generate_edge_weights(config: SimpleGraphPlusConfig, num_edges: int) -> Tensor:
    """
    Generates edge weights according to `config.edge_weight_range`.
    If `config.edge_weight_range` is not `None`, the edge weights are drawn uniformly at random from this range.
    Otherwise, all edge weights are 1.
    Returns a Tensor of size `[num_edges]`.
    """
    if config.edge_weight_range is not None:
        raise NotImplementedError()  # TODO
        return Uniform(*config.edge_weight_range).sample([num_edges])
    else:
        return torch.ones([num_edges])


def _generate_cluster(config: SimpleGraphPlusConfig, num_nodes: int, degree_lower_limit: float) -> Tensor:
    """
    Creates a subgraph in which each node has a degree strictly larger than `degree_lower_limit`.
    Returns a Tensor of size `[2, num_edges]` that contains every edge of the subgraph,
    where each edge is represented by the ids of the start and end nodes.
    `num_edges` is random.
    """
    edges = []

    for node_a in range(num_nodes):
        # TODO consider the case with random edge weights
        node_degree = sum([node_a in edge for edge in edges])

        # TODO need to somehow deal with the case where we run out of potential neighbours and havent reached the
        #      required degree (only relevant if edge weights are random)
        while node_degree <= degree_lower_limit:
            # using rejection sampling, find another node to which there is no edge yet
            node_b = random.randrange(num_nodes)
            if [node_a, node_b] in edges or [node_b, node_a] in edges:
                continue

            # don't generate self-loops
            if node_a == node_b:
                continue

            edges.append([node_a, node_b])
            node_degree += 1  # TODO consider the case with random edge weights

    return torch.tensor(edges, dtype=torch.int64).T


def _sanity_check(graph: Data, degree_lower_limit: float):
    """
    Checks some simple properties of the graph and raises an exception when one of the checks fails:

    - `graph.num_nodes` matches the largest node used in `graph.edge_index`
    - `graph.edge_index`, `graph.edge_attr`, and `graph.y` have the same number of edges
    - All node degrees are strictly larger than `degree_lower_limit`
    """
    assert graph.num_nodes == graph.edge_index.max() + 1

    assert graph.edge_index.size(1) == graph.edge_attr.size(0)
    assert graph.edge_index.size(1) == graph.y.size(0)

    for node in range(graph.num_nodes):
        adjacent_edges_index = graph.edge_index[0, :] == node
        node_degree = graph.edge_attr[adjacent_edges_index].sum()

        if node_degree <= degree_lower_limit:
            raise Exception(
                f"Node degrees must be larger than {degree_lower_limit}, "
                f"but degree of node {node} was {node_degree} <= {degree_lower_limit}")
