

"""
original code: https://github.com/Lsdefine/attention-is-all-you-need-keras
"""

import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.initializers import *
from tensorflow.keras.layers import *

from ariel_tests.models.transformer.utils import slice_from_to


class Squeeze(object):

    def __init__(self, axis):
        self.axis = axis

    def __call__(self, inputs):
        def squeeze(tensor, axis):
            squeezed = K.squeeze(tensor, axis=axis)
            return squeezed

        return Lambda(squeeze, arguments={'axis': self.axis})(inputs)


class Slice(Layer):

    # axis parameter is not functional
    def __init__(self, axis, initial, final, **kwargs):
        self.axis, self.initial, self.final = axis, initial, final
        super(Slice, self).__init__(**kwargs)

    def build(self, input_shape):
        super(Slice, self).build(input_shape)

    def call(self, inputs):
        output = slice_from_to(inputs, self.initial, self.final)
        return output


class RepeatElements(Layer):
    def __init__(self, n_head, **kwargs):
        self.n_head = n_head
        super(RepeatElements, self).__init__(**kwargs)

        def build(self, input_shape):
            super(RepeatElements, self).build(input_shape)

    def call(self, inputs):
        repeated = K.repeat_elements(inputs, self.n_head, 0)
        return repeated


def GetPosEncodingMatrix(max_len, d_emb):
    pos_enc = np.array([
        [pos / np.power(10000, 2 * (j // 2) / d_emb) for j in range(d_emb)]
        if pos != 0 else np.zeros(d_emb)
        for pos in range(max_len)
    ])
    pos_enc[1:, 0::2] = np.sin(pos_enc[1:, 0::2])  # dim 2i
    pos_enc[1:, 1::2] = np.cos(pos_enc[1:, 1::2])  # dim 2i+1
    return pos_enc


def get_pad_mask(q, k):
    ones = K.expand_dims(K.ones_like(q, 'float32'), -1)
    mask = K.cast(K.expand_dims(K.not_equal(k, 0), 1), 'float32')
    mask = K.batch_dot(ones, mask, axes=[2, 1])
    return mask


class GetPadMask(Layer):
    def __init__(self, **kwargs):
        super(GetPadMask, self).__init__(**kwargs)

    def build(self, input_shape):
        super(GetPadMask, self).build(input_shape)

    def call(self, inputs):
        q, k = inputs
        mask = get_pad_mask(q, k)
        return mask


class Min(Layer):
    def __init__(self, **kwargs):
        super(Min, self).__init__(**kwargs)

    def build(self, input_shape):
        super(Min, self).build(input_shape)

    def call(self, inputs):
        q, k = inputs
        output = K.minimum(q, k)
        return output


def GetSubMask(s):
    len_s = tf.shape(s)[1]
    bs = tf.shape(s)[:1]
    mask = K.cumsum(tf.eye(len_s, batch_shape=bs), 1)
    return mask


class LayerNormalization(Layer):

    def __init__(self, eps=1e-6, **kwargs):
        self.eps = eps
        super(LayerNormalization, self).__init__(**kwargs)

    def build(self, input_shape):
        self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:],
                                     initializer=Ones(), trainable=True)
        self.beta = self.add_weight(name='beta', shape=input_shape[-1:],
                                    initializer=Zeros(), trainable=True)
        super(LayerNormalization, self).build(input_shape)

    def call(self, x):
        mean = K.mean(x, axis=-1, keepdims=True)
        std = K.std(x, axis=-1, keepdims=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

    def compute_output_shape(self, input_shape):
        return input_shape


class ScaledDotProductAttention(Layer):

    def __init__(self, d_model, attn_dropout=0.1, **kwargs):
        self.temper = np.sqrt(d_model)
        self.dropout = Dropout(attn_dropout)
        super(ScaledDotProductAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        super(ScaledDotProductAttention, self).build(input_shape)

    def call(self, inputs):
        q, k, v, mask = inputs
        attn = K.batch_dot(q, k, axes=[2, 2]) / self.temper
        if mask is not None:
            mmask = (-1e+10) * (1 - mask)
            attn = tf.add(attn, mmask)
        attn = K.softmax(attn, axis=-1)
        attn = self.dropout(attn)
        output = K.batch_dot(attn, v)
        return output, attn


class MultiHeadAttention():

    # mode 0 - big martixes, faster; mode 1 - more clear implementation
    def __init__(self, n_head, d_model, d_k, d_v, dropout, mode=0, use_norm=True):
        self.mode = mode
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        self.dropout = dropout
        if mode == 0:
            self.qs_layer = Dense(n_head * d_k, use_bias=False)
            self.ks_layer = Dense(n_head * d_k, use_bias=False)
            self.vs_layer = Dense(n_head * d_v, use_bias=False)
        elif mode == 1:
            self.qs_layers = []
            self.ks_layers = []
            self.vs_layers = []
            for _ in range(n_head):
                self.qs_layers.append(TimeDistributed(Dense(d_k, use_bias=False)))
                self.ks_layers.append(TimeDistributed(Dense(d_k, use_bias=False)))
                self.vs_layers.append(TimeDistributed(Dense(d_v, use_bias=False)))
        self.attention = ScaledDotProductAttention(d_model)
        self.layer_norm = LayerNormalization() if use_norm else None
        self.w_o = TimeDistributed(Dense(d_model))

    def __call__(self, q, k, v, mask=None):
        d_k, d_v = self.d_k, self.d_v
        n_head = self.n_head

        if self.mode == 0:
            qs = self.qs_layer(q)  # [batch_size, len_q, n_head*d_k]
            ks = self.ks_layer(k)
            vs = self.vs_layer(v)

            def reshape1(x):
                s = tf.shape(x)  # [batch_size, len_q, n_head * d_k]
                x = tf.reshape(x, [s[0], s[1], n_head, s[2] // n_head])
                x = tf.transpose(x, [2, 0, 1, 3])
                x = tf.reshape(x, [-1, s[1], s[2] // n_head])  # [n_head * batch_size, len_q, d_k]
                return x

            qs = Lambda(reshape1)(qs)
            ks = Lambda(reshape1)(ks)
            vs = Lambda(reshape1)(vs)

            if mask is not None:
                mask = RepeatElements(n_head)(mask)

            head, attn = self.attention([qs, ks, vs, mask])

            def reshape2(x):
                s = tf.shape(x)  # [n_head * batch_size, len_v, d_v]
                x = tf.reshape(x, [n_head, -1, s[1], s[2]])
                x = tf.transpose(x, [1, 2, 0, 3])
                x = tf.reshape(x, [-1, s[1], n_head * d_v])  # [batch_size, len_v, n_head * d_v]
                return x

            head = Lambda(reshape2)(head)
        elif self.mode == 1:
            heads = [];
            attns = []
            for i in range(n_head):
                qs = self.qs_layers[i](q)
                ks = self.ks_layers[i](k)
                vs = self.vs_layers[i](v)
                head, attn = self.attention(qs, ks, vs, mask)
                heads.append(head);
                attns.append(attn)
            head = Concatenate()(heads) if n_head > 1 else heads[0]
            attn = Concatenate()(attns) if n_head > 1 else attns[0]

        outputs = self.w_o(head)
        outputs = Dropout(self.dropout)(outputs)
        if not self.layer_norm: return outputs, attn
        outputs = Add()([outputs, q])
        return self.layer_norm(outputs), attn


class PositionwiseFeedForward():

    def __init__(self, d_hid, d_inner_hid, dropout=0.1):
        self.w_1 = Conv1D(d_inner_hid, 1, activation='relu')
        self.w_2 = Conv1D(d_hid, 1)
        self.layer_norm = LayerNormalization()
        self.dropout = Dropout(dropout)

    def __call__(self, x):
        output = self.w_1(x)
        output = self.w_2(output)
        output = self.dropout(output)
        output = Add()([output, x])
        return self.layer_norm(output)
