import argparse
import time
import os.path as osp
import logging
import torch
from torch_geometric.graphgym.config import (
    cfg,
    set_cfg,
    load_cfg,
)
from torch_geometric.graphgym.loader import set_dataset_attr
import custom_modules
from custom_modules.loader.utils import *
from custom_modules.loader.custom_loaders import *
from custom_modules.transform.transforms import (
    pre_transform_in_memory,
)
from custom_modules.transform.task_preprocessing import task_specific_preprocessing
from custom_modules.transform.posenc_stats import compute_posenc_stats


def log_loaded_dataset(dataset, format, name):
    logging.info(f"[*] Loaded dataset '{name}' from '{format}':")
    logging.info(f"  {dataset.data}")
    logging.info(f"  undirected: {dataset[0].is_undirected()}")
    logging.info(f"  num graphs: {len(dataset)}")

    total_num_nodes = 0
    if hasattr(dataset.data, "num_nodes"):
        total_num_nodes = dataset.data.num_nodes
    elif hasattr(dataset.data, "x"):
        total_num_nodes = dataset.data.x.size(0)
    logging.info(f"  avg num_nodes/graph: " f"{total_num_nodes // len(dataset)}")
    logging.info(f"  num node features: {dataset.num_node_features}")
    logging.info(f"  num edge features: {dataset.num_edge_features}")
    if hasattr(dataset, "num_tasks"):
        logging.info(f"  num tasks: {dataset.num_tasks}")

    if hasattr(dataset.data, "y") and dataset.data.y is not None:
        if isinstance(dataset.data.y, list):
            # A special case for ogbg-code2 dataset.
            logging.info(f"  num classes: n/a")
        elif dataset.data.y.numel() == dataset.data.y.size(
            0
        ) and torch.is_floating_point(dataset.data.y):
            logging.info(f"  num classes: (appears to be a regression task)")
        else:
            logging.info(f"  num classes: {dataset.num_classes}")
    elif hasattr(dataset.data, "train_edge_label") or hasattr(
        dataset.data, "edge_label"
    ):
        # Edge/link prediction task.
        if hasattr(dataset.data, "train_edge_label"):
            labels = dataset.data.train_edge_label  # Transductive link task
        else:
            labels = dataset.data.edge_label  # Inductive link task
        if labels.numel() == labels.size(0) and torch.is_floating_point(labels):
            logging.info(f"  num edge classes: (probably a regression task)")
        else:
            logging.info(f"  num edge classes: {len(torch.unique(labels))}")


def load_dataset_and_save(format, name, dataset_dir, data_cfg=None):

    if format.startswith("PyG-"):
        pyg_dataset_id = format.split("-", 1)[1]
        dataset_dir = osp.join(dataset_dir, pyg_dataset_id)

        if pyg_dataset_id == "Actor":
            if name != "actor":
                raise ValueError(f"Actor class provides only one dataset.")
            dataset = Actor(dataset_dir)

        elif pyg_dataset_id == "WebKB":
            dataset = WebKB(dataset_dir, name)

        elif pyg_dataset_id == "WikipediaNetwork":
            if name == "crocodile":
                raise NotImplementedError(f"crocodile not implemented")
            dataset = WikipediaNetwork(dataset_dir, name, geom_gcn_preprocess=True)

        else:
            raise ValueError(f"Unexpected PyG Dataset identifier: {format}")

    # GraphGym default loader for Pytorch Geometric datasets
    elif format == "PyG":
        dataset = load_pyg(name, dataset_dir)

    elif format == "Network_repository":
        dataset_dir_network_repo = "/graph-datasets"  # change dir to where network repo datasets are stored
        dataset = NetworkRepository(
            f"{dataset_dir_network_repo}/{data_cfg.dataset_name}"
        )

    elif format == "Snap_pokec":
        dataset = SnapPokecDataset(dataset_dir)

    else:
        raise ValueError(f"Unknown data format: {format}")

    pre_transform_in_memory(dataset, partial(task_specific_preprocessing, cfg=cfg))

    log_loaded_dataset(dataset, format, name)

    # check if preprocessed already saved
    dataset_loaded = check_processed_eig(dataset_dir, name)
    if dataset_loaded is not None:
        return dataset_loaded

    pe_enabled_list = []
    for key, pecfg in cfg.items():
        if key.startswith("posenc_") and pecfg.enable:
            pe_name = key.split("_", 1)[1]
            pe_enabled_list.append(pe_name)
            if hasattr(pecfg, "kernel"):
                # Generate kernel times if functional snippet is set.
                if pecfg.kernel.times_func:
                    pecfg.kernel.times = list(eval(pecfg.kernel.times_func))
                logging.info(
                    f"Parsed {pe_name} PE kernel times / steps: "
                    f"{pecfg.kernel.times}"
                )
    if pe_enabled_list:
        start = time.perf_counter()
        logging.info(
            f"Precomputing Positional Encoding statistics: "
            f"{pe_enabled_list} for all graphs..."
        )
        # Estimate directedness based on 10 graphs to save time.
        is_undirected = all(d.is_undirected() for d in dataset[:10])
        logging.info(f"  ...estimated to be undirected: {is_undirected}")

        pre_transform_in_memory(
            dataset,
            partial(
                compute_posenc_stats,
                pe_types=pe_enabled_list,
                is_undirected=is_undirected,
                cfg=cfg,
            ),
            show_progress=True,
        )
        elapsed = time.perf_counter() - start
        timestr = (
            time.strftime("%H:%M:%S", time.gmtime(elapsed)) + f"{elapsed:.2f}"[-3:]
        )
        logging.info(f"Done! Took {timestr}")

        if not hasattr(dataset.data, "x"):
            data_list = []
            for i in range(len(dataset)):
                data = dataset.get(i)

                data.x = torch.ones(data.num_nodes, 1)

                data.x = data.x.float()

                data_list.append(data)

            # Collate the dataset
            dataset.data, dataset.slices = dataset.collate(data_list)
        if dataset.data.x is None:
            data_list = []
            for i in range(len(dataset)):
                data = dataset.get(i)

                data.x = torch.ones(data.num_nodes, 1)

                data.x = data.x.float()

                data_list.append(data)

            # Collate the dataset
            dataset.data, dataset.slices = dataset.collate(data_list)

    if not hasattr(dataset, "name"):
        dataset.name = data_cfg.dataset_name

    if data_cfg is not None:
        set_dataset_attr(
            dataset,
            "dataset_name",
            [data_cfg.dataset_name] * len(dataset),
            len(dataset),
        )

        set_dataset_attr(
            dataset,
            "dataset_task_name",
            [f"{data_cfg.dataset_name}_{data_cfg.task}_{data_cfg.task_type}"]
            * len(dataset),
            len(dataset),
        )

        # add if task is node classification
        if data_cfg.task == "node":
            set_dataset_attr(
                dataset,
                "node_id",
                torch.tensor(list(range(len(dataset.data.y))), dtype=torch.long),
                len(dataset),
            )
    # Save the processed dataset for future use
    save_processed_eig(dataset, dataset_dir, name)

    return dataset


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="GraphFM")

    parser.add_argument(
        "--cfg",
        dest="cfg_file",
        type=str,
        required=True,
        help="The configuration file path.",
    )
    parser.add_argument(
        "opts",
        default=None,
        nargs=argparse.REMAINDER,
        help="See graphgym/config.py for remaining options.",
    )

    args = parser.parse_args()
    # Load config file
    set_cfg(cfg)
    print(cfg)
    load_cfg(cfg, args)
    graph_dataset_dict = {}
    for dataset_name in cfg.dataset_multi.name_list:
        dataset_cfg = getattr(cfg, dataset_name)
        dataset_cfg.enable = True
        load_dataset_and_save(
            dataset_cfg.format,
            dataset_cfg.dataset_name,
            f"{cfg.out_dir}/graph-datasets/real_datasets",
            dataset_cfg,
        )
        # print('dataset',full_dataset)
