from torch_geometric.datasets import Flickr
from .datasets import Flickr_v2, Amazon
from torchvision import datasets
from torchvision import transforms


def ImageNet(root):
    normalise = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

    transform = transforms.Compose([
        transforms.Resize((227,227)),
        transforms.ToTensor(),
        normalise
    ])

    train_dataset = datasets.ImageNet(
        root=root, download=False, split='train', transform=transform
    )

    val_dataset = datasets.ImageNet(
        root=root, download=False, split='val', transform=transform
    )

    return [(train_dataset, val_dataset)]


data = {
    "Flickr": Flickr,
    "Flickr_v2": Flickr_v2,
    "Flickr_v2-18": Flickr_v2,
    "Flickr_v2-50": Flickr_v2,
    "Flickr_v2-16": Flickr_v2,
    "Flickr_v2-CH": Flickr_v2,
    "AmazonElectronics-BytePair": Amazon,
    "AmazonElectronics-BoW": Amazon,
    "AmazonElectronics-Encoded": Amazon,
    "AmazonElectronics-roBERTa": Amazon,
    "AmazonInstruments-BytePair": Amazon,
    "AmazonInstruments-BoW": Amazon,
    "AmazonInstruments-Encoded": Amazon,
    "AmazonInstruments-roBERTa": Amazon,
    "ilsvrc2017": ImageNet, 
}

