from torch_geometric.datasets import *
from torch_geometric.datasets import HeterophilousGraphDataset
from ogb.graphproppred import PygGraphPropPredDataset
import torch
import torch_geometric.transforms as T
from functools import partial
import logging
from numpy.random import default_rng
from custom_modules.loader.dataset.aqsol_molecules import AQSOL
from custom_modules.loader.dataset.coco_superpixels import COCOSuperpixels
from custom_modules.loader.dataset.malnet_tiny import MalNetTiny
from custom_modules.loader.dataset.voc_superpixels import VOCSuperpixels

from custom_modules.transform.transforms import (
    pre_transform_in_memory,
    typecast_x,
    concat_x_and_pos,
    clip_graphs_to_size,
)

from torch_geometric.graphgym.config import cfg
import numpy as np


def load_pyg(name, dataset_dir):
    """
    Load PyG dataset objects. (More PyG datasets will be supported)

    Args:
        name (string): dataset name
        dataset_dir (string): data directory

    Returns: PyG dataset object

    """
    dataset_dir = "{}/{}".format(dataset_dir, name)
    if name[:3] == "TU_":
        # TU_IMDB doesn't have node features
        if name[3:] == "IMDB":
            name = "IMDB-MULTI"
            dataset = TUDataset(dataset_dir, name, transform=T.Constant())
        else:
            dataset = TUDataset(dataset_dir, name[3:])
    elif name == "Karate":
        dataset = KarateClub()
    elif "Coauthor" in name:
        if "CS" in name:
            dataset = Coauthor(dataset_dir, name="CS")
        else:
            dataset = Coauthor(dataset_dir, name="Physics")
    elif "AttributedGraphDataset" in name:
        dataset_name = name.split("_")[1]
        dataset = AttributedGraphDataset(dataset_dir, dataset_name)
    elif "HeterophilousGraphDataset" in name:
        dataset_name = name.split("_")[1]
        dataset = HeterophilousGraphDataset(dataset_dir, dataset_name)
    elif "Airports" in name:
        dataset_name = name.split("_")[1]
        dataset = Airports(dataset_dir, dataset_name)
    elif "Twitch" in name:
        dataset_name = name.split("_")[1]
        dataset = Twitch(dataset_dir, dataset_name)
    elif "LINKXDataset" in name:
        dataset_name = name.split("_")[1]
        dataset = LINKXDataset(dataset_dir, dataset_name)
    elif "Planetoid" in name:
        dataset_name = name.split("_")[1]
        dataset = Planetoid(dataset_dir, dataset_name)
    elif "CitationFull" in name:
        dataset_name = name.split("_")[1]
        if "Cora_ML" in name:
            dataset = CitationFull(dataset_dir, "Cora_ML")
        else:
            dataset = CitationFull(dataset_dir, dataset_name)
    elif "Amazon" in name:
        if "Computers" in name:
            dataset = Amazon(dataset_dir, name="Computers")
        elif "Products" in name:
            dataset = AmazonProducts(dataset_dir)
        else:
            dataset = Amazon(dataset_dir, name="Photo")
    elif "Reddit" in name:
        if "Reddit2" in name:
            dataset = Reddit2(dataset_dir)
        else:
            dataset = Reddit(dataset_dir)
    elif name == "KarateClub":
        dataset = KarateClub()
    elif name == "MNIST":
        dataset = MNISTSuperpixels(dataset_dir)
    elif name == "PPI":
        dataset = PPI(dataset_dir)
    elif name == "QM7b":
        dataset = QM7b(dataset_dir)
    elif name == "QM9":
        dataset = QM9(dataset_dir)

    elif "SnapDataset" in name:
        dataset_name = name.split("_")[1]
        dataset = SNAPDataset(dataset_dir, name=dataset_name)
    else:
        try:
            dataset = eval(f"{name}('{dataset_dir}')")
        except:
            raise ValueError("{} not support".format(name))
    if not hasattr(dataset, "x"):
        dataset = preformat_add_zero_node_features(dataset)
    return dataset


def preformat_GNNBenchmarkDataset(dataset_dir, name):
    """Load and preformat datasets from PyG's GNNBenchmarkDataset.

    Args:
        dataset_dir: path where to store the cached dataset
        name: name of the specific dataset in the TUDataset class

    Returns:
        PyG dataset object
    """
    if name in ["MNIST", "CIFAR10"]:
        tf_list = [concat_x_and_pos]  # concat pixel value and pos. coordinate
        tf_list.append(partial(typecast_x, type_str="float"))
    elif name in ["PATTERN", "CLUSTER", "CSL"]:
        tf_list = []
    else:
        raise ValueError(
            f"Loading dataset '{name}' from " f"GNNBenchmarkDataset is not supported."
        )

    if name in ["MNIST", "CIFAR10", "PATTERN", "CLUSTER"]:
        dataset = join_dataset_splits(
            [
                GNNBenchmarkDataset(root=dataset_dir, name=name, split=split)
                for split in ["train", "val", "test"]
            ]
        )
        pre_transform_in_memory(dataset, T.Compose(tf_list))
    elif name == "CSL":
        dataset = GNNBenchmarkDataset(root=dataset_dir, name=name)

    return dataset


def preformat_MalNetTiny(dataset_dir, feature_set):
    """Load and preformat Tiny version (5k graphs) of MalNet

    Args:
        dataset_dir: path where to store the cached dataset
        feature_set: select what node features to precompute as MalNet
            originally doesn't have any node nor edge features

    Returns:
        PyG dataset object
    """
    if feature_set in ["none", "Constant"]:
        tf = T.Constant()
    elif feature_set == "OneHotDegree":
        tf = T.OneHotDegree()
    elif feature_set == "LocalDegreeProfile":
        tf = T.LocalDegreeProfile()
    else:
        raise ValueError(f"Unexpected transform function: {feature_set}")

    dataset = MalNetTiny(dataset_dir)
    dataset.name = "MalNetTiny"
    logging.info(f'Computing "{feature_set}" node features for MalNetTiny.')
    pre_transform_in_memory(dataset, tf)

    split_dict = dataset.get_idx_split()
    dataset.split_idxs = [split_dict["train"], split_dict["valid"], split_dict["test"]]
    # for iter, batch in enumerate(dataset):
    #     print(batch)

    return dataset


def preformat_add_zero_node_features(dataset):
    """Add zero node features to dataset.

    Args:
        dataset: PyG dataset object

    Returns:
        PyG dataset object
    """

    def add_zeros(data):
        data.x = torch.zeros((data.num_nodes, 1), dtype=torch.long)
        return data

    dataset = add_zeros(dataset)
    return dataset


def preformat_OGB_Graph(dataset_dir, name):
    """Load and preformat OGB Graph Property Prediction datasets.

    Args:
        dataset_dir: path where to store the cached dataset
        name: name of the specific OGB Graph dataset

    Returns:
        PyG dataset object
    """
    dataset = PygGraphPropPredDataset(name=name, root=dataset_dir)
    s_dict = dataset.get_idx_split()
    dataset.split_idxs = [s_dict[s] for s in ["train", "valid", "test"]]

    if name == "ogbg-ppa":
        # ogbg-ppa doesn't have any node features, therefore add zeros but do
        # so dynamically as a 'transform' and not as a cached 'pre-transform'
        # because the dataset is big (~38.5M nodes), already taking ~31GB space
        def add_zeros(data):
            data.x = torch.zeros(data.num_nodes, dtype=torch.long)
            return data

        dataset.transform = add_zeros
    elif name == "ogbg-code2":
        from custom_modules.loader.ogbg_code2_utils import (
            idx2vocab,
            get_vocab_mapping,
            augment_edge,
            encode_y_to_arr,
        )

        num_vocab = 5000  # The number of vocabulary used for sequence prediction
        max_seq_len = 5  # The maximum sequence length to predict

        seq_len_list = np.array([len(seq) for seq in dataset.data.y])
        logging.info(
            f"Target sequences less or equal to {max_seq_len} is "
            f"{np.sum(seq_len_list <= max_seq_len) / len(seq_len_list)}"
        )

        # Building vocabulary for sequence prediction. Only use training data.
        vocab2idx, idx2vocab_local = get_vocab_mapping(
            [dataset.data.y[i] for i in s_dict["train"]], num_vocab
        )
        logging.info(f"Final size of vocabulary is {len(vocab2idx)}")
        idx2vocab.extend(
            idx2vocab_local
        )  # Set to global variable to later access in CustomLogger

        # Set the transform function:
        # augment_edge: add next-token edge as well as inverse edges. add edge attributes.
        # encode_y_to_arr: add y_arr to PyG data object, indicating the array repres
        dataset.transform = T.Compose(
            [augment_edge, lambda data: encode_y_to_arr(data, vocab2idx, max_seq_len)]
        )

        # Subset graphs to a maximum size (number of nodes) limit.
        pre_transform_in_memory(dataset, partial(clip_graphs_to_size, size_limit=1000))

    return dataset


from ogb.nodeproppred import PygNodePropPredDataset


def preformat_OGB_Node(dataset_dir, name):
    """Load and preformat OGB Graph Property Prediction datasets.

    Args:
        dataset_dir: path where to store the cached dataset
        name: name of the specific OGB Graph dataset

    Returns:
        PyG dataset object
    """
    dataset = PygNodePropPredDataset(name=name, root=dataset_dir)
    s_dict = dataset.get_idx_split()
    dataset.split_idxs = [s_dict[s] for s in ["train", "valid", "test"]]
    return dataset


def preformat_OGB_PCQM4Mv2(dataset_dir, name):
    """Load and preformat PCQM4Mv2 from OGB LSC.

    OGB-LSC provides 4 data index splits:
    2 with labeled molecules: 'train', 'valid' meant for training and dev
    2 unlabeled: 'test-dev', 'test-challenge' for the LSC challenge submission

    We will take random 150k from 'train' and make it a validation set and
    use the original 'valid' as our testing set.

    Note: PygPCQM4Mv2Dataset requires rdkit

    Args:
        dataset_dir: path where to store the cached dataset
        name: select 'subset' or 'full' version of the training set

    Returns:
        PyG dataset object
    """
    try:
        # Load locally to avoid RDKit dependency until necessary.
        from ogb.lsc import PygPCQM4Mv2Dataset
    except Exception as e:
        logging.error(
            "ERROR: Failed to import PygPCQM4Mv2Dataset, "
            "make sure RDKit is installed."
        )
        raise e

    dataset = PygPCQM4Mv2Dataset(root=dataset_dir)
    split_idx = dataset.get_idx_split()

    rng = default_rng(seed=42)
    train_idx = rng.permutation(split_idx["train"].numpy())
    train_idx = torch.from_numpy(train_idx)

    # Leave out 150k graphs for a new validation set.
    valid_idx, train_idx = train_idx[:150000], train_idx[150000:]
    if name == "full":
        split_idxs = [
            train_idx,  # Subset of original 'train'.
            valid_idx,  # Subset of original 'train' as validation set.
            split_idx["valid"],  # The original 'valid' as testing set.
        ]

    elif name == "subset":
        # Further subset the training set for faster debugging.
        subset_ratio = 0.1
        subtrain_idx = train_idx[: int(subset_ratio * len(train_idx))]
        subvalid_idx = valid_idx[:50000]
        subtest_idx = split_idx["valid"]  # The original 'valid' as testing set.

        dataset = dataset[torch.cat([subtrain_idx, subvalid_idx, subtest_idx])]
        data_list = [data for data in dataset]
        dataset._indices = None
        dataset._data_list = data_list
        dataset.data, dataset.slices = dataset.collate(data_list)
        n1, n2, n3 = len(subtrain_idx), len(subvalid_idx), len(subtest_idx)
        split_idxs = [
            list(range(n1)),
            list(range(n1, n1 + n2)),
            list(range(n1 + n2, n1 + n2 + n3)),
        ]

    elif name == "inference":
        split_idxs = [
            split_idx["valid"],  # The original labeled 'valid' set.
            split_idx["test-dev"],  # Held-out unlabeled test dev.
            split_idx["test-challenge"],  # Held-out challenge test set.
        ]

        dataset = dataset[torch.cat(split_idxs)]
        data_list = [data for data in dataset]
        dataset._indices = None
        dataset._data_list = data_list
        dataset.data, dataset.slices = dataset.collate(data_list)
        n1, n2, n3 = len(split_idxs[0]), len(split_idxs[1]), len(split_idxs[2])
        split_idxs = [
            list(range(n1)),
            list(range(n1, n1 + n2)),
            list(range(n1 + n2, n1 + n2 + n3)),
        ]
        # Check prediction targets.
        assert all([not torch.isnan(dataset[i].y)[0] for i in split_idxs[0]])
        assert all([torch.isnan(dataset[i].y)[0] for i in split_idxs[1]])
        assert all([torch.isnan(dataset[i].y)[0] for i in split_idxs[2]])

    else:
        raise ValueError(f"Unexpected OGB PCQM4Mv2 subset choice: {name}")
    dataset.split_idxs = split_idxs
    return dataset


def preformat_PCQM4Mv2Contact(dataset_dir, name):
    """Load PCQM4Mv2-derived molecular contact link prediction dataset.

    Note: This dataset requires RDKit dependency!

    Args:
       dataset_dir: path where to store the cached dataset
       name: the type of dataset split: 'shuffle', 'num-atoms'

    Returns:
       PyG dataset object
    """
    try:
        # Load locally to avoid RDKit dependency until necessary
        from graphgps.loader.dataset.pcqm4mv2_contact import (
            PygPCQM4Mv2ContactDataset,
            structured_neg_sampling_transform,
        )
    except Exception as e:
        logging.error(
            "ERROR: Failed to import PygPCQM4Mv2ContactDataset, "
            "make sure RDKit is installed."
        )
        raise e

    split_name = name.split("-", 1)[1]
    dataset = PygPCQM4Mv2ContactDataset(dataset_dir, subset="530k")
    # Inductive graph-level split (there is no train/test edge split).
    s_dict = dataset.get_idx_split(split_name)
    dataset.split_idxs = [s_dict[s] for s in ["train", "val", "test"]]
    if cfg.dataset.resample_negative:
        dataset.transform = structured_neg_sampling_transform
    return dataset


def preformat_Peptides(dataset_dir, name):
    """Load Peptides dataset, functional or structural.

    Note: This dataset requires RDKit dependency!

    Args:
        dataset_dir: path where to store the cached dataset
        name: the type of dataset split:
            - 'peptides-functional' (10-task classification)
            - 'peptides-structural' (11-task regression)

    Returns:
        PyG dataset object
    """
    try:
        # Load locally to avoid RDKit dependency until necessary.
        from custom_modules.loader.dataset.peptides_functional import (
            PeptidesFunctionalDataset,
        )
        from custom_modules.loader.dataset.peptides_structural import (
            PeptidesStructuralDataset,
        )
    except Exception as e:
        logging.error(
            "ERROR: Failed to import Peptides dataset class, "
            "make sure RDKit is installed."
        )
        raise e

    dataset_type = name.split("-", 1)[1]
    if dataset_type == "functional":
        dataset = PeptidesFunctionalDataset(dataset_dir)
    elif dataset_type == "structural":
        dataset = PeptidesStructuralDataset(dataset_dir)
    s_dict = dataset.get_idx_split()
    dataset.split_idxs = [s_dict[s] for s in ["train", "val", "test"]]
    return dataset


def preformat_TUDataset(dataset_dir, name):
    """Load and preformat datasets from PyG's TUDataset.

    Args:
        dataset_dir: path where to store the cached dataset
        name: name of the specific dataset in the TUDataset class

    Returns:
        PyG dataset object
    """
    if name in ["DD", "NCI1", "ENZYMES", "PROTEINS", "TRIANGLES", "MUTAG"]:
        func = None
    elif name.startswith("IMDB-") or name == "COLLAB" or name.startswith("REDDIT-"):
        func = T.Constant()
    else:
        raise ValueError(
            f"Loading dataset '{name}' from " f"TUDataset is not supported."
        )
    dataset = TUDataset(dataset_dir, name, pre_transform=func)
    return dataset


def preformat_ZINC(dataset_dir, name):
    """Load and preformat ZINC datasets.

    Args:
        dataset_dir: path where to store the cached dataset
        name: select 'subset' or 'full' version of ZINC

    Returns:
        PyG dataset object
    """
    if name not in ["subset", "full"]:
        raise ValueError(f"Unexpected subset choice for ZINC dataset: {name}")
    dataset = join_dataset_splits(
        [
            ZINC(root=dataset_dir, subset=(name == "subset"), split=split)
            for split in ["train", "val", "test"]
        ]
    )
    return dataset


def preformat_AQSOL(dataset_dir):
    """Load and preformat AQSOL datasets.

    Args:
        dataset_dir: path where to store the cached dataset

    Returns:
        PyG dataset object
    """
    dataset = join_dataset_splits(
        [AQSOL(root=dataset_dir, split=split) for split in ["train", "val", "test"]]
    )
    return dataset


def preformat_VOCSuperpixels(dataset_dir, name, slic_compactness):
    """Load and preformat VOCSuperpixels dataset.

    Args:
        dataset_dir: path where to store the cached dataset
    Returns:
        PyG dataset object
    """
    dataset = join_dataset_splits(
        [
            VOCSuperpixels(
                root=dataset_dir,
                name=name,
                slic_compactness=slic_compactness,
                split=split,
            )
            for split in ["train", "val", "test"]
        ]
    )
    return dataset


def preformat_COCOSuperpixels(dataset_dir, name, slic_compactness):
    """Load and preformat COCOSuperpixels dataset.

    Args:
        dataset_dir: path where to store the cached dataset
    Returns:
        PyG dataset object
    """
    dataset = join_dataset_splits(
        [
            COCOSuperpixels(
                root=dataset_dir,
                name=name,
                slic_compactness=slic_compactness,
                split=split,
            )
            for split in ["train", "val", "test"]
        ]
    )
    return dataset


def join_dataset_splits(datasets):
    """Join train, val, test datasets into one dataset object.

    Args:
        datasets: list of 3 PyG datasets to merge

    Returns:
        joint dataset with `split_idxs` property storing the split indices
    """
    assert len(datasets) == 3, "Expecting train, val, test datasets"

    n1, n2, n3 = len(datasets[0]), len(datasets[1]), len(datasets[2])
    data_list = (
        [datasets[0].get(i) for i in range(n1)]
        + [datasets[1].get(i) for i in range(n2)]
        + [datasets[2].get(i) for i in range(n3)]
    )

    datasets[0]._indices = None
    datasets[0]._data_list = data_list
    datasets[0].data, datasets[0].slices = datasets[0].collate(data_list)
    split_idxs = [
        list(range(n1)),
        list(range(n1, n1 + n2)),
        list(range(n1 + n2, n1 + n2 + n3)),
    ]
    datasets[0].split_idxs = split_idxs

    return datasets[0]
