import tensorflow as tf

"""
TODO: Test
"""

AVAIL_AUG = dict()
AVAIL_AUG["flip"] = tf.keras.layers.experimental.preprocessing.RandomFlip
AVAIL_AUG["rotate"] = tf.keras.layers.experimental.preprocessing.RandomRotation
AVAIL_AUG["contrast"] = tf.keras.layers.experimental.preprocessing.RandomContrast
AVAIL_AUG["zoomout"] = tf.keras.layers.experimental.preprocessing.RandomZoom

def register_augmentation(name):
    def decorator(func):
        AVAIL_AUG[name] = func
        return func
    return decorator

def make_augmentor(**kwargs):
    layer_list = []
    for key, val in kwargs.items():
        if type(val) == tuple or type(val) == list:
            layer_list.append(AVAIL_AUG[key](*val))
        else:
            layer_list.append(AVAIL_AUG[key](val))
    return tf.keras.models.Sequential(layer_list)

    
@register_augmentation("shotnoise")
class RandomShotNoise(tf.keras.layers.Layer):
    def __init__(self, factor):
        super(RandomShotNoise, self).__init__()
        self.factor = factor
        
    def add_noise(self, x):
        return tf.clip_by_value(x + tf.random.uniform(x.shape, minval=-self.factor, maxval=self.factor), 0, 1)
        
    def call(self, inputs):
        return tf.map_fn(lambda x: self.add_noise(x), inputs)

@register_augmentation("hue")
class RandomHue(tf.keras.layers.Layer):
    def __init__(self, factor):
        super(RandomHue, self).__init__()
        self.factor = factor
        
    def call(self, inputs):
        return tf.map_fn(lambda x: tf.image.random_hue(x, self.factor), inputs)
    
@register_augmentation("brightness")
class RandomBrightness(tf.keras.layers.Layer):
    def __init__(self, max_delta):
        super(RandomBrightness, self).__init__()
        self.max_delta = max_delta
        
    def call(self, inputs):
        return tf.map_fn(lambda x: tf.image.random_brightness(x, self.max_delta), inputs)
    

@tf.function
def _norm_params(mask_size, offset=None):
    tf.assert_equal(
        tf.reduce_any(mask_size % 2 != 0),
        False,
        "mask_size should be divisible by 2",
    )
    if tf.rank(mask_size) == 0:
        mask_size = tf.stack([mask_size, mask_size])
    if offset is not None and tf.rank(offset) == 1:
        offset = tf.expand_dims(offset, 0)
    return mask_size, offset


def random_cutout(images, mask_size, constant_values = 0, seed = None) -> tf.Tensor:
    images = tf.convert_to_tensor(images)
    mask_size = tf.convert_to_tensor(mask_size)

    image_dynamic_shape = tf.shape(images)
    batch_size, image_height, image_width = (
        image_dynamic_shape[0],
        image_dynamic_shape[1],
        image_dynamic_shape[2],
    )

    mask_size, _ = _norm_params(mask_size, offset=None)

    half_mask_height = mask_size[0] // 2
    half_mask_width = mask_size[1] // 2

    cutout_center_height = tf.random.uniform(
        shape=[batch_size],
        minval=half_mask_height,
        maxval=image_height - half_mask_height,
        dtype=tf.int32,
        seed=seed,
    )
    cutout_center_width = tf.random.uniform(
        shape=[batch_size],
        minval=half_mask_width,
        maxval=image_width - half_mask_width,
        dtype=tf.int32,
        seed=seed,
    )

    offset = tf.transpose([cutout_center_height, cutout_center_width], [1, 0])
    return cutout(images, mask_size, offset, constant_values)


def cutout(images, mask_size, offset = (0, 0), constant_values = 0,):
    with tf.name_scope("cutout"):
        images = tf.convert_to_tensor(images)
        mask_size = tf.convert_to_tensor(mask_size)
        offset = tf.convert_to_tensor(offset)

        image_static_shape = images.shape
        image_dynamic_shape = tf.shape(images)
        image_height, image_width, channels = (
            image_dynamic_shape[1],
            image_dynamic_shape[2],
            image_dynamic_shape[3],
        )

        mask_size, offset = _norm_params(mask_size, offset)
        mask_size = mask_size // 2

        cutout_center_heights = offset[:, 0]
        cutout_center_widths = offset[:, 1]

        lower_pads = tf.maximum(0, cutout_center_heights - mask_size[0])
        upper_pads = tf.maximum(0, image_height - cutout_center_heights - mask_size[0])
        left_pads = tf.maximum(0, cutout_center_widths - mask_size[1])
        right_pads = tf.maximum(0, image_width - cutout_center_widths - mask_size[1])

        cutout_shape = tf.transpose(
            [
                image_height - (lower_pads + upper_pads),
                image_width - (left_pads + right_pads),
            ],
            [1, 0],
        )

        def fn(i):
            padding_dims = [
                [lower_pads[i], upper_pads[i]],
                [left_pads[i], right_pads[i]],
            ]
            mask = tf.pad(
                tf.zeros(cutout_shape[i], dtype=tf.bool),
                padding_dims,
                constant_values=True,
            )
            return mask

        mask = tf.map_fn(
            fn,
            tf.range(tf.shape(cutout_shape)[0]),
            fn_output_signature=tf.TensorSpec(
                shape=image_static_shape[1:-1], dtype=tf.bool
            ),
        )
        mask = tf.expand_dims(mask, -1)
        mask = tf.tile(mask, [1, 1, 1, channels])

        images = tf.where(
            mask,
            images,
            tf.cast(constant_values, dtype=images.dtype),
        )
        images.set_shape(image_static_shape)
        return images
    
    
class RandomCutout(tf.keras.layers.Layer):
    def __init__(self, mask_size):
        super(RandomCutout, self).__init__()
        self.mask_size = mask_size
    
    def call(self, inputs):
        return random_cutout(inputs, self.mask_size)
            
class FlipHorizontal(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        
    def call(self, inputs):
        return tf.image.flip_left_right(inputs)
        
"""---------------------------------------------------------------------"""
def make_patch_sampler():
    layer_list = [tf.keras.layers.experimental.preprocessing.RandomTranslation((-0.4, 0.4), (-0.4, 0.4), fill_mode="constant"),
                  RandomCutout(50), RandomCutout(50)]
                  #tf.keras.layers.experimental.preprocessing.RandomZoom((-0.5, -0.1), fill_mode="constant"),
                  #tf.keras.layers.experimental.preprocessing.RandomZoom((0.0, 0.5), fill_mode="constant")]
    return tf.keras.models.Sequential(layer_list)
