import tensorflow as tf

def my_init(shape, dtype=None):
    return tf.random.uniform(shape, minval=-0.25, maxval=0.25, dtype=dtype) 

class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = tf.keras.models.Sequential(
            [tf.keras.layers.Dense(ff_dim, activation="relu"), tf.keras.layers.Dense(embed_dim),]
        )
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)
    
class TokenAndPositionEmbedding(tf.keras.layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.token_emb = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = tf.keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

class Attention(tf.keras.layers.Layer):
    def __init__(self):
        super(Attention, self).__init__()
        
    def call(self, x, y):
        w1  = tf.expand_dims(x, 1)
        w2  = tf.expand_dims(y, 0)
        D   = tf.reduce_sum(tf.square(w1-w2), axis=-1)
        return tf.exp(-0.5*D)
        # dot = tf.reduce_sum(w1*w2, axis=-1)
        # A =  tf.expand_dims(tf.reduce_sum(x**2, axis=1), 1)
        # B =  tf.expand_dims(tf.reduce_sum(y**2, axis=1), 0)
        # return tf.exp(-A -B +  2*dot)

    def get_config(self):
        return {} 

def get_model(input_dim, inducing_points):
    maxlen=input_dim[0]
    vocab_size=20000
    # inducing_points Points
    n_points = inducing_points

    # model definition
    ip   = tf.keras.layers.Input(shape=(n_points,))
    ip_d = tf.keras.layers.Dense(64, activation='linear', use_bias=False, kernel_initializer=my_init)(ip)

    if (type(input_dim) == list) or (type(input_dim) == tuple):
        in_  = tf.keras.layers.Input(shape=input_dim)
    else:
        in_  = tf.keras.layers.Input(shape=(input_dim, ))
    x    = in_
    if (type(input_dim) == list) or (type(input_dim) == tuple):
        x = tf.keras.layers.Flatten()(x)
        
        
    embed_dim = 32  # Embedding size for each token
    num_heads = 2  # Number of attention heads
    ff_dim = 32  # Hidden layer size in feed forward network inside transformer
    
    embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
    x = embedding_layer(x)
    transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
    x = transformer_block(x)
    x = tf.keras.layers.GlobalAveragePooling1D()(x) #we can replace this with LAF in case
    x = tf.keras.layers.Dropout(0.1)(x)
    x = tf.keras.layers.Dense(64, activation="linear")(x)

    K_ii = Attention()(ip_d, ip_d)
    K_ix = Attention()(ip_d, x)
    K_xx = Attention()(x, x)

    model = tf.keras.models.Model([in_, ip], [ip_d, K_ii, K_ix, K_xx, x])
    model.build( ((None, ) + input_dim, (None, n_points)))

#     model.summary()

    regressor = tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(64,)),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(1, activation='linear'),
    ])

 #   classifier = tf.keras.models.Sequential([
 #       tf.keras.layers.Input(shape=(64,)),
 #       tf.keras.layers.Dense(1, activation='linear'),
#         tf.keras.layers.Activation('tanh'),
 #   ])
    regressor.build((None, 64))
    return model, regressor
