from typing import List
from omegaconf.dictconfig import DictConfig
from torch_geometric.data import InMemoryDataset
from torch_geometric.graphgym.loader import index2mask, set_dataset_attr
from sklearn.model_selection import ShuffleSplit
import torch
import numpy as np


def prepare_splits(dataset, cfg, split_index):
    """Ready train/val/test splits.

    Determine the type of split from the config and call the corresponding
    split generation / verification function.
    """
    split_mode = cfg.dataset.split_mode

    if split_mode == "standard":
        setup_standard_split(dataset, cfg, split_index)
    elif split_mode == "random":
        setup_random_split(dataset, cfg, split_index)


def setup_standard_split(dataset, config=None, split_index=0):
    """Select a standard split.

    Use standard splits that come with the dataset. Pick one split based on the
    ``split_index`` from the config file if multiple splits are available.

    GNNBenchmarkDatasets have splits that are not prespecified as masks. Therefore,
    they are handled differently and are first processed to generate the masks.

    Raises:
        ValueError: If any one of train/val/test mask is missing.
        IndexError: If the ``split_index`` is greater or equal to the total
            number of splits available.
    """

    task_level = config.dataset.task

    if task_level == "node":
        for split_name in "train_mask", "val_mask", "test_mask":
            mask = getattr(dataset.data, split_name, None)
            # Check if the train/val/test split mask is available
            if mask is None:
                raise ValueError(f"Missing '{split_name}' for standard split")
            # Pick a specific split if multiple splits are available
            if mask.dim() == 2:
                if split_index >= mask.shape[1]:
                    raise IndexError(
                        f"Specified split index ({split_index}) is "
                        f"out of range of the number of available "
                        f"splits ({mask.shape[1]}) for {split_name}"
                    )
                set_dataset_attr(dataset, split_name, mask[:, split_index], len(mask[:, split_index]))
            elif mask.dim() == 1:
                set_dataset_attr(dataset, split_name, mask, len(mask))
            else:
                if split_index != 0:
                    raise IndexError(f"This dataset has single standard split")
    else:
        if split_index != 0:
            raise NotImplementedError(
                f"Multiple standard splits not supported " f"for dataset task level: {task_level}"
            )


def generate_random_split(
    dataset: InMemoryDataset,
    split_ratios: List[float],
    split_index: int,
    seed: int = 0,
):
    """Generate random splits.

    Generate random train/val/test based on the ratios defined in the config
    file.

    Raises:
        ValueError: If the number split ratios is not equal to 3, or the ratios
            do not sum up to 1.
    """

    if len(split_ratios) != 3:
        raise ValueError(
            f"Three split ratios is expected for train/val/test, received "
            f"{len(split_ratios)} split ratios: {repr(split_ratios)}"
        )
    elif sum(split_ratios) != 1 and sum(split_ratios) != len(dataset):
        raise ValueError(
            f"The train/val/test split ratios must sum up to 1/length of the dataset, input ratios "
            f"sum up to {sum(split_ratios):.2f} instead: {repr(split_ratios)}"
        )

    train_index, val_test_index = next(
        ShuffleSplit(
            train_size=split_ratios[0],
            random_state=int(seed + split_index * 100),
        ).split(dataset.data.y, dataset.data.y)
    )

    if isinstance(split_ratios[0], float):
        val_test_ratio = split_ratios[1] / (1 - split_ratios[0])
    else:
        val_test_ratio = split_ratios[1]

    val_index, test_index = next(
        ShuffleSplit(train_size=val_test_ratio, random_state=int(seed + split_index * 150)).split(
            dataset.data.y[val_test_index], dataset.data.y[val_test_index]
        )
    )
    val_index = val_test_index[val_index]
    test_index = val_test_index[test_index]
    return train_index, val_index, test_index


def setup_random_split(dataset, config, split_index=0):
    """Generate random splits.

    Generate random train/val/test based on the ratios defined in the config
    file.

    Raises:
        ValueError: If the number split ratios is not equal to 3, or the ratios
            do not sum up to 1.
    """

    split_ratios = config.dataset.split_ratios

    if len(split_ratios) != 3:
        raise ValueError(
            f"Three split ratios is expected for train/val/test, received "
            f"{len(split_ratios)} split ratios: {repr(split_ratios)}"
        )
    elif sum(split_ratios) != 1 and sum(split_ratios) != len(dataset):
        raise ValueError(
            f"The train/val/test split ratios must sum up to 1/length of the dataset, input ratios "
            f"sum up to {sum(split_ratios):.2f} instead: {repr(split_ratios)}"
        )

    train_index, val_test_index = next(
        ShuffleSplit(
            train_size=split_ratios[0],
            random_state=int(config.seed + split_index * 100),
        ).split(dataset.data.y, dataset.data.y)
    )

    if isinstance(split_ratios[0], float):
        val_test_ratio = split_ratios[1] / (1 - split_ratios[0])
    else:
        val_test_ratio = split_ratios[1]

    val_index, test_index = next(
        ShuffleSplit(train_size=val_test_ratio, random_state=int(config.seed + split_index * 150)).split(
            dataset.data.y[val_test_index], dataset.data.y[val_test_index]
        )
    )
    val_index = val_test_index[val_index]
    test_index = val_test_index[test_index]

    set_dataset_splits(dataset, [train_index, val_index, test_index], config=config)


def set_dataset_splits(dataset, splits, config=None, split_index=0):
    """Set given splits to the dataset object.

    Args:
        dataset: PyG dataset object
        splits: List of train/val/test split indices

    Raises:
        ValueError: If any pair of splits has intersecting indices
    """
    # First check whether splits intersect and raise error if so.
    for i in range(len(splits) - 1):
        for j in range(i + 1, len(splits)):
            n_intersect = len(set(splits[i]) & set(splits[j]))
            if n_intersect != 0:
                raise ValueError(
                    f"Splits must not have intersecting indices: "
                    f"split #{i} (n = {len(splits[i])}) and "
                    f"split #{j} (n = {len(splits[j])}) have "
                    f"{n_intersect} intersecting indices"
                )

    task_level = config.dataset.task
    if task_level == "node":
        split_names = ["train_mask", "val_mask", "test_mask"]
        for split_name, split_index in zip(split_names, splits):
            split_index_tensor = torch.tensor(split_index, dtype=torch.long)  # throws error if this is not tensor

            mask = index2mask(split_index_tensor, size=dataset.data.y.shape[0])
            set_dataset_attr(dataset, split_name, mask, len(mask))
    else:
        raise ValueError(f"Unsupported dataset task level: {task_level}")


def join_dataset_splits(datasets):
    """Join train, val, test datasets into one dataset object.
    Args:
        datasets: list of 3 PyG datasets to merge
    Returns:
        joint dataset with `split_idxs` property storing the split indices
    """
    assert len(datasets) == 3, "Expecting train, val, test datasets"
    n1, n2, n3 = len(datasets[0]), len(datasets[1]), len(datasets[2])
    data_list = (
        [datasets[0].get(i) for i in range(n1)]
        + [datasets[1].get(i) for i in range(n2)]
        + [datasets[2].get(i) for i in range(n3)]
    )
    datasets[0]._indices = None
    datasets[0]._data_list = data_list
    datasets[0].data, datasets[0].slices = datasets[0].collate(data_list)
    split_idxs = [list(range(n1)), list(range(n1, n1 + n2)), list(range(n1 + n2, n1 + n2 + n3))]
    datasets[0].split_idxs = split_idxs
    return datasets[0]
