import numpy as np
from keras import backend as K
from keras import regularizers
from keras.engine import Layer
from keras.initializers import RandomNormal


class KerasMatrixFactorizer(Layer):
    def __init__(
            self,
            rank,
            input_dim_i,
            input_dim_j,
            embeddings_regularizer=None,
            **kwargs
    ):
        self.rank = rank
        self.input_dim_i = input_dim_i
        self.input_dim_j = input_dim_j
        self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
        super(KerasMatrixFactorizer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.i_embedding = self.add_weight(
            shape=(self.input_dim_i, self.rank),
            initializer=RandomNormal(mean=0.0, stddev=1 / np.sqrt(self.rank)),
            name='i_embedding',
            regularizer=self.embeddings_regularizer
        )
        self.j_embedding = self.add_weight(
            shape=(self.input_dim_j, self.rank),
            initializer=RandomNormal(mean=0.0, stddev=1 / np.sqrt(self.rank)),
            name='j_embedding',
            regularizer=self.embeddings_regularizer
        )
        self.i_bias = self.add_weight(
            shape=(self.input_dim_i, 1),
            initializer='zeros',
            name='i_bias'
        )
        self.j_bias = self.add_weight(
            shape=(self.input_dim_j, 1),
            initializer='zeros',
            name='j_bias'
        )
        self.constant = self.add_weight(
            shape=(1, 1),
            initializer='zeros',
            name='constant',
        )

        self.built = True
        super(KerasMatrixFactorizer, self).build(input_shape)

    def call(self, inputs):
        if K.dtype(inputs) != 'int32':
            inputs = K.cast(inputs, 'int32')
        # get the embeddings
        i = inputs[:, 0]  # by convention
        j = inputs[:, 1]
        i_embedding = K.gather(self.i_embedding, i)
        j_embedding = K.gather(self.j_embedding, j)
        i_bias = K.gather(self.i_bias, i)
        j_bias = K.gather(self.j_bias, j)
        # <i_embed, j_embed> + i_bias + j_bias + constant
        out = K.batch_dot(i_embedding, j_embedding, axes=[1, 1])
        out += (i_bias + j_bias + self.constant)
        return out

    def compute_output_shape(self, input_shape):
        return (input_shape[0], 1)
