import tensorflow as tf

# def my_init(shape, dtype=None):
#     return tf.random.uniform(shape, minval=-0.25, maxval=0.25, dtype=dtype) 

class Attention(tf.keras.layers.Layer):
    def __init__(self):
        super(Attention, self).__init__()
        
    def call(self, x, y):
        w1  = tf.expand_dims(x, 1)
        w2  = tf.expand_dims(y, 0)
        D   = tf.reduce_sum(tf.square(w1-w2), axis=-1)
        return tf.exp(-0.5*D)
        # dot = tf.reduce_sum(w1*w2, axis=-1)
        # A =  tf.expand_dims(tf.reduce_sum(x**2, axis=1), 1)
        # B =  tf.expand_dims(tf.reduce_sum(y**2, axis=1), 0)
        # return tf.exp(-A -B +  2*dot)

    def get_config(self):
        return {} 
    
def _get_segment_indices( ids_):
    head = tf.reverse(tf.add(tf.reduce_max(ids_, axis=0), 1),[0])
    max_columns = tf.concat([tf.reverse(tf.slice(head, [0], [tf.size(head)-1]),[0]),[1]],0)
    multipliers = tf.math.cumprod(max_columns, reverse=True)
    y, idx = tf.unique(tf.reduce_sum(tf.multiply(ids_, multipliers),axis=1))
    return idx

def _reduce_indices( ids_):
    ids_shape   = tf.shape(ids_)
    root_shape  = tf.gather(ids_shape, tf.range(0,tf.size(ids_shape)-1))
    last_column = tf.gather(ids_shape, [tf.size(ids_shape)-1])
    new_shape = tf.concat([last_column-1, root_shape], 0)
    reduced_ids = tf.reshape(tf.gather(tf.reshape(tf.transpose(ids_), [-1]), tf.range(0, tf.reduce_prod(new_shape))), new_shape)
    reduced_ids = tf.transpose(reduced_ids)
    return reduced_ids

def reduce_segment_ids(segment_ids, n):
    lengths = segment_ids
    for i in range(n):
        segment_indices = _get_segment_indices(lengths)
        reduced_ids = _reduce_indices(lengths)
        lengths = tf.math.segment_max(reduced_ids, segment_indices)
    return tf.cast(segment_indices, tf.int32)

def get_model(input_dim, inducing_points):
    # inducing_points Points
    n_points = inducing_points

    # model definition
    ip   = tf.keras.layers.Input(shape=(n_points,))
    ip_d = tf.keras.layers.Dense(64, activation='linear', use_bias=False)(ip)#, kernel_initializer=my_init)(ip)
    
    in_ = tf.keras.layers.Input(shape=(input_dim,))
    out_1 = tf.keras.layers.Lambda(lambda t: reduce_segment_ids(tf.cast(t,tf.int32), 1), output_shape=(1,))(in_[:, (input_dim-2):])
    out_2 = tf.keras.layers.Lambda(lambda t: reduce_segment_ids(tf.cast(t,tf.int32), 2), output_shape=(1,))(in_[:, (input_dim-2):])
    
    x = in_[:, :int(input_dim-2)]
    x = tf.keras.layers.Embedding(137, 500)(x)
    x = tf.keras.layers.Flatten()(x)
#     x = tf.keras.layers.Dense(500, activation='relu')(x)
    x1 = tf.keras.layers.Dense(250)(x)
    x2 = tf.keras.layers.Dense(250)(x)
    x1 = tf.keras.layers.Lambda(lambda t: tf.math.segment_max(t[0], t[1] ))([x1, out_1])
    x2 = tf.keras.layers.Lambda(lambda t: tf.math.segment_mean(t[0], t[1] ))([x2, out_1])
    x  = tf.keras.layers.Concatenate()([x1, x2])
    
    x1 = tf.keras.layers.Dense(250)(x)
    x2 = tf.keras.layers.Dense(250)(x)
    x1 = tf.keras.layers.Lambda(lambda t: tf.math.segment_max(t[0], t[1] ))([x1, out_2])
    x2 = tf.keras.layers.Lambda(lambda t: tf.math.segment_mean(t[0], t[1] ))([x2, out_2])
    x  = tf.keras.layers.Concatenate()([x1, x2])

    x  = tf.keras.layers.Dense(64, activation='linear')(x)
    
    K_ii = Attention()(ip_d, ip_d)
    K_ix = Attention()(ip_d, x)
    K_xx = Attention()(x, x)
    
    model = tf.keras.models.Model([in_, ip], [ip_d, K_ii, K_ix, K_xx, x])
    model.build( ((None, input_dim), (None, n_points)))
#     model.summary()

    regressor = tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(64,)),
#         tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(1, activation='linear'),
    ])
    regressor.build((None, 64))
    return model, regressor