import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Layer, Embedding, LSTM, Dense, Input
from tensorflow.keras.models import Model

from ariel_tests.models.AriEL.tf_helpers import tf_update_bounds_encoder, pzToSymbol_withArgmax
from ariel_tests.models.transformer.utils import replace_column


def predefined_model(vocab_size, emb_dim, units=128):
    embedding = Embedding(vocab_size, emb_dim, mask_zero='True')
    lstm = LSTM(units, return_sequences=False)

    input_question = Input(shape=(None,), name='discrete_sequence')
    embed = embedding(input_question)
    lstm_output = lstm(embed)
    softmax = Dense(vocab_size, activation='softmax')(lstm_output)

    return Model(inputs=input_question, outputs=softmax)


class UpdateBoundsEncoder(Layer):

    def __init__(self, lat_dim, vocab_size, curDim, **kwargs):
        super(UpdateBoundsEncoder, self).__init__(**kwargs)

        self.lat_dim, self.vocab_size, self.curDim = lat_dim, vocab_size, curDim

    def call(self, inputs, training=None):
        low_bound, upp_bound, softmax, s_j = inputs
        tf_curDim = tf.constant(self.curDim)
        low_bound, upp_bound = tf_update_bounds_encoder(low_bound, upp_bound, softmax, s_j, tf_curDim)
        return [low_bound, upp_bound]

    def compute_output_shape(self, input_shape):
        return input_shape[0], input_shape[1]


class UpdateBoundsDecoder(Layer):

    def __init__(self, curDim, **kwargs):
        super(UpdateBoundsDecoder, self).__init__(**kwargs)

        self.curDim = curDim

    def call(self, inputs, training=None):
        low_bound, upp_bound, softmax = inputs

        c_upp = K.cumsum(softmax, axis=1)
        c_low = tf.cumsum(softmax, axis=1, exclusive=True)
        range_ = upp_bound[:, self.curDim] - low_bound[:, self.curDim]

        # tf convoluted way to assign a value to a location ,
        # to minimize time, I'll go to the first and fast solution

        # up bound
        upp_update = range_[:, tf.newaxis] * c_upp
        updated_upp = tf.add(low_bound[:, self.curDim, tf.newaxis], upp_update)

        # low bound        
        low_update = range_[:, tf.newaxis] * c_low
        updated_low = tf.add(low_bound[:, self.curDim, tf.newaxis], low_update)

        return [updated_low, updated_upp]

    def compute_output_shape(self, input_shape):
        return input_shape[0], input_shape[1]


class FindSymbolAndBounds(Layer):

    def __init__(self, vocab_size, curDim, **kwargs):
        super(FindSymbolAndBounds, self).__init__(**kwargs)

        self.vocab_size, self.curDim = vocab_size, curDim

    def call(self, inputs, training=None):
        Ls, Us, low_bound, upp_bound, input_point = inputs

        s = pzToSymbol_withArgmax(Us, Ls, input_point[:, self.curDim, tf.newaxis])
        # s = tf.cast(s, dtype=tf.int32)
        s_oh = tf.one_hot(s, self.vocab_size)

        new_L_column = tf.reduce_sum(Ls * s_oh, axis=1)
        low_bound = replace_column(low_bound, new_L_column, self.curDim)

        new_U_column = tf.reduce_sum(Us * s_oh, axis=1)
        upp_bound = replace_column(upp_bound, new_U_column, self.curDim)

        return [s, low_bound, upp_bound]
