
import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package="retsim")
class RETSimEmbedding(tf.keras.layers.Layer):

    def __init__(self, model_path, chunk_size=512, **kwargs):
        super().__init__(**kwargs)
        self.model_path = model_path
        self.chunk_size = chunk_size
        self.model = self._load_model(model_path)
        self.embedding_size = self.model.layers[-1].output_shape[-1]

    def call(self, inputs, training=False):

        batch_size = tf.shape(inputs)[0]

        char_codepoints = tf.strings.unicode_decode(
            inputs,
            "utf-8",
            errors="replace",
        )

        char_codepoints = tf.squeeze(char_codepoints, axis=1)

        char_lengths = char_codepoints.row_lengths()
        max_chars = tf.reduce_max(char_lengths)
        max_chunks = tf.cast(tf.math.ceil(max_chars / self.chunk_size), dtype=tf.int32)
        num_chunks = tf.cast(tf.math.ceil(char_lengths / self.chunk_size), dtype=tf.int32)
        max_size = max_chunks * self.chunk_size
        s_dense = char_codepoints.to_tensor(shape=(batch_size, max_size))

        s_dense = tf.reshape(s_dense, shape=(batch_size * max_chunks, tf.constant(self.chunk_size)))

        # batch size x chunks, 512 (integers - unicode code points)
        # ----- 
        embeddings = self.model(s_dense, training=training)

        mask = tf.clip_by_value(tf.reduce_sum(s_dense, axis=1), 0, 1)
        mask = tf.expand_dims(mask, axis=1)
        mask = tf.cast(mask, dtype=tf.float32)

        embeddings = tf.multiply(embeddings, mask)

        embeddings = tf.reshape(embeddings, shape=(batch_size, max_chunks, self.embedding_size))
        embeddings_sum = tf.reduce_sum(embeddings, axis=1)

        num_chunks_filled = tf.reshape(tf.repeat(num_chunks, [self.embedding_size], axis=0), shape=(batch_size, self.embedding_size))
        num_chunks_filled  = tf.cast(num_chunks_filled, dtype=tf.float32)

        embedding_avg = embeddings_sum / num_chunks_filled

        return embedding_avg

    def _load_model(self, path) -> tf.keras.models.Model:
        """Load pretrained RETVec model.

        Args:
            model: Path to saved pretrained RETVec model. Either a pre-defined
                RETSim model name, str or pathlib.Path.

        Returns:
            The pretrained RETSim short text model, with weights frozen.
        """
        model = tf.keras.models.load_model(path)
        model.trainable = False
        return model

    def get_config(self):
        config = super().get_config()
        config.update({"model_path": self.model_path, "chunk_size": self.chunk_size})
        return config
