from typing import Literal, Optional

import joblib
import networkx as nx
import numpy as np
from torch_geometric.data import Batch
from torch_geometric.utils import from_networkx
from tqdm.rich import tqdm

from polygraph.datasets.base import ProceduralGraphDataset, SplitGraphDataset
from polygraph.datasets.base.graph_storage import GraphStorage


def is_planar_graph(graph: nx.Graph) -> bool:
    return nx.is_connected(graph) and nx.is_planar(graph)


class ProceduralPlanarGraphDataset(ProceduralGraphDataset):
    """Procedural version of [`PlanarGraphDataset`][polygraph.datasets.PlanarGraphDataset].

    Graphs are generated by sampling random node positions in the unit square and producing a Delaunay triangulation.

    Args:
        split: Split to load.
        num_graphs: Number of graphs to generate for this split.
        n_nodes: Number of nodes in the graphs.
        seed: Seed for the random number generator.
        memmap: Whether to use memory mapping for the dataset.
    """

    def __init__(
        self,
        split: Literal["train", "val", "test"],
        num_graphs: int,
        n_nodes: int = 64,
        seed: int = 42,
        memmap: bool = False,
        show_generation_progress: bool = False,
    ):
        config_hash: str = joblib.hash(  # pyright: ignore
            (num_graphs, n_nodes, seed, split), hash_name="md5"
        )
        self._rng = np.random.default_rng(
            int.from_bytes(config_hash.encode(), "big")
        )
        self._num_graphs = num_graphs
        self._n_nodes = n_nodes
        super().__init__(
            split,
            config_hash,
            memmap,
            show_generation_progress,
        )

    def generate_data(self) -> GraphStorage:
        graphs = [
            from_networkx(self._random_planar())
            for _ in tqdm(
                range(self._num_graphs),
                desc="Generating planar graphs",
                disable=not self.show_generation_progress,
            )
        ]
        return GraphStorage.from_pyg_batch(Batch.from_data_list(graphs))

    def is_valid(self, graph: nx.Graph) -> bool:
        """Check if a graph is valid (connected and planar)."""
        return is_planar_graph(graph)

    def _random_planar(self):
        import scipy

        node_locations = self._rng.uniform(size=(self._n_nodes, 2))
        # Create the delaunay triangulation
        triangulation = scipy.spatial.Delaunay(node_locations)
        graph = nx.Graph()
        graph.add_nodes_from(range(self._n_nodes))
        graph.add_edges_from(
            (s[i], s[j])
            for s in triangulation.simplices
            for i in range(3)
            for j in range(3)
            if i < j
        )
        return graph


class PlanarGraphDataset(SplitGraphDataset):
    """Planar graph dataset proposed by Martinkus et al. [1].

    Each graph consists of 64 nodes and is connected and planar.

    {{ plot_first_k_graphs("PlanarGraphDataset", "train", 3) }}


    Dataset statistics:

    {{ summary_md_table("PlanarGraphDataset", ["train", "val", "test"]) }}

    References:
        [1] Martinkus, K., Loukas, A., Perraudin, N., & Wattenhofer, R. (2022).
            [SPECTRE: Spectral Conditioning Helps to Overcome the Expressivity Limits
            of One-shot Graph Generators](https://arxiv.org/abs/2204.01613). In Proceedings of the 39th International
            Conference on Machine Learning (ICML).
    """

    _URL_FOR_SPLIT = {
        "train": "https://sandbox.zenodo.org/records/332447/files/planar_train.pt?download=1",
        "val": "https://sandbox.zenodo.org/records/332447/files/planar_val.pt?download=1",
        "test": "https://sandbox.zenodo.org/records/332447/files/planar_test.pt?download=1",
    }

    _HASH_FOR_SPLIT = {
        "train": "edc2630954a23b1cf6a549d43a95e359",
        "val": "56a6d569407e47a93f6febe56ec07843",
        "test": "6135c071784a5efe48ad5a0a30a0028c",
    }

    def url_for_split(self, split: str):
        return self._URL_FOR_SPLIT[split]

    def is_valid(self, graph: nx.Graph) -> bool:
        """Check whether graph is connnected and planar.

        Args:
            graph: NetworkX graph to check.

        Returns:
            bool: True if the graph is connected and planar, False otherwise.
        """
        return is_planar_graph(graph)

    def hash_for_split(self, split: str) -> Optional[str]:
        return self._HASH_FOR_SPLIT[split]



class PlanarLGraphDataset(ProceduralPlanarGraphDataset):
    def __init__(
        self,
        split: Literal["train", "val", "test"],
        num_graphs: int,
        seed: int = 42,
        memmap: bool = False,
        show_generation_progress: bool = False,
    ):
        if split == "train":
            num_graphs = 8192
        elif split == "val":
            num_graphs = 4096
        elif split == "test":
            num_graphs = 4096

        super().__init__(
            split,
            num_graphs,
            seed,
            memmap,
            show_generation_progress,
        )