from datasets.utils import train_test_split
from torchvision import transforms, datasets
from torch.utils.data import Subset

def get_dataset_from_config(config):
    norm_transform = transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768])
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        norm_transform
    ])
    val_transform = transforms.Compose([
        transforms.ToTensor(),
        norm_transform
    ])

    if config.dataset_name == 'cifar10':
        train_dataset = datasets.CIFAR10(root=config.dataset_root, train=True, download=True,
                                                transform=train_transform)
        valid_dataset = datasets.CIFAR10(root=config.dataset_root, train=True, download=False,
                                            transform=val_transform)
        test_dataset = datasets.CIFAR10(root=config.dataset_root, train=False, download=True,
                                            transform=val_transform)
    elif config.dataset_name == 'cifar100':
        train_dataset = datasets.CIFAR100(root=config.dataset_root, train=True, download=True,
                                                transform=train_transform)
        valid_dataset = datasets.CIFAR100(root=config.dataset_root, train=True, download=False,
                                            transform=val_transform)
        test_dataset = datasets.CIFAR100(root=config.dataset_root, train=False, download=True,
                                            transform=val_transform)
    
    train_idx, valid_idx = train_test_split(config, train_dataset.targets)

    train_dataset = Subset(train_dataset, train_idx)
    valid_dataset = Subset(valid_dataset, valid_idx)

    return {"train": train_dataset, "valid": valid_dataset, "test": test_dataset}


def is_valid_dataset_name(dataset_name):
    return dataset_name.startswith('cifar')