from dataclasses import dataclass
from typing import Optional

from ruamel.yaml import YAML, yaml_object
from typing_extensions import override

from ..._graph_generator_config import GraphGeneratorConfig
from ._generator import generate_simple_graph_plus


@yaml_object(YAML())
@dataclass()
class SimpleGraphPlusConfig(GraphGeneratorConfig):
    """
    Fields:

    - `num_nodes`: Total number of nodes in the graph.
    - `num_clusters`: The number of clusters.
                      Nodes within the same cluster are intended to be more strongly connected.
    - `min_edges_between_clusters`, `max_edges_between_clusters`: The number of edges generated between the clusters
                                                                  will be drawn from this interval (inclusive).
                                                                  The actual number of edges between clusters might be
                                                                  lower than `min_edges_between_clusters`!
    - `edge_weight_range`: If not `None`, edge weights will be drawn uniformly at random from this range.
                           The second value must be larger than the first one.
                           If this is `None`, all edges will have weight 1.
    """

    num_nodes: int
    num_clusters: int
    min_edges_between_clusters: int
    max_edges_between_clusters: int
    edge_weight_range: Optional[tuple[float, float]] = None

    # overriding abstract method (no type hint on purpose)
    generate_graph = generate_simple_graph_plus

    @override
    def validate(self):
        min_cluster_size = self.max_edges_between_clusters + 2
        if self.num_clusters * min_cluster_size > self.num_nodes:
            raise ValueError(
                f"num_nodes ({self.num_nodes}) cannot be smaller than "
                "num_clusters * (config.max_edges_between_clusters + 2) "
                f"({self.num_clusters} * {min_cluster_size} = {self.num_clusters * min_cluster_size})"
            )

        if self.edge_weight_range is not None and self.edge_weight_range[0] >= self.edge_weight_range[1]:
            raise ValueError(
                "First value of edge_weight_range must be smaller than the second one, but was "
                f"{self.edge_weight_range[0]} >= {self.edge_weight_range[1]}"
            )

        if self.min_edges_between_clusters <= 1:
            print(
                "WARNING: The SimpleGraphPlus graph generator doesn't guarantee that the provided ground truth minimum",
                "cut is the only minimum cut. However, if min_edges_between_clusters is > 1, it is sufficiently likely",
                "that the minimum cut is unique. You are seeing this warning because min_edges_between_clusters is",
                self.min_edges_between_clusters,
            )


if __name__ == "__main__":
    from visualise_graph import visualise_graph
    config = SimpleGraphPlusConfig(
        num_nodes=40,
        num_clusters=2,
        min_edges_between_clusters=3,
        max_edges_between_clusters=18,
    )
    graph = config.generate_graph()
    print("Edges between clusters:", int(graph.y.sum().item()) // 2)
    visualise_graph(graph)
