import os
import sys

import networkx as nx
import torch
import torch_geometric
from tqdm import tqdm

PROJECT_SRC_PATH = ".."

DATASET_ROOT_DIR = "./data/my_tree"

TRAIN_GRAPHS = 128
VAL_GRAPHS = 32
TEST_GRAPHS = 100

try:
    if not os.path.isdir(PROJECT_SRC_PATH):
        raise FileNotFoundError
    sys.path.append(PROJECT_SRC_PATH)
    from datasets.spectre_dataset import SpectreGraphDataset

    print("Loaded SpectreGraphDataset from project src path.")
except (ImportError, FileNotFoundError):
    print(
        "Error: could not import datasets.spectre_dataset from PROJECT_SRC_PATH. "
        "Please set PROJECT_SRC_PATH to the repo's src directory."
    )
    raise SystemExit(1)


def convert_nx_to_pyg_data(graph: nx.Graph) -> torch_geometric.data.Data:
    adj = torch.Tensor(nx.to_numpy_array(graph))
    n = adj.shape[-1]

    x = torch.ones(n, 1, dtype=torch.float)
    y = torch.zeros([1, 0]).float()
    edge_index, _ = torch_geometric.utils.dense_to_sparse(adj)

    edge_attr = torch.zeros(edge_index.shape[-1], 2, dtype=torch.float)
    edge_attr[:, 1] = 1

    num_nodes = torch.tensor(n, dtype=torch.long).view(1)

    return torch_geometric.data.Data(
        x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, n_nodes=num_nodes
    )


def generate_and_process_datasets(
    node_counts,
    train_size=TRAIN_GRAPHS,
    val_size=VAL_GRAPHS,
    test_size=TEST_GRAPHS,
    root_dir=DATASET_ROOT_DIR,
):
    raw_dir = os.path.join(root_dir, "raw")
    os.makedirs(raw_dir, exist_ok=True)
    print(f"Using raw data directory: {raw_dir}/")

    class DummyDataset:
        def collate(self, data_list):
            return torch_geometric.data.Batch.from_data_list(data_list), None

    dummy_dataset = DummyDataset()

    datasets = {
        "train": train_size,
        "val": val_size,
        "test": test_size,
    }

    for dataset_name, dataset_size in datasets.items():
        print(f"\n>>> Generating {dataset_name} set...")
        all_data_list = []

        for n_nodes in node_counts:
            print(f"  Generating {dataset_name} graphs with {n_nodes} nodes...")

            for _ in tqdm(range(dataset_size), desc=f"  {n_nodes}-node graphs"):
                if n_nodes > 1:
                    nx_graph = nx.random_tree(n=n_nodes, seed=None)
                else:
                    nx_graph = nx.empty_graph(n=1)

                pyg_data = convert_nx_to_pyg_data(nx_graph)
                all_data_list.append(pyg_data)

        collated_data, slices = dummy_dataset.collate(all_data_list)
        file_path = os.path.join(raw_dir, f"{dataset_name}.pt")
        torch.save((collated_data, slices), file_path)
        print(f"  Saved {len(all_data_list)} graphs to '{file_path}'.")


if __name__ == "__main__":
    node_list = [80]

    generate_and_process_datasets(
        node_counts=node_list,
        train_size=TRAIN_GRAPHS,
        val_size=VAL_GRAPHS,
        test_size=TEST_GRAPHS,
    )

    print("\nDone.")
    print("Summary:")
    print(f"  - train graphs: {len(node_list) * TRAIN_GRAPHS}")
    print(f"  - val graphs:   {len(node_list) * VAL_GRAPHS}")
    print(f"  - test graphs:  {len(node_list) * TEST_GRAPHS}")
    print("Note: update spectre_dataset.py to point to this dataset if needed.")
