def get_randomized_datasets(cfg):
    """
    Returns one list for each of the following sets: train, unlabeled. The train list contains copies of the training set. 
    The unlabeled list contains unlabeled sets with different random labels.

    Parameters
    ----------
    cfg : DictConfig
        The configuration file for the experiment.

    Returns
    -------
    train_ds_list: List 
        A list containing copies of the training set.
    original_test_ds: Dict
        The test set, unchanged.
    unlabeled_ds_list: List
        A list containing copies of the unlabeled set, with randomized labels.
    original_validation_ds: Dict
        The validation set, unchanged.
    """
    # Get the original datasets
    original_train_ds, original_test_ds, original_unlabeled_ds, original_validation_ds = get_datasets(cfg)

    shuffled_labels = jnp.reshape(jnp.arange(dataset_num_classes[cfg.hyperparameters.dataset_name]),
                                  (1, -1)).T @ jnp.ones((1, original_unlabeled_ds['label'].shape[0]))

    key_shuffle_labels = jax.random.PRNGKey(cfg.hyperparameters.seed_shuffle_labels)

    shuffled_labels = jax.random.shuffle(key=key_shuffle_labels, x=shuffled_labels, axis=0)

    train_ds_list = []
    unlabeled_ds_list = []


    for i in range(cfg.hyperparameters.ensemble_size):
        train_ds_list.append(original_train_ds)
        original_unlabeled_ds_tmp = original_unlabeled_ds.copy()
        original_unlabeled_ds_tmp['label'] = shuffled_labels[i]
        unlabeled_ds_list.append(original_unlabeled_ds_tmp)

    return train_ds_list, original_test_ds, unlabeled_ds_list, original_validation_ds

def get_datasets(cfg):
    """
    Load train and test datasets into memory.
    Parameters
    ----------
    cfg : DictConfig
        The configuration file for the experiment.

    Returns
    -------
    train_ds: dict
        Dictionary with keys 'image' and 'label' corresponding to the training set.
    test_ds: dict
        Dictionary with keys 'image' and 'label' corresponding to the test set.
    unlabeled_ds: dict
        Dictionary with keys 'image' and 'label' corresponding to the unlabeled set.
    validation_ds: dict
        Dictionary with keys 'image' and 'label' corresponding to the validation set.
    """
    if cfg.hyperparameters.dataset_name not in ['Cifar10', 'Cifar100', 'svhn_cropped', 'fashion_mnist']:
        warnings.warn(cfg.hyperparameters.dataset_name+' might not exist in tensorflow_datasets. These experiments have been created for datasets ``Cifar10``, ``Cifar100``, ``svhn_cropped`` and ``fashion_mnist``.')

    if cfg.server.dataset_dir == 'default':
        ds_builder = tfds.builder(cfg.hyperparameters.dataset_name)
        ds_builder.download_and_prepare()
        train_ds_tmp = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
        test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    else:
        train_ds_tmp = tfds.load(name=cfg.hyperparameters.dataset_name, data_dir=cfg.server.dataset_dir, split='train',
                             batch_size=-1)
        test_ds = tfds.load(name=cfg.hyperparameters.dataset_name, data_dir=cfg.server.dataset_dir, split='test',
                            batch_size=-1)

    train_ds_tmp['image'] = jnp.float32(train_ds_tmp['image'])/255.
    test_ds['image'] = jnp.float32(test_ds['image'])/255.

    train_ds_tmp['label'] = jnp.int32(train_ds_tmp['label'])
    test_ds['label'] = jnp.int32(test_ds['label'])

    train_ds_tmp = {i: train_ds_tmp[i] for i in train_ds_tmp if i != 'id' and i != 'coarse_label'}
    test_ds = {i: test_ds[i] for i in test_ds if i != 'id' and i != 'coarse_label'}

    validation_ds = {}
    unlabeled_ds = {}
    train_ds = {}

    validation_ds['image'] = train_ds_tmp['image'][0:cfg.hyperparameters.size_validation]
    validation_ds['label'] = train_ds_tmp['label'][0:cfg.hyperparameters.size_validation]

    train_ds['image'] = train_ds_tmp['image'][cfg.hyperparameters.size_validation:cfg.hyperparameters.size_validation
                                                                              +cfg.hyperparameters.size_training]
    train_ds['label'] = train_ds_tmp['label'][cfg.hyperparameters.size_validation:cfg.hyperparameters.size_validation
                                                                              +cfg.hyperparameters.size_training]

    if cfg.hyperparameters.in_distribution:

        unlabeled_ds['image'] = train_ds_tmp['image'][
                                cfg.hyperparameters.size_validation + cfg.hyperparameters.size_training:]
        unlabeled_ds['label'] = train_ds_tmp['label'][
                                cfg.hyperparameters.size_validation + cfg.hyperparameters.size_training:]

    else:


        if cfg.server.dataset_dir == 'default':
            ds_builder = tfds.builder(cfg.hyperparameters.unlabeled_dataset_name)
            ds_builder.download_and_prepare()
            unlabeled_ds = tfds.as_numpy(ds_builder.as_dataset(split='unlabelled', batch_size=-1))
        else:
            unlabeled_ds = tfds.load(name=cfg.hyperparameters.unlabeled_dataset_name,
                                 data_dir=cfg.server.dataset_dir,
                                 split='unlabelled',
                                 batch_size=-1)

        # Resize to match the target distribution
        unlabeled_ds['image'] = tf.image.resize(unlabeled_ds['image'], [dataset_dimensions[cfg.hyperparameters.dataset_name][0],
                                        dataset_dimensions[cfg.hyperparameters.dataset_name][0]])
        # Turn to grayscale if the target distribution is greyscale
        if dataset_dimensions[cfg.hyperparameters.dataset_name][2] == 1:
            unlabeled_ds['image'] = tf.image.rgb_to_grayscale(
                unlabeled_ds['image'], name=None
            )
        unlabeled_ds['image'] = jnp.float32(unlabeled_ds['image']) / 255.
        unlabeled_ds['label'] = jnp.int32(unlabeled_ds['label'])
        unlabeled_ds = {i: unlabeled_ds[i] for i in unlabeled_ds if i != 'id'}


    if cfg.hyperparameters.mode == 'diverse':
        if cfg.hyperparameters.size_unlabeled > 0:
            unlabeled_ds['image'] = unlabeled_ds['image'][:cfg.hyperparameters.size_unlabeled]
            unlabeled_ds['label'] = unlabeled_ds['label'][:cfg.hyperparameters.size_unlabeled]

    return train_ds, test_ds, unlabeled_ds, validation_ds