from dataclasses import dataclass

from ruamel.yaml import YAML, yaml_object
from typing_extensions import override

from ..._graph_generator_config import GraphGeneratorConfig
from ._generator import noigen


@yaml_object(YAML())
@dataclass()
class NoigenConfig(GraphGeneratorConfig):
    """
    Fields:

    - `num_nodes`: The number of nodes in the graph.
                   Must be at least 2.
    - `edge_density`: The density of edges as a fraction.
                      This means the number of edges is always
                      `num_edges = int(num_nodes * (num_nodes - 1) * edge_density / 2)`.
                      Must be inside the interval `[2 / num_nodes, 1]`,
                      so that there are enough edges for the graph to be connected.
    - `num_clusters`: The number of clusters.
                      Nodes within the same cluster are intended to be more strongly connected.
    - `inter_cluster_edge_scale`: The weights of edges between clusters are scaled down by this factor.
                                  Must be inside the interval `(0, 1]`.
    - `inter_cluster_edge_ratio`: The final graph will have exactly `int(num_edges * inter_cluster_edge_ratio)` edges
                                  between clusters, and the remaining edges will be between nodes of the same cluster.
                                  Must be inside the interval `(0, 1)`.
                                  If `edge_density` is too high, this value must be chosen carefully or the graph
                                  generation can get stuck in an infinite loop.
                                  Extreme values can also lead to the graph not being connected.
    """

    num_nodes: int
    edge_density: float
    num_clusters: int
    inter_cluster_edge_scale: float
    inter_cluster_edge_ratio: float

    # overriding abstract method (no type hint on purpose)
    generate_graph = noigen

    @override
    def validate(self):
        if self.num_nodes < 2:
            raise ValueError("num_nodes must be at least 2, but was", self.num_nodes)

        if self.num_edges < self.num_nodes - 1:
            raise ValueError(
                "edge_density must be at least 2 / num_nodes, but was "
                f"{self.edge_density} < 2 / {self.num_nodes} = {2 / self.num_nodes}. "
                "This is to ensure that there are enough edges for the graph to be connected"
            )

        if self.edge_density > 1:
            raise ValueError("edge_density must be at most 1, but was", self.edge_density)

        if self.num_clusters < 1 or self.num_clusters > self.num_nodes:
            raise ValueError("num_clusters must be between 1 and num_nodes (inclusive), but was", self.num_clusters)

        if self.inter_cluster_edge_scale <= 0 or self.inter_cluster_edge_scale > 1:
            raise ValueError(
                "intra_cluster_edge_scale must be 0 < intra_cluster_edge_scale <= 1, but was",
                self.inter_cluster_edge_scale
            )

        if self.inter_cluster_edge_ratio <= 0 or self.inter_cluster_edge_ratio >= 1:
            raise ValueError(
                "inter_cluster_edge_ratio must be 0 < inter_cluster_edge_ratio < 1, but was",
                self.inter_cluster_edge_ratio
            )

    @property
    def num_edges(self) -> int:
        """
        The number of edges in the generated graphs (all graphs will have the same number of edges).
        """
        return int(self.num_nodes * (self.num_nodes - 1) * self.edge_density / 2)


if __name__ == "__main__":
    from karger import karger_stein_repeated, MetaGraph
    from visualise_graph import visualise_graph
    config = NoigenConfig(
        num_nodes=30,
        edge_density=0.5,
        num_clusters=2,
        inter_cluster_edge_scale=0.2,
        inter_cluster_edge_ratio=0.1,
    )
    graph = config.generate_graph()
    graph.meta_graph = MetaGraph.from_pyg(graph)
    ground_truth_solution = karger_stein_repeated(graph, 2, 100)
    visualise_graph(graph, edge_features=ground_truth_solution)
