from pathlib import Path
import hydra
from omegaconf import DictConfig
from functools import partial
from src.utils import seed_everything, print_flattened_config
from src.datasets.pos_enc_compute import (
    compute_pos_enc_statistics,
    pre_transform_in_memory,
)
from src.datasets.split_generator import generate_random_split
from src.datasets.custom_processing import *
import pickle
import torch_geometric
from torch_geometric.graphgym.loader import load_ogb
import torch
from pathlib import Path


@hydra.main(config_path="./configs", config_name="preprocess")
def main(cfg: DictConfig):
    seed_everything(cfg.seed)
    print_flattened_config(cfg)

    data_root = Path(cfg.data_root)
    raw_dir = data_root / "raw"
    raw_dir.mkdir(exist_ok=True, parents=True)
    print(f"Raw dataset dir: {raw_dir}")

    # -- Download/load dataset
    if cfg.dataset.format == "PyG":
        if cfg.dataset.source == "torch_geometric":
            if "-" in cfg.dataset.name:
                # for datasets that are like: DatasetGroupName-DatasetName
                pyg_dataset_id, dataset_name = cfg.dataset.name.split("-", 1)
                ds_class = getattr(torch_geometric.datasets, pyg_dataset_id)
                dataset = ds_class(root=raw_dir, name=dataset_name)
            else:
                # for datasets that are single datasets (not groups)
                ds_class = getattr(torch_geometric.datasets, cfg.dataset.name)
                dataset = ds_class(root=raw_dir)
        elif cfg.dataset.source == "ogbn":
            dataset = load_ogb(cfg.dataset.name, raw_dir / cfg.dataset.shortname)

        else:
            raise ValueError(f"Unknown dataset source {cfg.dataset.source}")
    else:
        raise ValueError(f"Unknown dataset format: {cfg.dataset.format}")

    if cfg.dataset.split_mode == "standard" and dataset.train_mask.ndim == 1:
        dataset.data.train_mask = dataset.train_mask.unsqueeze(1)
        dataset.data.val_mask = dataset.val_mask.unsqueeze(1)
        dataset.data.test_mask = dataset.test_mask.unsqueeze(1)
    print(dataset)
    print(dataset.data)

    # -- Add splits
    num_splits = len(cfg.dataset.split_index)
    if cfg.dataset.split_mode == "standard":
        # test_mask is sometimes different for each split, sometimes the same
        # we want the shape of all masks to always be [N_nodes, N_splits],
        if dataset.data.test_mask.ndim == 1:  # e.g. this happens in wikiCS
            assert dataset.data.test_mask.size(0) == dataset.data.num_nodes
            dataset.data.test_mask = dataset.data.test_mask.repeat(num_splits, 1).T

    elif cfg.dataset.split_mode == "random":

        num_nodes = dataset.data.num_nodes
        train_mask = torch.zeros((num_nodes, num_splits), dtype=bool)
        val_mask = torch.zeros((num_nodes, num_splits), dtype=bool)
        test_mask = torch.zeros((num_nodes, num_splits), dtype=bool)

        for i in cfg.dataset.split_index:
            train_idx, val_idx, test_idx = generate_random_split(
                dataset=dataset,
                split_ratios=cfg.dataset.split_ratios,
                split_index=i,
                seed=cfg.seed,
            )
            train_mask[train_idx, i] = True
            val_mask[val_idx, i] = True
            test_mask[test_idx, i] = True

        dataset.data["train_mask"] = train_mask
        dataset.data["val_mask"] = val_mask
        dataset.data["test_mask"] = test_mask

    # Ensure masks make sense:
    # 1. presence and shape
    for s in ("train_mask", "val_mask", "test_mask"):
        assert hasattr(dataset.data, s), f"{s} not found in dataset"
        assert getattr(dataset.data, s).size(1) == num_splits

    # 2. no overlap
    assert (dataset.train_mask & dataset.val_mask).any() == False
    assert (dataset.val_mask & dataset.test_mask).any() == False
    assert (dataset.test_mask & dataset.train_mask).any() == False

    # -- Add eigenvectors
    # Estimate directedness based on 10 graphs to save time.
    is_undirected = all(d.is_undirected() for d in dataset[:10])
    pre_transform_in_memory(
        dataset,
        partial(
            compute_pos_enc_statistics,
            cfg=cfg,
            is_undirected=is_undirected,
        ),
        show_progress=True,
    )

    # Save
    print(dataset.data)
    save_dir = data_root / "processed"
    save_dir.mkdir(parents=True, exist_ok=True)
    save_path = save_dir / f"{cfg.dataset.shortname}.pt"
    with open(save_path, "wb") as f:
        pickle.dump(dataset, f)
    print(f"Preprocessed dataset saved at {save_path}")


if __name__ == "__main__":
    main()
