import tensorflow as tf
import tensorflow_datasets as tfds

def celeb_a_loader(split, batch_size=160):
    def preprocess_data(data):
        return tf.clip_by_value(float(data["image"]) / 255.0, 0.0, 1.0), tf.cast(list(data["attributes"].values()), tf.float32)

    # the validation dataset is shuffled as well, because data order matters
    # for the KID estimation
    return (
        tfds.load("celeb_a", split=split, shuffle_files=True)
        .map(preprocess_data, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .shuffle(10 * batch_size)
        .batch(batch_size, drop_remainder=True)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

def load_dataset(dataset_name, split, batch_size):
    if dataset_name == "celeb_a":
        return celeb_a_loader(
            split, batch_size
        )
    else:
        raise ValueError(f"Unknown dataset name: {dataset_name}.")
