import functools
import jax.numpy as jnp

import tensorflow.compat.v1 as tf
# import tensorflow as tf
import tensorflow_datasets as tfds
import jax
from flax.jax_utils import prefetch_to_device
from tqdm import tqdm



# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU') 
tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True) 


def _pmap_device_order():
    return jax.local_devices()


def split_batch(batch):
    split = lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:])
    return jax.tree_map(split, batch)

def unsplit_batch(batch):
    batch = jax.device_get(batch) # ensure batch is on host
    unsplit = lambda x: x.reshape((-1,) + x.shape[2:])
    return jax.tree_map(unsplit, batch)

def replicate(tree, devices=None):
    """Replicates arrays to multiple devices.
    Args:
    tree: a pytree containing the arrays that should be replicated.
    devices: the devices the data is replicated to
        (default: same order as expected by `jax.pmap()`).
    Returns:
    A new pytree containing the replicated arrays.
    """
    devices = devices or _pmap_device_order()
    return jax.device_put_replicated(tree, devices)


def unreplicate(tree):
    return jax.tree_map(lambda x: x[0], jax.device_get(tree))

def get_iter(ds, parallel):
    if parallel:
        iterator =  map(split_batch, ds)
        iterator = prefetch_to_device(iterator, 2)
    else:
        iterator = iter(ds)
    return iterator



def get_ssl_cifar_loaders(batch_size):

    data_dir = 'datasets'
    
    def normalize_cifar(image, label):
        mean = tf.constant([0.4914, 0.4822, 0.4465])
        mean = tf.reshape(mean, (1,1,3))
        std = tf.constant([0.2023, 0.1994, 0.2010])
        std = tf.reshape(std, (1,1,3))
        image = (image - mean)/std
        return image, label

    def cifar_simclr_aug(x, y):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        image1 = preprocess_for_train(x, 32, 32, color_distort=True, color_jitter_strength=0.5, crop=True)
        image1, _ = normalize_cifar(image1, y)
        image2 = preprocess_for_train(x, 32, 32, color_distort=True, color_jitter_strength=0.5, crop=True)
        image2, _ = normalize_cifar(image2, y)
        return image1, image2, y

    def cifar_simclr_val(x, y):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        x = preprocess_for_eval(x, 32, 32, crop=False)
        return x, y
    
    
    train_ds = tfds.load(name='cifar10', split='train',
                         as_supervised=True, data_dir=data_dir)

    ssl_train_ds = (
        train_ds
        .shuffle(len(train_ds))
        .map(cifar_simclr_aug, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    ssl_train_ds = tfds.as_numpy(ssl_train_ds)

    sub_train_ds = tfds.load(name='cifar10', split='train', 
                             as_supervised=True, data_dir=data_dir)
                         
    lep_train_ds = (
        sub_train_ds
        .map(cifar_simclr_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize_cifar)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    lep_train_ds = tfds.as_numpy(lep_train_ds)
    
    test_ds = tfds.load(name='cifar10', split='test',
                        as_supervised=True, data_dir=data_dir)

    test_ds = (
        test_ds
        .map(cifar_simclr_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize_cifar)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    test_ds = tfds.as_numpy(test_ds)

    return ssl_train_ds, lep_train_ds, test_ds

def get_ssl_cifar100_loaders(batch_size):
    
    data_dir = 'datasets'
    
    def normalize_cifar(image, label):
        mean = tf.constant([0.5071, 0.4867, 0.4408])
        mean = tf.reshape(mean, (1,1,3))
        std = tf.constant([0.2675, 0.2565, 0.2761])
        std = tf.reshape(std, (1,1,3))
        image = (image - mean)/std
        return image, label

    def cifar_simclr_aug(x, y):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        image1 = preprocess_for_train(x, 32, 32, color_distort=True, color_jitter_strength=0.5, crop=True)
        image1, _ = normalize_cifar(image1, y)
        image2 = preprocess_for_train(x, 32, 32, color_distort=True, color_jitter_strength=0.5, crop=True)
        image2, _ = normalize_cifar(image2, y)
        return image1, image2, y

    def cifar_simclr_val(x, y):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        x = preprocess_for_eval(x, 32, 32, crop=False)
        return x, y
    
    
    train_ds = tfds.load(name='cifar100', split='train',
                         as_supervised=True, data_dir=data_dir)

    ssl_train_ds = (
        train_ds
        .shuffle(len(train_ds))
        .map(cifar_simclr_aug, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    ssl_train_ds = tfds.as_numpy(ssl_train_ds)

    sub_train_ds = tfds.load(name='cifar100', split='train', 
                             as_supervised=True, data_dir=data_dir)
                         
    lep_train_ds = (
        sub_train_ds
        .map(cifar_simclr_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize_cifar)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    lep_train_ds = tfds.as_numpy(lep_train_ds)
    
    test_ds = tfds.load(name='cifar100', split='test',
                        as_supervised=True, data_dir=data_dir)

    test_ds = (
        test_ds
        .map(cifar_simclr_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize_cifar)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    test_ds = tfds.as_numpy(test_ds)

    return ssl_train_ds, lep_train_ds, test_ds


def get_ssl_stl10_loaders(batch_size):

    base_dir = 'datasets'

    def normalize_stl10(image, label):
        mean = tf.constant([0.485, 0.456, 0.406])
        mean = tf.reshape(mean, (1,1,3))
        std = tf.constant([0.229, 0.224, 0.225])
        std = tf.reshape(std, (1,1,3))
        image = (image - mean)/std
        return image, label

    def stl10_simclr_aug(x, y):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        image1 = preprocess_for_train(x, 96, 96, color_distort=True, crop=True, blur=True)
        image1, _ = normalize_stl10(image1, y)
        image2 = preprocess_for_train(x, 96, 96, color_distort=True, crop=True, blur=True)
        image2, _ = normalize_stl10(image2, y)
        return image1, image2, y

    def stl10_simclr_val(x, y):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        x = preprocess_for_eval(x, 96, 96, crop=True)
        return x, y
    

    split = tfds.split_for_jax_process('unlabelled')

    train_ds = tfds.load('stl10', split=split, data_dir=base_dir, download=False,
                         shuffle_files=True, as_supervised=True)

    # concatenate train and unlabeled data
    train_ds = train_ds.concatenate(tfds.load('stl10', split='train', data_dir=base_dir, download=False,
                                              shuffle_files=True, as_supervised=True))

    ssl_train_ds = (
        train_ds
        .shuffle(len(train_ds))
        .map(stl10_simclr_aug, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    ssl_train_ds = tfds.as_numpy(ssl_train_ds)




    sub_split = tfds.split_for_jax_process('train', drop_remainder=True)
    sub_train_ds = tfds.load('stl10',
            split=sub_split, data_dir=base_dir, download=False, as_supervised=True)
                         
    lep_train_ds = (
        sub_train_ds
        .map(stl10_simclr_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize_stl10)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    lep_train_ds = tfds.as_numpy(lep_train_ds)


    test_ds = tfds.load('stl10',
            split='test', data_dir=base_dir, download=False, as_supervised=True)


    test_ds = (
        test_ds
        .map(stl10_simclr_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize_stl10)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    test_ds = tfds.as_numpy(test_ds)
    
    return ssl_train_ds, lep_train_ds, test_ds


def get_ssl_imagenet_loaders(batch_size):

    base_dir = 'datasets'

    def normalize_imagenet(image, label):
        mean = tf.constant([0.485, 0.456, 0.406])
        mean = tf.reshape(mean, (1,1,3))
        std = tf.constant([0.229, 0.224, 0.225])
        std = tf.reshape(std, (1,1,3))
        image = (image - mean)/std
        return image, label

    def imagenet_simclr_aug(x, y):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        image1 = preprocess_for_train(x, 224, 224, color_distort=True, crop=True, blur=True)
        image1, _ = normalize_imagenet(image1, y)
        image2 = preprocess_for_train(x, 224, 224, color_distort=True, crop=True, blur=True)
        image2, _ = normalize_imagenet(image2, y)
        return image1, image2, y

    def imagenet_simclr_val(x, y):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        x = preprocess_for_eval(x, 224, 224, crop=True)
        return x, y
    
    options = tf.data.Options()
    options.threading.private_threadpool_size = 48
    options.threading.max_intra_op_parallelism = 1
    options.experimental_optimization.map_parallelization = True

    split = tfds.split_for_jax_process('train', drop_remainder=True)

    train_ds = tfds.load('imagenet2012',
            split=split, data_dir=base_dir, download=False, shuffle_files=True,
            as_supervised=True, with_info=False)
    

    ssl_train_ds = (
        train_ds
        .with_options(options)
        .shuffle(buffer_size=10*batch_size)
        .map(imagenet_simclr_aug, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    ssl_train_ds = tfds.as_numpy(ssl_train_ds)




    sub_split = tfds.split_for_jax_process('train', drop_remainder=True)

    sub_train_ds = tfds.load('imagenet2012',
            split=sub_split, data_dir=base_dir, download=False, as_supervised=True)
                         
    lep_train_ds = (
        sub_train_ds
        .map(imagenet_simclr_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize_imagenet)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    lep_train_ds = tfds.as_numpy(lep_train_ds)


    test_ds = tfds.load('imagenet2012',
            split='validation', data_dir=base_dir, download=False, as_supervised=True)


    test_ds = (
        test_ds
        .map(imagenet_simclr_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize_imagenet)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    test_ds = tfds.as_numpy(test_ds)
    
    return ssl_train_ds, lep_train_ds, test_ds




def get_ssl_im32_loaders(batch_size):

    data_dir = 'datasets'
    
    def normalize_imagenet(image, label):
        mean = tf.constant([0.485, 0.456, 0.406])
        mean = tf.reshape(mean, (1,1,3))
        std = tf.constant([0.229, 0.224, 0.225])
        std = tf.reshape(std, (1,1,3))
        image = (image - mean)/std
        return image, label

    def im32_simclr_aug(x, y):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        image1 = preprocess_for_train(x, 32, 32, color_distort=True, color_jitter_strength=0.5, crop=True)
        image1, _ = normalize_imagenet(image1, y)
        image2 = preprocess_for_train(x, 32, 32, color_distort=True, color_jitter_strength=0.5, crop=True)
        image2, _ = normalize_imagenet(image2, y)
        return image1, image2, y

    def im32_simclr_val(x, y):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        x = preprocess_for_eval(x, 32, 32, crop=False)
        return x, y


    split = tfds.split_for_jax_process('train', drop_remainder=True)
    train_ds = tfds.load(name='imagenet_resized/32x32', split=split,
                         as_supervised=True, data_dir=data_dir, 
                         shuffle_files=True, download=False)


    ssl_train_ds = (
        train_ds
        .shuffle(len(train_ds))
        .map(im32_simclr_aug, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    ssl_train_ds = tfds.as_numpy(ssl_train_ds)

    split = tfds.split_for_jax_process('train', drop_remainder=True)
    sub_train_ds = tfds.load(name='imagenet_resized/32x32', split=split,
                             as_supervised=True, data_dir=data_dir, 
                             download=False)
                         
    lep_train_ds = (
        sub_train_ds
        .map(im32_simclr_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize_imagenet)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    lep_train_ds = tfds.as_numpy(lep_train_ds)
    
    test_ds = tfds.load(name='imagenet_resized/32x32', split='validation',
                        as_supervised=True, data_dir=data_dir)

    test_ds = (
        test_ds
        .map(im32_simclr_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize_imagenet)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    test_ds = tfds.as_numpy(test_ds)

    return ssl_train_ds, lep_train_ds, test_ds


def get_ssl_tinyimagenet_loaders(batch_size):
    
    base_dir = 'datasets'

    def normalize_imagenet(image, label):
        mean = tf.constant([0.485, 0.456, 0.406])
        mean = tf.reshape(mean, (1,1,3))
        std = tf.constant([0.229, 0.224, 0.225])
        std = tf.reshape(std, (1,1,3))
        image = (image - mean)/std
        return image, label

    def imagenet_simclr_aug(x, y):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        image1 = preprocess_for_train(x, 64, 64, color_distort=True, crop=True, color_jitter_strength=0.5)
        image1, _ = normalize_imagenet(image1, y)
        image2 = preprocess_for_train(x, 64, 64, color_distort=True, crop=True, color_jitter_strength=0.5)
        image2, _ = normalize_imagenet(image2, y)
        return image1, image2, y

    def imagenet_simclr_val(x, y):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        x = preprocess_for_eval(x, 64, 64, crop=True)
        return x, y
    
    options = tf.data.Options()
    options.threading.private_threadpool_size = 48
    options.threading.max_intra_op_parallelism = 1
    options.experimental_optimization.map_parallelization = True

    split = tfds.split_for_jax_process('train', drop_remainder=True)

    train_ds = tfds.load('tiny_imagenet_dataset',
            split=split, data_dir=base_dir, download=False, shuffle_files=True,
            as_supervised=True, with_info=False)
    

    ssl_train_ds = (
        train_ds
        .with_options(options)
        .shuffle(buffer_size=10*batch_size)
        .map(imagenet_simclr_aug, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    ssl_train_ds = tfds.as_numpy(ssl_train_ds)




    sub_split = tfds.split_for_jax_process('train', drop_remainder=True)

    sub_train_ds = tfds.load('tiny_imagenet_dataset',
            split=sub_split, data_dir=base_dir, download=False, as_supervised=True)
                         
    lep_train_ds = (
        sub_train_ds
        .map(imagenet_simclr_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize_imagenet)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    lep_train_ds = tfds.as_numpy(lep_train_ds)


    test_ds = tfds.load('tiny_imagenet_dataset',
            split='validation', data_dir=base_dir, download=False, as_supervised=True)


    test_ds = (
        test_ds
        .map(imagenet_simclr_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize_imagenet)
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    test_ds = tfds.as_numpy(test_ds)
    
    return ssl_train_ds, lep_train_ds, test_ds





import numpy as np

def get_lep_loaders(bsize, state, lep_ds, parallel, num_batches=-1):
    
    tot_params = {'params': state.params, 'batch_stats': state.batch_stats}
    if state.direct_pred is not None:
        tot_params['direct_pred'] = state.direct_pred

    if parallel:
        apply_fn = jax.pmap(state.apply_fn, static_broadcasted_argnums=2)
        avg_axis = (2, 3)
    else:
        apply_fn = state.apply_fn
        avg_axis = (1, 2)

    # do one forward pass to get the output shape
    x, _ = next(iter(lep_ds))
    x = split_batch(x) if parallel else x
    z, proj, _, _ = apply_fn(tot_params, x, False)
    flatten_z = jnp.mean(z['out'], axis=avg_axis)
    lep_batch_size = unsplit_batch(flatten_z).shape[0] if parallel else flatten_z.shape[0]

    size = 0
    lep_it = get_iter(lep_ds, parallel)

    zs = np.empty((len(lep_ds)*lep_batch_size, flatten_z.shape[-1]), dtype=np.float32)
    ys = np.empty((len(lep_ds)*lep_batch_size), dtype=np.int32)
    projs = np.empty((len(lep_ds)*lep_batch_size, proj['out'].shape[-1]), dtype=np.float32)

    print('Iterating over the lep dataset...')
    for i, (x, y) in tqdm(enumerate(lep_it), total=len(lep_ds) if num_batches < 0 else num_batches):
        if num_batches > 0 and i >= num_batches:
            break
        z, proj, _, _ = apply_fn(tot_params, x, False)
        flatten_z = jnp.mean(z['out'], axis=avg_axis)

        flatten_z = unsplit_batch(flatten_z) if parallel else flatten_z
        proj = unsplit_batch(proj) if parallel else proj
        y = unsplit_batch(y) if parallel else y
        
        from_idx, to_idx = i*lep_batch_size, i*lep_batch_size + flatten_z.shape[0]
        size += flatten_z.shape[0]

        zs[from_idx:to_idx] = np.asarray(flatten_z, dtype=np.float32)
        ys[from_idx:to_idx] = np.asarray(y, dtype=np.int32)
        projs[from_idx:to_idx] = np.asarray(proj['out'], dtype=np.float32)

    # # print minimum proportion class
    # print('Minimum label percentage in the lep dataset:')
    # min_prop = 100
    # for i in range(jnp.unique(ys).shape[0]):
    #     prop = jnp.sum(ys == i)/ys.shape[0]*100
    #     if prop < min_prop:
    #         min_prop = prop
    
    # print("%.2f" % (min_prop) + "%")

    print('Creating tensorflow dataset features...')
    batched_ds = tf.data.Dataset.from_tensor_slices((zs[:size], ys[:size]))

    print('Creating tensorflow dataset projections...')
    batched_proj_ds = tf.data.Dataset.from_tensor_slices((projs[:size], ys[:size]))

    print('Shuffling and batching...')
    batched_ds = batched_ds.shuffle(len(batched_ds))
    batched_ds = batched_ds.batch(bsize).prefetch(tf.data.experimental.AUTOTUNE)
    batched_ds = tfds.as_numpy(batched_ds)

    batched_proj_ds = batched_proj_ds.shuffle(len(batched_proj_ds))
    batched_proj_ds = batched_proj_ds.batch(bsize).prefetch(tf.data.experimental.AUTOTUNE)
    batched_proj_ds = tfds.as_numpy(batched_proj_ds)

    print('Number of data samples in the lep dataset:', size)
    return batched_ds, batched_proj_ds





















### AUGMENTATIONS FUNC FROM SIMCLR REPO ###


def random_apply(func, p, x):
    """Randomly apply function func to x with probability p."""
    return tf.cond(
        tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32),
                tf.cast(p, tf.float32)),
        lambda: func(x),
        lambda: x)


def random_brightness(image, max_delta, impl='simclrv2'):
    """A multiplicative vs additive change of brightness."""
    if impl == 'simclrv2':
        factor = tf.random_uniform(
          [], tf.maximum(1.0 - max_delta, 0), 1.0 + max_delta)
        image = image * factor
    elif impl == 'simclrv1':
        image = tf.image.random_brightness(image, max_delta=max_delta)
    else:
        raise ValueError('Unknown impl {} for random brightness.'.format(impl))
    return image


def to_grayscale(image, keep_channels=True):
    image = tf.image.rgb_to_grayscale(image)
    if keep_channels:
        image = tf.tile(image, [1, 1, 3])
    return image


def color_jitter(image, strength, random_order=True, impl='simclrv2'):
    """Distorts the color of the image.
    Args:
    image: The input image tensor.
    strength: the floating number for the strength of the color augmentation.
    random_order: A bool, specifying whether to randomize the jittering order.
    impl: 'simclrv1' or 'simclrv2'.  Whether to use simclrv1 or simclrv2's
        version of random brightness.
    Returns:
    The distorted image tensor.
    """
    brightness = 0.8 * strength
    contrast = 0.8 * strength
    saturation = 0.8 * strength
    hue = 0.2 * strength
    if random_order:
        return color_jitter_rand(
            image, brightness, contrast, saturation, hue, impl=impl)
    else:
        return color_jitter_nonrand(
            image, brightness, contrast, saturation, hue, impl=impl)


def color_jitter_nonrand(image,
                         brightness=0,
                         contrast=0,
                         saturation=0,
                         hue=0,
                         impl='simclrv2'):
    """Distorts the color of the image (jittering order is fixed).
    Args:
    image: The input image tensor.
    brightness: A float, specifying the brightness for color jitter.
    contrast: A float, specifying the contrast for color jitter.
    saturation: A float, specifying the saturation for color jitter.
    hue: A float, specifying the hue for color jitter.
    impl: 'simclrv1' or 'simclrv2'.  Whether to use simclrv1 or simclrv2's
        version of random brightness.
    Returns:
    The distorted image tensor.
    """
    with tf.name_scope('distort_color'):
        def apply_transform(i, x, brightness, contrast, saturation, hue):
            """Apply the i-th transformation."""
            if brightness != 0 and i == 0:
                x = random_brightness(x, max_delta=brightness, impl=impl)
            elif contrast != 0 and i == 1:
                x = tf.image.random_contrast(
                    x, lower=1-contrast, upper=1+contrast)
            elif saturation != 0 and i == 2:
                x = tf.image.random_saturation(
                    x, lower=1-saturation, upper=1+saturation)
            elif hue != 0:
                x = tf.image.random_hue(x, max_delta=hue)
            return x

        for i in range(4):
            image = apply_transform(i, image, brightness, contrast, saturation, hue)
            image = tf.clip_by_value(image, 0., 1.)
        return image


def color_jitter_rand(image,
                      brightness=0,
                      contrast=0,
                      saturation=0,
                      hue=0,
                      impl='simclrv2'):
    """Distorts the color of the image (jittering order is random).
    Args:
    image: The input image tensor.
    brightness: A float, specifying the brightness for color jitter.
    contrast: A float, specifying the contrast for color jitter.
    saturation: A float, specifying the saturation for color jitter.
    hue: A float, specifying the hue for color jitter.
    impl: 'simclrv1' or 'simclrv2'.  Whether to use simclrv1 or simclrv2's
        version of random brightness.
    Returns:
    The distorted image tensor.
    """
    with tf.name_scope('distort_color'):
        def apply_transform(i, x):
            """Apply the i-th transformation."""
            def brightness_foo():
                if brightness == 0:
                    return x
                else:
                    return random_brightness(x, max_delta=brightness, impl=impl)

            def contrast_foo():
                if contrast == 0:
                    return x
                else:
                    return tf.image.random_contrast(x, lower=1-contrast, upper=1+contrast)
            def saturation_foo():
                if saturation == 0:
                    return x
                else:
                    return tf.image.random_saturation(
                          x, lower=1-saturation, upper=1+saturation)
            def hue_foo():
                if hue == 0:
                    return x
                else:
                    return tf.image.random_hue(x, max_delta=hue)
            x = tf.cond(tf.less(i, 2),
                      lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo),
                      lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo))
            return x

        perm = tf.random_shuffle(tf.range(4))
        for i in range(4):
            image = apply_transform(perm[i], image)
            image = tf.clip_by_value(image, 0., 1.)
        return image


def _compute_crop_shape(
    image_height, image_width, aspect_ratio, crop_proportion):
    """Compute aspect ratio-preserving shape for central crop.
    The resulting shape retains `crop_proportion` along one side and a proportion
    less than or equal to `crop_proportion` along the other side.
    Args:
    image_height: Height of image to be cropped.
    image_width: Width of image to be cropped.
    aspect_ratio: Desired aspect ratio (width / height) of output.
    crop_proportion: Proportion of image to retain along the less-cropped side.
    Returns:
    crop_height: Height of image after cropping.
    crop_width: Width of image after cropping.
    """
    image_width_float = tf.cast(image_width, tf.float32)
    image_height_float = tf.cast(image_height, tf.float32)

    def _requested_aspect_ratio_wider_than_image():
        crop_height = tf.cast(tf.rint(
            crop_proportion / aspect_ratio * image_width_float), tf.int32)
        crop_width = tf.cast(tf.rint(
            crop_proportion * image_width_float), tf.int32)
        return crop_height, crop_width

    def _image_wider_than_requested_aspect_ratio():
        crop_height = tf.cast(
            tf.rint(crop_proportion * image_height_float), tf.int32)
        crop_width = tf.cast(tf.rint(
            crop_proportion * aspect_ratio *
            image_height_float), tf.int32)
        return crop_height, crop_width

    return tf.cond(
      aspect_ratio > image_width_float / image_height_float,
      _requested_aspect_ratio_wider_than_image,
      _image_wider_than_requested_aspect_ratio)


def center_crop(image, height, width, crop_proportion):
    """Crops to center of image and rescales to desired size.
    Args:
    image: Image Tensor to crop.
    height: Height of image to be cropped.
    width: Width of image to be cropped.
    crop_proportion: Proportion of image to retain along the less-cropped side.
    Returns:
    A `height` x `width` x channels Tensor holding a central crop of `image`.
    """
    shape = tf.shape(image)
    image_height = shape[0]
    image_width = shape[1]
    crop_height, crop_width = _compute_crop_shape(
      image_height, image_width, width / height, crop_proportion)
    offset_height = ((image_height - crop_height) + 1) // 2
    offset_width = ((image_width - crop_width) + 1) // 2
    image = tf.image.crop_to_bounding_box(
      image, offset_height, offset_width, crop_height, crop_width)

    image = tf.image.resize_bicubic([image], [height, width])[0]

    return image


def distorted_bounding_box_crop(image,
                                bbox,
                                min_object_covered=0.1,
                                aspect_ratio_range=(0.75, 1.33),
                                area_range=(0.05, 1.0),
                                max_attempts=100,
                                scope=None):
    """Generates cropped_image using one of the bboxes randomly distorted.
    See `tf.image.sample_distorted_bounding_box` for more documentation.
    Args:
    image: `Tensor` of image data.
    bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
        where each coordinate is [0, 1) and the coordinates are arranged
        as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
        image.
    min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
        area of the image must contain at least this fraction of any bounding
        box supplied.
    aspect_ratio_range: An optional list of `float`s. The cropped area of the
        image must have an aspect ratio = width / height within this range.
    area_range: An optional list of `float`s. The cropped area of the image
        must contain a fraction of the supplied image within in this range.
    max_attempts: An optional `int`. Number of attempts at generating a cropped
        region of the image of the specified constraints. After `max_attempts`
        failures, return the entire image.
    scope: Optional `str` for name scope.
    Returns:
    (cropped image `Tensor`, distorted bbox `Tensor`).
    """
    with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
        shape = tf.shape(image)
        sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
            shape,
            bounding_boxes=bbox,
            min_object_covered=min_object_covered,
            aspect_ratio_range=aspect_ratio_range,
            area_range=area_range,
            max_attempts=max_attempts,
            use_image_if_no_bounding_boxes=True)
        bbox_begin, bbox_size, _ = sample_distorted_bounding_box

        # Crop the image to the specified bounding box.
        offset_y, offset_x, _ = tf.unstack(bbox_begin)
        target_height, target_width, _ = tf.unstack(bbox_size)
        image = tf.image.crop_to_bounding_box(
            image, offset_y, offset_x, target_height, target_width)

        return image


def crop_and_resize(image, height, width):
    """Make a random crop and resize it to height `height` and width `width`.
    Args:
    image: Tensor representing the image.
    height: Desired image height.
    width: Desired image width.
    Returns:
    A `height` x `width` x channels Tensor holding a random crop of `image`.
    """
    bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
    aspect_ratio = width / height
    image = distorted_bounding_box_crop(
        image,
        bbox,
        min_object_covered=0.1,
        aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio),
        area_range=(0.08, 1.0),
        max_attempts=100,
        scope=None)
    return tf.image.resize_bicubic([image], [height, width])[0]


def gaussian_blur(image, kernel_size, sigma, padding='SAME'):
    """Blurs the given image with separable convolution.
    Args:
    image: Tensor of shape [height, width, channels] and dtype float to blur.
    kernel_size: Integer Tensor for the size of the blur kernel. This is should
      be an odd number. If it is an even number, the actual kernel size will be
      size + 1.
    sigma: Sigma value for gaussian operator.
    padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
    Returns:
    A Tensor representing the blurred image.
    """
    radius = tf.cast(kernel_size / 2, tf.int32)
    kernel_size = radius * 2 + 1
    x = tf.cast(tf.range(-radius, radius + 1), tf.float32)
    blur_filter = tf.exp(
        -tf.pow(x, 2.0) / (2.0 * tf.pow(tf.cast(sigma, tf.float32), 2.0)))
    blur_filter /= tf.reduce_sum(blur_filter)
    # One vertical and one horizontal filter.
    blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
    blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
    num_channels = tf.shape(image)[-1]
    blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
    blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
    expand_batch_dim = image.shape.ndims == 3
    if expand_batch_dim:
        # Tensorflow requires batched input to convolutions, which we can fake with
        # an extra dimension.
        image = tf.expand_dims(image, axis=0)
    blurred = tf.nn.depthwise_conv2d(
      image, blur_h, strides=[1, 1, 1, 1], padding=padding)
    blurred = tf.nn.depthwise_conv2d(
      blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
    if expand_batch_dim:
        blurred = tf.squeeze(blurred, axis=0)
    return blurred


def random_crop_with_resize(image, height, width, p=1.0):
    """Randomly crop and resize an image.
    Args:
    image: `Tensor` representing an image of arbitrary size.
    height: Height of output image.
    width: Width of output image.
    p: Probability of applying this transformation.
    Returns:
    A preprocessed image `Tensor`.
    """
    def _transform(image):  # pylint: disable=missing-docstring
        image = crop_and_resize(image, height, width)
        return image
    return random_apply(_transform, p=p, x=image)


def random_color_jitter(image, strength=1.0, p=1.0, impl='simclrv2'):

    def _transform(image):
        color_jitter_t = functools.partial(
            color_jitter, strength=strength, impl=impl)  # strength of 1.0 by default
        image = random_apply(color_jitter_t, p=0.8, x=image)
        return random_apply(to_grayscale, p=0.2, x=image)
    return random_apply(_transform, p=p, x=image)


def random_blur(image, height, width, p=0.5):
    """Randomly blur an image.
    Args:
    image: `Tensor` representing an image of arbitrary size.
    height: Height of output image.
    width: Width of output image.
    p: probability of applying this transformation.
    Returns:
    A preprocessed image `Tensor`.
    """
    del width
    def _transform(image):
        sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32)
        return gaussian_blur(
            image, kernel_size=height//10, sigma=sigma, padding='SAME')
    return random_apply(_transform, p=p, x=image)


def batch_random_blur(images_list, height, width, blur_probability=0.5):
    """Apply efficient batch data transformations.
    Args:
    images_list: a list of image tensors.
    height: the height of image.
    width: the width of image.
    blur_probability: the probaility to apply the blur operator.
    Returns:
    Preprocessed feature list.
    """
    def generate_selector(p, bsz):
        shape = [bsz, 1, 1, 1]
        selector = tf.cast(
            tf.less(tf.random_uniform(shape, 0, 1, dtype=tf.float32), p),
            tf.float32)
        return selector

    new_images_list = []
    for images in images_list:
        images_new = random_blur(images, height, width, p=1.)
        selector = generate_selector(blur_probability, tf.shape(images)[0])
        images = images_new * selector + images * (1 - selector)
        images = tf.clip_by_value(images, 0., 1.)
        new_images_list.append(images)

    return new_images_list


def preprocess_for_train(image,
                         height,
                         width,
                         color_distort=True,
                         color_jitter_strength=1.0,
                         crop=True,
                         flip=True,
                         blur=False,
                         impl='simclrv2'):
    """Preprocesses the given image for training.
    Args:
    image: `Tensor` representing an image of arbitrary size.
    height: Height of output image.
    width: Width of output image.
    color_distort: Whether to apply the color distortion.
    crop: Whether to crop the image.
    flip: Whether or not to flip left and right of an image.
    blur: Whether or not to apply a blur to an image.
    impl: 'simclrv1' or 'simclrv2'.  Whether to use simclrv1 or simclrv2's
        version of random brightness.
    Returns:
    A preprocessed image `Tensor`.
    """
    if crop:
        image = random_crop_with_resize(image, height, width)
    if flip:
        image = tf.image.random_flip_left_right(image)
    if color_distort:
        image = random_color_jitter(image, strength=color_jitter_strength, impl=impl)
    if blur:
        image = random_blur(image, height, width)
    image = tf.reshape(image, [height, width, 3])
    image = tf.clip_by_value(image, 0., 1.)
    return image


def preprocess_for_eval(image, height, width, crop=True):
    """Preprocesses the given image for evaluation.
    Args:
    image: `Tensor` representing an image of arbitrary size.
    height: Height of output image.
    width: Width of output image.
    crop: Whether or not to (center) crop the test images.
    Returns:
    A preprocessed image `Tensor`.
    """
    if crop:
        image = center_crop(image, height, width, crop_proportion=0.875) # standard for imagenet
    image = tf.reshape(image, [height, width, 3])
    image = tf.clip_by_value(image, 0., 1.)
    return image


def preprocess_image(image, height, width, is_training=False,
                     color_distort=True, test_crop=True):
    """Preprocesses the given image.
    Args:
    image: `Tensor` representing an image of arbitrary size.
    height: Height of output image.
    width: Width of output image.
    is_training: `bool` for whether the preprocessing is for training.
    color_distort: whether to apply the color distortion.
    test_crop: whether or not to extract a central crop of the images
        (as for standard ImageNet evaluation) during the evaluation.
    Returns:
    A preprocessed image `Tensor` of range [0, 1].
    """
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    if is_training:
        return preprocess_for_train(image, height, width, color_distort)
    else:
        return preprocess_for_eval(image, height, width, test_crop)



### END OF AUGMENTATIONS ###




