
"""
original code: https://github.com/Lsdefine/attention-is-all-you-need-keras
"""

from scipy.special import softmax
from tensorflow.keras.models import Model

from ariel_tests.models.transformer.layers_transformer import *


class EncoderLayer():

    def __init__(self, d_model, d_inner_hid, n_head, d_k, d_v, dropout=0.1):
        self.self_att_layer = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn_layer = PositionwiseFeedForward(d_model, d_inner_hid, dropout=dropout)

    def __call__(self, enc_input, mask=None):
        output, slf_attn = self.self_att_layer(enc_input, enc_input, enc_input, mask=mask)
        output = self.pos_ffn_layer(output)
        return output, slf_attn


class DecoderLayer():

    def __init__(self, d_model, d_inner_hid, n_head, d_k, d_v, dropout=0.1):
        self.self_att_layer = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.enc_att_layer = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn_layer = PositionwiseFeedForward(d_model, d_inner_hid, dropout=dropout)

    def __call__(self, dec_input, enc_output, self_mask=None, enc_mask=None):
        output, slf_attn = self.self_att_layer(dec_input, dec_input, dec_input, mask=self_mask)
        output, enc_attn = self.enc_att_layer(output, enc_output, enc_output, mask=enc_mask)
        output = self.pos_ffn_layer(output)
        return output, slf_attn, enc_attn


class Encoder():

    def __init__(self, d_model, d_inner_hid, n_head, d_k, d_v, \
                 layers=6, dropout=0.1, word_emb=None, pos_emb=None):
        self.emb_layer = word_emb
        self.pos_layer = pos_emb
        self.emb_dropout = Dropout(dropout)
        self.layers = [EncoderLayer(d_model, d_inner_hid, n_head, d_k, d_v, dropout) for _ in range(layers)]

    def __call__(self, src_seq, src_pos, return_att=False, active_layers=999):
        x = self.emb_layer(src_seq)
        if src_pos is not None:
            pos = self.pos_layer(src_pos)
            x = Add()([x, pos])
        x = self.emb_dropout(x)
        if return_att: atts = []
        mask = GetPadMask()([src_seq, src_seq])
        for enc_layer in self.layers[:active_layers]:
            x, att = enc_layer(x, mask)
            if return_att: atts.append(att)
        return (x, atts) if return_att else x


class Decoder():

    def __init__(self, d_model, d_inner_hid, n_head, d_k, d_v,
                 layers=6, dropout=0.1, word_emb=None, pos_emb=None):
        self.emb_layer = word_emb
        self.pos_layer = pos_emb
        self.layers = [DecoderLayer(d_model, d_inner_hid, n_head, d_k, d_v, dropout) for _ in range(layers)]

    def __call__(self, tgt_seq, tgt_pos, src_seq, enc_output, return_att=False, active_layers=999):
        x = self.emb_layer(tgt_seq)
        if tgt_pos is not None:
            pos = self.pos_layer(tgt_pos)
            x = Add()([x, pos])

        self_pad_mask = GetPadMask()([tgt_seq, tgt_seq])
        self_sub_mask = Lambda(GetSubMask)(tgt_seq)
        self_mask = Min()([self_pad_mask, self_sub_mask])

        if not src_seq == None:
            enc_mask = GetPadMask()([tgt_seq, src_seq])
        else:
            enc_mask = None

        if return_att: self_atts, enc_atts = [], []
        for dec_layer in self.layers[:active_layers]:
            x, self_att, enc_att = dec_layer(x, enc_output, self_mask, enc_mask)
            if return_att:
                self_atts.append(self_att)
                enc_atts.append(enc_att)
        return (x, self_atts, enc_atts) if return_att else x


class Transformer(object):

    def __init__(self, vocabulary, len_limit, d_model=256,
                 d_inner_hid=512, n_head=4, d_k=64, d_v=64, layers=2, dropout=0.1,
                 share_word_emb=False):
        self.vocabulary = vocabulary
        self.len_limit = len_limit
        self.src_loc_info = True
        self.d_model = d_model
        self.decode_model = None
        self.softmax_generator = None
        self.decoder_generator = None
        d_emb = d_model

        self.vocab_size = self.vocabulary.getMaxVocabularySize()
        pos_emb = Embedding(len_limit, d_emb, trainable=False,
                            weights=[GetPosEncodingMatrix(len_limit, d_emb)])
        i_word_emb = Embedding(self.vocab_size, d_emb)
        o_word_emb = i_word_emb

        self.encoder = Encoder(d_model, d_inner_hid, n_head, d_k, d_v, layers, dropout,
                               word_emb=i_word_emb, pos_emb=pos_emb)
        self.decoder = Decoder(d_model, d_inner_hid, n_head, d_k, d_v, layers, dropout,
                               word_emb=o_word_emb, pos_emb=pos_emb)
        self.target_layer = TimeDistributed(Dense(self.vocab_size, use_bias=False))

    def get_pos_seq(self, x):
        mask = K.cast(K.not_equal(x, 0), 'int32')
        pos = K.cumsum(K.ones_like(x, 'int32'), 1)
        return pos * mask

    def compile(self, optimizer='adam', active_layers=999):
        src_seq_input = Input(shape=(None,), dtype='int32')
        tgt_seq_input = Input(shape=(None,), dtype='int32')

        src_seq = src_seq_input
        tgt_seq = Slice(1, 0, -1)(tgt_seq_input)
        tgt_true = Slice(1, 1, None)(tgt_seq_input)

        src_pos = Lambda(self.get_pos_seq)(src_seq)
        tgt_pos = Lambda(self.get_pos_seq)(tgt_seq)
        if not self.src_loc_info: src_pos = None

        enc_output = self.encoder(src_seq, src_pos, active_layers=active_layers)
        dec_output = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output, active_layers=active_layers)
        final_output = self.target_layer(dec_output)

        def get_loss(args):
            y_pred, y_true = args
            y_true = tf.cast(y_true, 'int32')
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
            mask = tf.cast(tf.not_equal(y_true, 0), 'float32')
            loss = tf.reduce_sum(loss * mask, -1) / tf.reduce_sum(mask, -1)
            loss = K.mean(loss)
            return loss

        def get_accu(args):
            y_pred, y_true = args
            mask = tf.cast(tf.not_equal(y_true, 0), 'float32')
            corr = K.cast(K.equal(K.cast(y_true, 'int32'), K.cast(K.argmax(y_pred, axis=-1), 'int32')), 'float32')
            corr = K.sum(corr * mask, -1) / K.sum(mask, -1)
            return K.mean(corr)

        loss = Lambda(get_loss)([final_output, tgt_true])
        self.ppl = Lambda(K.exp)(loss)
        self.accu = Lambda(get_accu)([final_output, tgt_true])

        self.model = Model([src_seq_input, tgt_seq_input], loss)
        self.model.add_loss([loss])
        self.output_model = Model([src_seq_input, tgt_seq_input], final_output)

        self.model.compile(optimizer, None)

        self.model.metrics_tensors = []
        self.model.metrics_names.append('ppl')
        self.model.metrics_tensors.append(self.ppl)
        self.model.metrics_names.append('accu')
        self.model.metrics_tensors.append(self.accu)
        self.encode_model = Model(src_seq_input, enc_output)
        self.decoder_asGenerator_noBeam()

    def make_src_seq_matrix(self, input_seq):

        src_seq = np.zeros((1, len(input_seq) + 3), dtype='int32')
        src_seq[0, 0] = self.vocabulary.startIndex

        for i, z in enumerate(input_seq): src_seq[0, 1 + i] = self.vocabulary.indicesByTokens[z]

        src_seq[0, len(input_seq) + 1] = self.vocabulary.endIndex
        return src_seq

    def decode_sequence(self, input_seq, delimiter=''):

        src_seq = self.make_src_seq_matrix(input_seq)
        decoded_tokens = []
        target_seq = np.zeros((1, self.len_limit), dtype='int32')
        target_seq[0, 0] = self.vocabulary.startIndex
        for i in range(self.len_limit - 1):
            output = self.output_model.predict_on_batch([src_seq, target_seq])
            sampled_index = np.argmax(output[0, i, :])
            sampled_token = self.vocabulary.tokens[sampled_index]
            decoded_tokens.append(sampled_token)
            if sampled_index == self.vocabulary.endIndex: break
            target_seq[0, i + 1] = sampled_index

        return delimiter.join(decoded_tokens[:-1])

    def next_symbol_prediction(self, input_seq, pad_value=None):
        """
        This method will only make sense if the Transformer has been trained to do next symbol
        prediction.

        Args:
            input_indices: a sequence of indices whose next value has to be predicted
            pad_value: the value for padding can be specified from the outside

        Returns:
            softmax distribution over next possible symbols

        """

        if pad_value == None:
            pad_value = self.vocabulary.padIndex

        target_seq = self.make_src_seq_matrix(input_seq)
        src_seq = [[pad_value] + list(target_seq[0])]

        output = self.output_model.predict_on_batch([src_seq, target_seq])
        soft = softmax(output[0, -1, :])
        sampled_index = self.vocabulary.tokens[np.argmax(output[0, -1, :])]

        return soft, sampled_index

    def make_fast_decode_model(self):
        src_seq_input = Input(shape=(None,), dtype='int32')
        tgt_seq_input = Input(shape=(None,), dtype='int32')
        src_seq = src_seq_input
        tgt_seq = tgt_seq_input

        src_pos = Lambda(self.get_pos_seq)(src_seq)
        tgt_pos = Lambda(self.get_pos_seq)(tgt_seq)
        if not self.src_loc_info: src_pos = None
        enc_output = self.encoder(src_seq, src_pos)
        self.encode_model = Model(src_seq_input, enc_output)

        enc_ret_input = Input(shape=(None, self.d_model))
        dec_output = self.decoder(tgt_seq, tgt_pos, src_seq, enc_ret_input)
        final_output = self.target_layer(dec_output)
        self.decode_model = Model([src_seq_input, enc_ret_input, tgt_seq_input], final_output)

        self.encode_model.compile('adam', 'mse')
        self.decode_model.compile('adam', 'mse')

    def decoder_asGenerator_noBeam(self):
        # input that used to come from the encoder, now will be directed
        # externally with noise

        enc_ret_input = Input(shape=(None, self.d_model))
        tgt_seq = Input(shape=(None,), dtype='int32')

        src_seq = None

        # do the loop for the target sequence
        tgt_pos = Lambda(self.get_pos_seq)(tgt_seq)
        dec_output = self.decoder(tgt_seq, tgt_pos, src_seq, enc_ret_input)
        final_output = self.target_layer(dec_output)

        self.decoder_generator = Model([enc_ret_input, tgt_seq], final_output)

    def decode_noise_fast(self, input_noise, delimiter=' '):
        if self.decoder_generator is None:
            self.decoder_asGenerator_noBeam()

        batch_size = input_noise.shape[0]

        target_seq = np.zeros((batch_size, self.len_limit), dtype='int32')
        target_seq[:, 0] = [self.vocabulary.startIndex] * batch_size
        for i in range(self.len_limit - 1):
            output = self.decoder_generator.predict_on_batch([input_noise, target_seq])
            sampled_index = np.argmax(output[:, i, :], axis=1)
            target_seq[:, i + 1] = sampled_index

        return self.indices2sentences(target_seq)

    def indices2sentences(self, indices):
        listTokens = [[self.vocabulary.tokens[int(i)] for i in list_idx]
                      for list_idx in indices]
        listSentences = []
        for sentence in listTokens:
            modSentence = sentence
            removableItems = ['<PAD>', '<UNK>', '<S>', '</S>']
            modSentence = [item for item in sentence if item not in removableItems]
            listSentences.append(' '.join(modSentence))

        return listSentences

    def decode_sequence_fast(self, input_seq, delimiter=''):
        if self.decode_model is None: self.make_fast_decode_model()
        src_seq = self.make_src_seq_matrix(input_seq)
        enc_ret = self.encode_model.predict_on_batch(src_seq)

        decoded_tokens = []
        target_seq = np.zeros((1, self.len_limit), dtype='int32')
        target_seq[0, 0] = self.vocabulary.startIndex
        for i in range(self.len_limit - 1):
            output = self.decode_model.predict_on_batch([src_seq, enc_ret, target_seq])
            sampled_index = np.argmax(output[0, i, :])
            sampled_token = self.vocabulary.tokens[sampled_index]
            decoded_tokens.append(sampled_token)
            if sampled_index == self.vocabulary.endIndex: break
            target_seq[0, i + 1] = sampled_index
        return delimiter.join(decoded_tokens[:-1])

    def beam_search(self, input_seq, topk=5, delimiter=''):
        if self.decode_model is None: self.make_fast_decode_model()
        src_seq = self.make_src_seq_matrix(input_seq)
        src_seq = src_seq.repeat(topk, 0)
        enc_ret = self.encode_model.predict_on_batch(src_seq)

        final_results = []
        decoded_tokens = [[] for _ in range(topk)]
        decoded_logps = [0] * topk
        lastk = 1
        target_seq = np.zeros((topk, self.len_limit), dtype='int32')
        target_seq[:, 0] = self.vocabulary.startIndex
        for i in range(self.len_limit - 1):
            if lastk == 0 or len(final_results) > topk * 3: break
            output = self.decode_model.predict_on_batch([src_seq, enc_ret, target_seq])
            output = np.exp(output[:, i, :])
            output = np.log(output / np.sum(output, -1, keepdims=True) + 1e-8)
            cands = []
            for k, wprobs in zip(range(lastk), output):
                if target_seq[k, i] == self.vocabulary.endIndex: continue
                wsorted = sorted(list(enumerate(wprobs)), key=lambda x: x[-1], reverse=True)
                for wid, wp in wsorted[:topk]:
                    cands.append((k, wid, decoded_logps[k] + wp))
            cands.sort(key=lambda x: x[-1], reverse=True)
            cands = cands[:topk]
            backup_seq = target_seq.copy()
            for kk, zz in enumerate(cands):
                k, wid, wprob = zz
                target_seq[kk,] = backup_seq[k]
                target_seq[kk, i + 1] = wid
                decoded_logps[kk] = wprob
                decoded_tokens.append(decoded_tokens[k] + [self.vocabulary.tokens[wid]])
                if wid == self.vocabulary.endIndex: final_results.append((decoded_tokens[k], wprob))
            decoded_tokens = decoded_tokens[topk:]
            lastk = len(cands)
        final_results = [(x, y / (len(x) + 1)) for x, y in final_results]
        final_results.sort(key=lambda x: x[-1], reverse=True)
        final_results = [(delimiter.join(x), y) for x, y in final_results]
        return final_results


def output2nextsymbol(output_model):
    original_seq = Input(shape=(None,), dtype='int32')

    src_seq = Slice(1, 0, -1)(original_seq)
    tgt_seq = Slice(1, 1, None)(original_seq)
    output_3dim = output_model([src_seq, tgt_seq])
    output_2dim = Slice(1, -1, None)(output_3dim)
    squeezed = Squeeze(1)(output_2dim)
    next_symol_softmax = Softmax(-1)(squeezed)

    model = Model(original_seq, next_symol_softmax)

    return model


# def predefinedNextsymbolTransformer(model_filename):
#    output_model = load_model(model_filename)
#    nextsymbol_model = output2nextsymbol(output_model)
#    return nextsymbol_model


if __name__ == '__main__':
    print('done')
