import os
import torch
from torchvision import datasets, transforms


def get_datasets(dataset_config, split=0.8):
    if dataset_config['modality'] == 'image':
        preprocessing = []
        if dataset_config['gray_scale']:
            preprocessing.append(transforms.Grayscale())
        
        preprocessing.append(transforms.ToTensor())

        normalize = transforms.Normalize((dataset_config['data_mean']), (dataset_config['data_std']))
        preprocessing.append(normalize)

        preprocessing.append(transforms.Lambda(torch.flatten))

        preprocessing = transforms.Compose(preprocessing)

        if dataset_config['pre_split']:
            train_set_dir = os.path.join('.', 'data', dataset_config['data_dir'], 'train')
            test_set_dir = os.path.join('.', 'data', dataset_config['data_dir'], 'test')
            train_dataset = datasets.ImageFolder(train_set_dir, transform=preprocessing)
            test_dataset = datasets.ImageFolder(test_set_dir, transform=preprocessing)
        else:
            full_set_dir = os.path.join('data', dataset_config['data_dir'])
            full_dataset = datasets.ImageFolder(full_set_dir, transform=preprocessing)
            train_size = round(split * len(full_dataset))
            test_size = len(full_dataset) - train_size
            train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
    elif dataset_config['modality'] == 'text':
        pass

    else:
        raise ValueError("Unsupported data modality: {}".format(config['modality']))

    return train_dataset, test_dataset

