import torch
from torch_geometric.datasets import GNNBenchmarkDataset
from functools import partial
from src.datasets.split_generator import join_dataset_splits
from tqdm import tqdm
import torch_geometric.transforms as T


def concat_x_and_pos(data):
    data.x = torch.cat((data.x, data.pos), 1)
    return data


def typecast_x(data, type_str):
    if type_str == "float":
        data.x = data.x.float()
    elif type_str == "long":
        data.x = data.x.long()
    else:
        raise ValueError(f"Unexpected type '{type_str}'.")
    return data


def pre_transform_in_memory(dataset, transform_func, show_progress=False):
    """Pre-transform already loaded PyG dataset object.
    Apply transform function to a loaded PyG dataset object so that
    the transformed result is persistent for the lifespan of the object.
    This means the result is not saved to disk, as what PyG's `pre_transform`
    would do, but also the transform is applied only once and not at each
    data access as what PyG's `transform` hook does.
    Implementation is based on torch_geometric.data.in_memory_dataset.copy
    Args:
        dataset: PyG dataset object to modify
        transform_func: transformation function to apply to each data example
        show_progress: show tqdm progress bar
    """
    if transform_func is None:
        return dataset
    data_list = [
        transform_func(dataset.get(i))
        for i in tqdm(range(len(dataset)), disable=not show_progress, mininterval=10, miniters=len(dataset) // 20)
    ]
    data_list = list(filter(None, data_list))
    dataset._indices = None
    dataset._data_list = data_list
    dataset.data, dataset.slices = dataset.collate(data_list)
