from custom_modules.loader.custom_loaders import *
from custom_modules.loader.utils import *
import os.path as osp
from torch_geometric.graphgym.loader import load_ogb, set_dataset_attr
from custom_modules.loader.split_generator import prepare_splits, set_dataset_splits
from custom_modules.transform.posenc_stats import compute_posenc_stats
from custom_modules.transform.task_preprocessing import task_specific_preprocessing
from custom_modules.transform.transforms import (
    pre_transform_in_memory,
)
import time

from custom_modules.loader.utils import *
from torch_geometric.graphgym.loader import set_dataset_attr
from custom_modules.loader.split_generator import prepare_splits, set_dataset_splits
from custom_modules.transform.posenc_stats import compute_posenc_stats
from custom_modules.transform.task_preprocessing import task_specific_preprocessing
from custom_modules.transform.transforms import (
    pre_transform_in_memory,
)
import time
from torch_geometric.graphgym.config import (
    cfg,
    set_cfg,
    load_cfg,
)
from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_batch
from custom_modules.loader.dataset.node_former_NC_dataset import (
    SnapPatentsDataset,
    SnapPokecDataset,
)


def log_loaded_dataset(dataset, format, name):
    logging.info(f"[*] Loaded dataset '{name}' from '{format}':")
    logging.info(f"  {dataset.data}")
    logging.info(f"  undirected: {dataset[0].is_undirected()}")
    logging.info(f"  num graphs: {len(dataset)}")

    total_num_nodes = 0
    if hasattr(dataset.data, "num_nodes"):
        total_num_nodes = dataset.data.num_nodes
    elif hasattr(dataset.data, "x"):
        total_num_nodes = dataset.data.x.size(0)
    logging.info(f"  avg num_nodes/graph: " f"{total_num_nodes // len(dataset)}")
    logging.info(f"  num node features: {dataset.num_node_features}")
    logging.info(f"  num edge features: {dataset.num_edge_features}")
    if hasattr(dataset, "num_tasks"):
        logging.info(f"  num tasks: {dataset.num_tasks}")

    if hasattr(dataset.data, "y") and dataset.data.y is not None:
        if isinstance(dataset.data.y, list):
            # A special case for ogbg-code2 dataset.
            logging.info(f"  num classes: n/a")
        elif dataset.data.y.numel() == dataset.data.y.size(
            0
        ) and torch.is_floating_point(dataset.data.y):
            logging.info(f"  num classes: (appears to be a regression task)")
        else:
            logging.info(f"  num classes: {dataset.num_classes}")
    elif hasattr(dataset.data, "train_edge_label") or hasattr(
        dataset.data, "edge_label"
    ):
        # Edge/link prediction task.
        if hasattr(dataset.data, "train_edge_label"):
            labels = dataset.data.train_edge_label  # Transductive link task
        else:
            labels = dataset.data.edge_label  # Inductive link task
        if labels.numel() == labels.size(0) and torch.is_floating_point(labels):
            logging.info(f"  num edge classes: (probably a regression task)")
        else:
            logging.info(f"  num edge classes: {len(torch.unique(labels))}")


def load_dataset_and_save(format, name, dataset_dir, data_cfg=None):

    if format.startswith("PyG-"):
        pyg_dataset_id = format.split("-", 1)[1]
        dataset_dir = osp.join(dataset_dir, pyg_dataset_id)

        if pyg_dataset_id == "Actor":
            if name != "actor":
                raise ValueError(f"Actor class provides only one dataset.")
            dataset = Actor(dataset_dir)

        elif pyg_dataset_id == "GNNBenchmarkDataset":
            dataset = preformat_GNNBenchmarkDataset(dataset_dir, name)

        elif pyg_dataset_id == "MalNetTiny":
            dataset = preformat_MalNetTiny(dataset_dir, feature_set=name)

        elif pyg_dataset_id == "Planetoid":
            dataset = Planetoid(dataset_dir, name)

        elif pyg_dataset_id == "TUDataset":
            dataset = preformat_TUDataset(dataset_dir, name)
            # print(dataset)
            # for iter, batch in enumerate(dataset):
            #     print(batch)

        elif pyg_dataset_id == "WebKB":
            dataset = WebKB(dataset_dir, name)

        elif pyg_dataset_id == "WikipediaNetwork":
            if name == "crocodile":
                raise NotImplementedError(f"crocodile not implemented")
            dataset = WikipediaNetwork(dataset_dir, name, geom_gcn_preprocess=True)

        elif pyg_dataset_id == "ZINC":
            dataset = preformat_ZINC(dataset_dir, name)

        elif pyg_dataset_id == "AQSOL":
            dataset = preformat_AQSOL(dataset_dir, name)

        elif pyg_dataset_id == "VOCSuperpixels":
            dataset = preformat_VOCSuperpixels(
                dataset_dir, name, cfg.dataset.slic_compactness
            )

        elif pyg_dataset_id == "COCOSuperpixels":
            dataset = preformat_COCOSuperpixels(
                dataset_dir, name, cfg.dataset.slic_compactness
            )

        else:
            raise ValueError(f"Unexpected PyG Dataset identifier: {format}")

    # GraphGym default loader for Pytorch Geometric datasets
    elif format == "PyG":
        dataset = load_pyg(name, dataset_dir)

    elif format == "Snap_pokec":
        dataset = SnapPokecDataset(dataset_dir)

    elif format == "Snap_patents":
        dataset = SnapPatentsDataset(dataset_dir)

    elif format == "OGB":
        if name.startswith("ogbg"):
            dataset = preformat_OGB_Graph(dataset_dir, name.replace("_", "-"))

        elif name.startswith("ogbn"):
            dataset = preformat_OGB_Node(dataset_dir, name.replace("_", "-"))

        elif name.startswith("PCQM4Mv2-"):
            subset = name.split("-", 1)[1]
            dataset = preformat_OGB_PCQM4Mv2(dataset_dir, subset)

        elif name.startswith("peptides-"):
            dataset = preformat_Peptides(dataset_dir, name)

        ### Link prediction datasets.
        elif name.startswith("ogbl-"):
            # GraphGym default loader.
            dataset = load_ogb(name, dataset_dir)

            # OGB link prediction datasets are binary classification tasks,
            # however the default loader creates float labels => convert to int.
            def convert_to_int(ds, prop):
                tmp = getattr(ds.data, prop).int()
                set_dataset_attr(ds, prop, tmp, len(tmp))

            convert_to_int(dataset, "train_edge_label")
            convert_to_int(dataset, "val_edge_label")
            convert_to_int(dataset, "test_edge_label")

        elif name.startswith("PCQM4Mv2Contact-"):
            dataset = preformat_PCQM4Mv2Contact(dataset_dir, name)

        else:
            raise ValueError(f"Unsupported OGB(-derived) dataset: {name}")
    else:
        raise ValueError(f"Unknown data format: {format}")

    # dataset_loaded = check_and_load_processed_eig(
    #     dataset_dir, name, data_cfg.dir_with_error
    # )
    # if dataset_loaded is not None:
    #     return dataset_loaded

    pre_transform_in_memory(dataset, partial(task_specific_preprocessing, cfg=cfg))

    log_loaded_dataset(dataset, format, name)

    pe_enabled_list = []
    for key, pecfg in cfg.items():
        if key.startswith("posenc_") and pecfg.enable:
            pe_name = key.split("_", 1)[1]
            pe_enabled_list.append(pe_name)
            if hasattr(pecfg, "kernel"):
                # Generate kernel times if functional snippet is set.
                if pecfg.kernel.times_func:
                    pecfg.kernel.times = list(eval(pecfg.kernel.times_func))
                logging.info(
                    f"Parsed {pe_name} PE kernel times / steps: "
                    f"{pecfg.kernel.times}"
                )
    if pe_enabled_list:
        start = time.perf_counter()
        logging.info(
            f"Precomputing Positional Encoding statistics: "
            f"{pe_enabled_list} for all graphs..."
        )
        # Estimate directedness based on 10 graphs to save time.
        is_undirected = all(d.is_undirected() for d in dataset[:10])
        logging.info(f"  ...estimated to be undirected: {is_undirected}")
        pre_transform_in_memory(
            dataset,
            partial(
                compute_posenc_stats,
                pe_types=pe_enabled_list,
                is_undirected=is_undirected,
                cfg=cfg,
            ),
            show_progress=True,
        )
        elapsed = time.perf_counter() - start
        timestr = (
            time.strftime("%H:%M:%S", time.gmtime(elapsed)) + f"{elapsed:.2f}"[-3:]
        )
        logging.info(f"Done! Took {timestr}")

    # Set standard dataset train/val/test splits
    # if hasattr(dataset, "split_idxs"):
    #     set_dataset_splits(dataset, dataset.split_idxs)
    #     delattr(dataset, "split_idxs")

    # prepare_splits(dataset, data_cfg=data_cfg)
    if data_cfg is not None:

        set_dataset_attr(
            dataset,
            "dataset_name",
            [data_cfg.dataset_name] * len(dataset),
            len(dataset),
        )

        set_dataset_attr(
            dataset,
            "dataset_task_name",
            [f"{data_cfg.dataset_name}_{data_cfg.task}_{data_cfg.task_type}"]
            * len(dataset),
            len(dataset),
        )

        # add if task is node classification
        if data_cfg.task == "node":
            set_dataset_attr(
                dataset,
                "node_id",
                torch.tensor(list(range(len(dataset.data.y))), dtype=torch.long),
                len(dataset),
            )
    print(dataset.data)

    # Save the processed dataset for future use
    save_processed_eig(dataset, dataset_dir, name)

    return dataset


args = parse_args()
# Load config file
set_cfg(cfg)
load_cfg(cfg, args)


graph_dataset_dict = {}
for dataset_name in cfg.dataset_multi.name_list:
    dataset_cfg = getattr(cfg, dataset_name)
    dataset_cfg.enable = True
    load_dataset_and_save(
        dataset_cfg.format,
        dataset_cfg.dataset_name,
        "~/graph_datasets_ram/graph-datasets/real_datasets/datasets/real_data_no_error",
        dataset_cfg,
    )