import os
import sys

import networkx as nx
import numpy as np
import torch
import torch_geometric
import torch_geometric.utils
from scipy.spatial import Delaunay
from torch_geometric.data import Batch, Data
from tqdm import tqdm

PROJECT_SRC_PATH = ".."

DATASET_ROOT_DIR = "./data/my_planar"

TRAIN_GRAPHS = 128
VAL_GRAPHS = 32
TEST_GRAPHS = 40
NUM_NODES = 128                                              

try:
    if not os.path.isdir(PROJECT_SRC_PATH):
        raise FileNotFoundError
    sys.path.append(PROJECT_SRC_PATH)
    print("Loaded SpectreGraphDataset from project src path.")
except (ImportError, FileNotFoundError):
    print(
        "Error: could not import datasets.spectre_dataset from PROJECT_SRC_PATH. "
        "Please set PROJECT_SRC_PATH to the repo's src directory."
    )
    raise SystemExit(1)


def generate_connected_planar_graph(num_nodes: int) -> nx.Graph:
    while True:
        pos = {i: (np.random.rand(), np.random.rand()) for i in range(num_nodes)}
        points = np.array(list(pos.values()))
        delaunay_tri = Delaunay(points)
        graph = nx.Graph()
        for simplex in delaunay_tri.simplices:
            nx.add_cycle(graph, simplex)
        graph.add_nodes_from(range(num_nodes))
        if nx.is_connected(graph) and len(graph.nodes) == num_nodes:
            return graph


def convert_nx_to_adjacency_matrix(graph: nx.Graph) -> torch.Tensor:
    return torch.Tensor(nx.to_numpy_array(graph))


def generate_and_save_datasets(
    node_counts,
    train_size=TRAIN_GRAPHS,
    val_size=VAL_GRAPHS,
    test_size=TEST_GRAPHS,
    root_dir=DATASET_ROOT_DIR,
):
    raw_dir = os.path.join(root_dir, "raw")
    os.makedirs(raw_dir, exist_ok=True)
    print(f"Using raw data directory: {raw_dir}/")

    datasets = {
        "train": train_size,
        "val": val_size,
        "test": test_size,
    }

    for dataset_name, dataset_size in datasets.items():
        print(f"\n>>> Generating {dataset_name} set...")
        adjacency_matrices = []

        for n_nodes in node_counts:
            print(f"  Generating {dataset_name} graphs with {n_nodes} nodes...")

            for _ in tqdm(range(dataset_size), desc=f"  {n_nodes}-node graphs"):
                nx_graph = generate_connected_planar_graph(n_nodes)
                adj_matrix = convert_nx_to_adjacency_matrix(nx_graph)
                adjacency_matrices.append(adj_matrix)

        file_path = os.path.join(raw_dir, f"{dataset_name}.pt")
        torch.save(adjacency_matrices, file_path)
        print(f"  Saved {len(adjacency_matrices)} graphs to '{file_path}'.")


if __name__ == "__main__":
    node_list = [NUM_NODES]

    print("Generating planar datasets (compatible with spectre_dataset.py)...")
    generate_and_save_datasets(
        node_counts=node_list,
        train_size=TRAIN_GRAPHS,
        val_size=VAL_GRAPHS,
        test_size=TEST_GRAPHS,
    )

    print("\nDone.")
    print("Summary:")
    print(f"  - train graphs: {len(node_list) * TRAIN_GRAPHS}")
    print(f"  - val graphs:   {len(node_list) * VAL_GRAPHS}")
    print(f"  - test graphs:  {len(node_list) * TEST_GRAPHS}")
    print("Note: update spectre_dataset.py to point to this dataset if needed.")
