import tensorflow as tf

tfl = tf.keras.layers


class FCResidualBlock(tfl.Layer):
    def __init__(self, hidden_dims=[], activation=tfl.Activation('relu')):
        super(FCResidualBlock, self).__init__()
        self._hidden_dims = hidden_dims
        self._hidden_layers = [tfl.Dense(hidden_dim) for hidden_dim in hidden_dims]
        self._activation = activation

    def build(self, input_shape):
        self._input_dim = input_shape[-1]
        self._out_layer = tfl.Dense(self._input_dim)
        super(FCResidualBlock, self).build(input_shape)

    def call(self, inputs):
        res_out = self._activation(inputs)
        for layer in self._hidden_layers:
            res_out = layer(res_out)
            res_out = self._activation(res_out)
        res_out = self._out_layer(res_out)
        return inputs + res_out


class SpectralNormalization(tfl.Wrapper):
    def __init__(self, layer, power_iterations=1, eps=1e-12):
        assert isinstance(layer, tf.keras.layers.Layer)
        self.power_iterations = power_iterations
        self._eps = eps
        super(SpectralNormalization, self).__init__(layer)

    def build(self, input_shape):
        if not self.layer.built:
            self.layer.build(input_shape)
        self.kernel_shape = tf.shape(self.layer.kernel)
        self.u = self.add_weight(shape=[1, self.kernel_shape[-1]],
                                 initializer=tf.keras.initializers.RandomNormal(),
                                 trainable=False)
        self.built = True

    def call(self, inputs):
        self.power_iteration(self.power_iterations)
        return self.layer(inputs)

    def power_iteration(self, iterations):
        reshaped_kernel = tf.reshape(self.layer.kernel, [-1, self.kernel_shape[-1]])
        u = tf.identity(self.u)
        for _ in range(iterations):
            v = tf.matmul(u, tf.transpose(reshaped_kernel))
            v = tf.nn.l2_normalize(v, epsilon=self._eps)
            u = tf.matmul(v, reshaped_kernel)
            u = tf.nn.l2_normalize(u, epsilon=self._eps)
        u, v = tf.stop_gradient(u), tf.stop_gradient(v)
        self.u.assign(u)
        norm_value = tf.matmul(tf.matmul(v, reshaped_kernel), tf.transpose(u))
        self.layer.kernel.assign(self.layer.kernel / norm_value)
