import tensorflow as tf

class Attention(tf.keras.layers.Layer):
    def __init__(self):
        super(Attention, self).__init__()
        
    def call(self, x, y):
        w1  = tf.expand_dims(x, 1)
        w2  = tf.expand_dims(y, 0)
        D   = tf.reduce_sum(tf.square(w1-w2), axis=-1)
        return tf.exp(-0.5*D)
        # dot = tf.reduce_sum(w1*w2, axis=-1)
        # A =  tf.expand_dims(tf.reduce_sum(x**2, axis=1), 1)
        # B =  tf.expand_dims(tf.reduce_sum(y**2, axis=1), 0)
        # return tf.exp(-A -B +  2*dot)

    def get_config(self):
        return {} 

def get_model(input_dim, inducing_points):
    # inducing_points Points
    n_points = inducing_points

    # model definition
    ip   = tf.keras.layers.Input(shape=(n_points,))
    ip_d = tf.keras.layers.Dense(64, activation='linear', use_bias=False)(ip)

    if (type(input_dim) == list) or (type(input_dim) == tuple):
        in_  = tf.keras.layers.Input(shape=input_dim)
    else:
        in_  = tf.keras.layers.Input(shape=(input_dim, ))
    x    = in_
    if (type(input_dim) == list) or (type(input_dim) == tuple):
        x = tf.keras.layers.Flatten()(x)

    x    = tf.keras.layers.Dense(128, activation='linear')(x)
    X    = tf.keras.layers.LayerNormalization()(x)
    x    = tf.keras.layers.Activation('relu')(x)
    
    x    = tf.keras.layers.Dense(128, activation='linear')(x)
    X    = tf.keras.layers.LayerNormalization()(x)
    x    = tf.keras.layers.Activation('relu')(x)
    
    x    = tf.keras.layers.Dense(128, activation='linear')(x)
    X    = tf.keras.layers.LayerNormalization()(x)
    x    = tf.keras.layers.Activation('relu')(x)
    
    x    = tf.keras.layers.Dense(64, activation='linear')(x)

    K_ii = Attention()(ip_d, ip_d)
    K_ix = Attention()(ip_d, x)
    K_xx = Attention()(x, x)

    model = tf.keras.models.Model([in_, ip], [ip_d, K_ii, K_ix, K_xx, x])
    model.build( ((None, ) + input_dim, (None, n_points)))

    regressor = tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(64,)),
        tf.keras.layers.Dense(1, activation='linear'),
    ])

 #   classifier = tf.keras.models.Sequential([
 #       tf.keras.layers.Input(shape=(64,)),
 #       tf.keras.layers.Dense(1, activation='linear'),
#         tf.keras.layers.Activation('tanh'),
 #   ])
    regressor.build((None, 64))
    return model, regressor
