from typing import Any, Dict
import tensorflow as tf
import numpy as np


@tf.keras.utils.register_keras_serializable()
class QLayerScale(tf.keras.layers.Layer):
    """
    see https://github.com/sayakpaul/deit-tf/blob/79cc91d3cb497f7abe5111fb536968fb9d9a754d/vit/layers/ls.py#L7
    """

    def __init__(self, projection_dim: int, init_values: np.ndarray, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.gamma = tf.Variable(
            init_values * tf.ones((projection_dim,)),
            name="layer_scale",
        )

    def call(self, x):
        return x * self.gamma

    def get_config(self) -> Dict[str, Any]:
        config = super().get_config()
        config["projection_dim"] = self.projection_dim
        config["init_values"] = self.init_values
        return config
