import tensorflow as tf
from typing import Dict, Any


@tf.keras.utils.register_keras_serializable()
class StochasticDepth(tf.keras.layers.Layer):
    def __init__(self, drop_prop, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.drop_prob = drop_prop

    def call(self, x, training=None):
        if training:
            keep_prob = 1 - self.drop_prob
            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
            random_tensor = tf.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x

    def get_config(self) -> Dict[str, Any]:
        config = super().get_config()
        config["drop_prop"] = self.drop_prop
        return config
