import tensorflow as tf
from keras import layers


class Sampling(layers.Layer):
    """This class uses z_mean and z_log_var to sample z, the vector encoding the input."""
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch_size = tf.shape(z_mean)[0]  # Batch size
        dim = tf.shape(z_mean)[1]  # Dimension of the latent space
        epsilon = tf.random.normal(shape=(batch_size, dim))  # Random normal distribution
        return z_mean + epsilon * tf.exp(0.5 * z_log_var)  # Reparameterization trick


class Patches(layers.Layer):
    """This class extracts patches from images."""
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches


class PatchEncoder(layers.Layer):
    """This class encodes the patches."""
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded
