


def preprocess(image, label):
    import tensorflow as tf

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = (image - tf.reshape(mean, [1, 1, 3])) / tf.reshape(std, [1, 1, 3])

    label = tf.one_hot(label, depth=120, dtype=tf.int32)

    return image, label


def augment(image, label):
    import tensorflow as tf
    import tensorflow_addons as tfa

    pad = 2
    
    image = tf.image.resize_with_crop_or_pad(image,
                                             16 + pad * 2,
                                             16 + pad * 2)
    image = tf.image.random_crop(image, size=[16, 16, 3])
    
    image = tf.image.random_flip_left_right(image)
    
    image = tfa.image.random_cutout(tf.expand_dims(image, 0), (4, 4))
    image = tf.squeeze(image, axis=0)
    return image, label


def load_data():
    def load_ds(split):
        import tensorflow_datasets as tfds

        ds = tfds.load('imagenet_resized/16x16', as_supervised=True,
                       split=split)
        return ds

    def load_ds_train():
        import tensorflow as tf

        ds_train = load_ds('train')
        ds_train = (
            ds_train
                .filter(lambda x, y: y < 120)  
                .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
                .cache()
                .map(augment, num_parallel_calls=tf.data.AUTOTUNE)
        )
        return ds_train

    def load_ds_valid():
        import tensorflow as tf

        ds_valid = load_ds('validation')
        ds_valid = (
            ds_valid
                .filter(lambda x, y: y < 120)  
                .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
                .cache()
        )
        return ds_valid

    return {
        'train_gen': load_ds_train,
        'train_size': 151700,
        'valid_gen': load_ds_valid,
        'valid_size': 6000,
        'types': ({'input_0': 'float32'}, 'int32'),
        'shapes': ({'input_0': (16, 16, 3)}, (120,)),
    }
