from vprnn.layers import VanillaCell
import keras.backend as K
from keras.initializers import Constant
from keras.layers import SimpleRNNCell


class FastVanillaCell(VanillaCell):
    def __init__(self, *args,
                 alpha_init=-3.75,
                 beta_init=3.0,
                 clip_scalar=True,
                 **kwargs):
        self.alpha_init = alpha_init
        self.beta_init = beta_init
        self.alpha = None
        self.beta = None
        self.clip_scalar = clip_scalar
        super().__init__(*args, **kwargs)

    def build(self, input_shape):
        self.alpha = self.add_weight(name='alpha',
                                     shape=(1, 1),
                                     initializer=Constant(self.alpha_init))
        self.beta = self.add_weight(name='beta',
                                    shape=(1, 1),
                                    initializer=Constant(self.beta_init))
        super().build(input_shape)

    def call(self, inputs, states, **kwargs):
        prev_state = states[0]
        # call the VPRNN cell
        b = K.sigmoid(self.beta)
        a = K.sigmoid(self.alpha)
        if self.clip_scalar:
            b = K.clip(b, 0, 1 - 2 * a)
        h_t, _ = super().call(inputs, states, **kwargs)
        new_state = a * h_t + b * prev_state
        return new_state, [new_state]


class FastSimpleCell(SimpleRNNCell):
    def __init__(self, *args,
                 alpha_init=-3.0,
                 beta_init=3.0,
                 **kwargs):
        self.alpha_init = alpha_init
        self.beta_init = beta_init
        self.alpha = None
        self.beta = None
        super().__init__(*args, **kwargs)

    def build(self, input_shape):
        self.alpha = self.add_weight(name='alpha',
                                     shape=(1, 1),
                                     initializer=Constant(self.alpha_init))
        self.beta = self.add_weight(name='beta',
                                    shape=(1, 1),
                                    initializer=Constant(self.beta_init))
        super().build(input_shape)

    def call(self, inputs, states, **kwargs):
        prev_state = states[0]
        # call the RNN cell
        h_t, _ = super().call(inputs, states, **kwargs)
        new_state = K.sigmoid(self.alpha) * h_t + K.sigmoid(self.beta) * prev_state
        return new_state, [new_state]
