# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 Transformer XL model.
"""

from __future__ import absolute_import, division, print_function, unicode_literals

import os
import json
import math
import logging
import collections
import sys
from io import open

import numpy as np
import tensorflow as tf

from .configuration_transfo_xl import TransfoXLConfig
from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list, get_initializer
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model

logger = logging.getLogger(__name__)

TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
    'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5",
}

def load_transfo_xl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
    # build the network
    inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
    tf_inputs = tf.constant(inputs_list)
    tfo = tf_model(tf_inputs, training=False)
    return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)


class TFPositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, demb, **kwargs):
        super(TFPositionalEmbedding, self).__init__(**kwargs)

        self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb))

    def call(self, pos_seq, bsz=None):
        sinusoid_inp = tf.einsum('i,j->ij', pos_seq, self.inv_freq)
        pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)

        if bsz is not None:
            return tf.tile(pos_emb[:, None, :], [1, bsz, 1])
        else:
            return pos_emb[:, None, :]


class TFPositionwiseFF(tf.keras.layers.Layer):
    def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5, init_std=0.02, **kwargs):
        super(TFPositionwiseFF, self).__init__(**kwargs)

        self.d_model = d_model
        self.d_inner = d_inner
        self.dropout = dropout

        self.layer_1 = tf.keras.layers.Dense(d_inner,
                                             kernel_initializer=get_initializer(init_std),
                                             activation=tf.nn.relu,
                                             name='CoreNet_._0')
        self.drop_1 = tf.keras.layers.Dropout(dropout)
        self.layer_2 = tf.keras.layers.Dense(d_model,
                                             kernel_initializer=get_initializer(init_std),
                                             name='CoreNet_._3')
        self.drop_2 = tf.keras.layers.Dropout(dropout)

        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name='layer_norm')

        self.pre_lnorm = pre_lnorm

    def call(self, inp, training=False):
        if self.pre_lnorm:
            ##### layer normalization + positionwise feed-forward
            core_out = self.layer_norm(inp)
            core_out = self.layer_1(core_out)
            core_out = self.drop_1(core_out, training=training)
            core_out = self.layer_2(core_out)
            core_out = self.drop_2(core_out, training=training)

            ##### residual connection
            output = core_out + inp
        else:
            ##### positionwise feed-forward
            core_out = self.layer_1(inp)
            core_out = self.drop_1(core_out, training=training)
            core_out = self.layer_2(core_out)
            core_out = self.drop_2(core_out, training=training)

            ##### residual connection + layer normalization
            output = self.layer_norm(inp + core_out)

        return output


class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
                 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
                 r_r_bias=None, r_w_bias=None, output_attentions=False, 
                 layer_norm_epsilon=1e-5, init_std=0.02, **kwargs):
        super(TFRelPartialLearnableMultiHeadAttn, self).__init__(**kwargs)

        self.output_attentions = output_attentions
        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout

        self.qkv_net = tf.keras.layers.Dense(3 * n_head * d_head,
                                             kernel_initializer=get_initializer(init_std),
                                             use_bias=False,
                                             name='qkv_net')

        self.drop = tf.keras.layers.Dropout(dropout)
        self.dropatt = tf.keras.layers.Dropout(dropatt)
        self.o_net = tf.keras.layers.Dense(d_model,
                                           kernel_initializer=get_initializer(init_std),
                                           use_bias=False,
                                           name='o_net')

        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name='layer_norm')

        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm

        if r_r_bias is not None and r_w_bias is not None: # Biases are shared
            self.r_r_bias = r_r_bias
            self.r_w_bias = r_w_bias
        else:
            self.r_r_bias = None
            self.r_w_bias = None

        self.r_net = tf.keras.layers.Dense(self.n_head * self.d_head,
                                           kernel_initializer=get_initializer(init_std),
                                           use_bias=False,
                                           name='r_net')

    def build(self, input_shape):
        if self.r_r_bias is None or self.r_w_bias is None: # Biases are not shared
            self.r_r_bias = self.add_weight(shape=(self.n_head, self.d_head),
                                            initializer='zeros',
                                            trainable=True,
                                            name='r_r_bias')
            self.r_w_bias = self.add_weight(shape=(self.n_head, self.d_head),
                                            initializer='zeros',
                                            trainable=True,
                                            name='r_w_bias')
        super(TFRelPartialLearnableMultiHeadAttn, self).build(input_shape)

    def _rel_shift(self, x):
        x_size = shape_list(x)

        x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
        x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]])
        x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
        x = tf.reshape(x, x_size)

        return x

    def call(self, inputs, training=False):
        w, r, attn_mask, mems, head_mask = inputs
        qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]

        if mems is not None:
            cat = tf.concat([mems, w], 0)
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
            w_head_q = w_head_q[-qlen:]
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)

        klen = shape_list(w_head_k)[0]

        w_head_q = tf.reshape(w_head_q, (qlen, bsz, self.n_head, self.d_head))  # qlen x bsz x n_head x d_head
        w_head_k = tf.reshape(w_head_k, (klen, bsz, self.n_head, self.d_head))  # qlen x bsz x n_head x d_head
        w_head_v = tf.reshape(w_head_v, (klen, bsz, self.n_head, self.d_head))  # qlen x bsz x n_head x d_head

        r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head))       # qlen x n_head x d_head

        #### compute attention score
        rw_head_q = w_head_q + self.r_w_bias                                    # qlen x bsz x n_head x d_head
        AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)                  # qlen x klen x bsz x n_head

        rr_head_q = w_head_q + self.r_r_bias
        BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)                   # qlen x klen x bsz x n_head
        BD = self._rel_shift(BD)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score = attn_score * self.scale

        #### compute attention probability
        if attn_mask is not None:
            attn_mask_t = attn_mask[:, :, None, None]
            attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t

        # [qlen x klen x bsz x n_head]
        attn_prob = tf.nn.softmax(attn_score, axis=1)
        attn_prob = self.dropatt(attn_prob, training=training)

        # Mask heads if we want to
        if head_mask is not None:
            attn_prob = attn_prob * head_mask

        #### compute attention vector
        attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)

        # [qlen x bsz x n_head x d_head]
        attn_vec_sizes = shape_list(attn_vec)
        attn_vec = tf.reshape(attn_vec, 
                        (attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head))

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out, training=training)

        if self.pre_lnorm:
            ##### residual connection
            outputs = [w + attn_out]
        else:
            ##### residual connection + layer normalization
            outputs = [self.layer_norm(w + attn_out)]

        if self.output_attentions:
            outputs.append(attn_prob)

        return outputs


class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout,
                 tgt_len=None, ext_len=None, mem_len=None,
                 dropatt=0., pre_lnorm=False,
                 r_w_bias=None,
                 r_r_bias=None,
                 output_attentions=False,
                 layer_norm_epsilon=1e-5,
                 init_std=0.02,
                 **kwargs):
        super(TFRelPartialLearnableDecoderLayer, self).__init__(**kwargs)

        self.dec_attn = TFRelPartialLearnableMultiHeadAttn(n_head, d_model,
                            d_head, dropout, tgt_len=tgt_len, ext_len=ext_len,
                            mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm,
                            r_w_bias=r_w_bias, r_r_bias=r_r_bias, init_std=init_std,
                            output_attentions=output_attentions,
                            layer_norm_epsilon=layer_norm_epsilon, name='dec_attn')
        self.pos_ff = TFPositionwiseFF(d_model, d_inner, dropout, 
                                       pre_lnorm=pre_lnorm, init_std=init_std,
                                       layer_norm_epsilon=layer_norm_epsilon,
                                       name='pos_ff')

    def call(self, inputs, training=False):
        dec_inp, r, dec_attn_mask, mems, head_mask = inputs
        attn_outputs = self.dec_attn([dec_inp, r, dec_attn_mask,
                                      mems, head_mask], training=training)
        ff_output = self.pos_ff(attn_outputs[0], training=training)

        outputs = [ff_output] + attn_outputs[1:]

        return outputs


class TFAdaptiveEmbedding(tf.keras.layers.Layer):
    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, init_std=0.02,
                 sample_softmax=False, **kwargs):
        super(TFAdaptiveEmbedding, self).__init__(**kwargs)

        self.n_token = n_token
        self.d_embed = d_embed
        self.init_std = init_std

        self.cutoffs = cutoffs + [n_token]
        self.div_val = div_val
        self.d_proj = d_proj

        self.emb_scale = d_proj ** 0.5

        self.cutoff_ends = [0] + self.cutoffs

        self.emb_layers = []
        self.emb_projs = []
        if div_val == 1:
            raise NotImplementedError  # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
        else:
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
                d_emb_i = d_embed // (div_val ** i)
                self.emb_layers.append(tf.keras.layers.Embedding(r_idx-l_idx,
                                                                 d_emb_i,
                                                                 embeddings_initializer=get_initializer(init_std),
                                                                 name='emb_layers_._{}'.format(i)))

    def build(self, input_shape):
        for i in range(len(self.cutoffs)):
            d_emb_i = self.d_embed // (self.div_val ** i)
            self.emb_projs.append(self.add_weight(shape=(d_emb_i, self.d_proj),
                                                  initializer=get_initializer(self.init_std),
                                                  trainable=True,
                                                  name='emb_projs_._{}'.format(i)))
        super(TFAdaptiveEmbedding, self).build(input_shape)

    def call(self, inp):
        if self.div_val == 1:
            raise NotImplementedError  # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
        else:
            inp_flat = tf.reshape(inp, (-1,))
            emb_flat = tf.zeros([shape_list(inp_flat)[0], self.d_proj])
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

                mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)

                inp_i = tf.boolean_mask(inp_flat, mask_i) - l_idx
                emb_i = self.emb_layers[i](inp_i)
                emb_i = tf.einsum('id,de->ie', emb_i, self.emb_projs[i])

                mask_idx = tf.cast(tf.where(mask_i), dtype=tf.int64)
                emb_flat += tf.scatter_nd(mask_idx, emb_i, tf.cast(tf.shape(emb_flat), dtype=tf.int64))

            embed_shape = shape_list(inp) + [self.d_proj]
            embed = tf.reshape(emb_flat, embed_shape)

        embed *= self.emb_scale

        return embed


class TFTransfoXLMainLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFTransfoXLMainLayer, self).__init__(**kwargs)
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states

        self.n_token = config.n_token

        self.d_embed = config.d_embed
        self.d_model = config.d_model
        self.n_head = config.n_head
        self.d_head = config.d_head
        self.untie_r = config.untie_r

        self.word_emb = TFAdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs, 
                                            div_val=config.div_val, init_std=config.init_std, name='word_emb')

        self.drop = tf.keras.layers.Dropout(config.dropout)

        self.n_layer = config.n_layer

        self.tgt_len = config.tgt_len
        self.mem_len = config.mem_len
        self.ext_len = config.ext_len
        self.max_klen = config.tgt_len + config.ext_len + config.mem_len

        self.attn_type = config.attn_type

        self.layers = []
        if config.attn_type == 0: # the default attention
            for i in range(config.n_layer):
                self.layers.append(
                    TFRelPartialLearnableDecoderLayer(
                        config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
                        tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
                        dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
                        r_w_bias=None if self.untie_r else self.r_w_bias,
                        r_r_bias=None if self.untie_r else self.r_r_bias,
                        output_attentions=self.output_attentions,
                        layer_norm_epsilon=config.layer_norm_epsilon,
                        init_std=config.init_std,
                        name='layers_._{}'.format(i))
                )
        else: # learnable embeddings and absolute embeddings
            raise NotImplementedError  # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint

        self.same_length = config.same_length
        self.clamp_len = config.clamp_len

        if self.attn_type == 0: # default attention
            self.pos_emb = TFPositionalEmbedding(self.d_model, name='pos_emb')
        else: # learnable embeddings and absolute embeddings
            raise NotImplementedError  # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint

    def build(self, input_shape):
        if not self.untie_r:
            self.r_w_bias = self.add_weight(shape=(self.n_head, self.d_head),
                                            initializer='zeros',
                                            trainable=True,
                                            name='r_w_bias')
            self.r_r_bias = self.add_weight(shape=(self.n_head, self.d_head),
                                            initializer='zeros',
                                            trainable=True,
                                            name='r_r_bias')
        super(TFTransfoXLMainLayer, self).build(input_shape)

    def _resize_token_embeddings(self, new_num_tokens):
        return self.word_emb

    def backward_compatible(self):
        self.sample_softmax = -1

    def reset_length(self, tgt_len, ext_len, mem_len):
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len

    def _prune_heads(self, heads):
        raise NotImplementedError

    def init_mems(self, data):
        if self.mem_len > 0:
            mems = []
            for i in range(self.n_layer):
                empty = tf.zeros([self.mem_len, shape_list(data)[1], self.d_model])
                mems.append(empty)

            return mems
        else:
            return None

    def _update_mems(self, hids, mems, qlen, mlen):
        # does not deal with None
        if mems is None: return None

        # mems is not None
        assert len(hids) == len(mems), 'len(hids) != len(mems)'

        # There are `mlen + qlen` steps that can be cached into mems
        # For the next step, the last `ext_len` of the `qlen` tokens
        # will be used as the extended context. Hence, we only cache
        # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
        # to `mlen + qlen - self.ext_len`.
        new_mems = []
        end_idx = mlen + max(0, qlen - 0 - self.ext_len)
        beg_idx = max(0, end_idx - self.mem_len)
        for i in range(len(hids)):

            cat = tf.concat([mems[i], hids[i]], axis=0)
            tf.stop_gradient(cat)
            new_mems.append(cat[beg_idx:end_idx])

        return new_mems

    def call(self, inputs, mems=None, head_mask=None, training=False):
        if isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
            mems = inputs[1] if len(inputs) > 1 else mems
            head_mask = inputs[2] if len(inputs) > 2 else head_mask
            assert len(inputs) <= 3, "Too many inputs."
        elif isinstance(inputs, dict):
            input_ids = inputs.get('input_ids')
            mems = inputs.get('mems', mems)
            head_mask = inputs.get('head_mask', head_mask)
            assert len(inputs) <= 3, "Too many inputs."
        else:
            input_ids = inputs

        # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
        # so we transpose here from shape [bsz, len] to shape [len, bsz]
        input_ids = tf.transpose(input_ids, perm=(1, 0))

        if mems is None:
            mems = self.init_mems(input_ids)

        qlen, bsz = shape_list(input_ids)

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
        # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
        if not head_mask is None:
            raise NotImplementedError
        else:
            head_mask = [None] * self.n_layer

        word_emb = self.word_emb(input_ids)

        mlen = shape_list(mems[0])[0] if mems is not None else 0
        klen = mlen + qlen

        attn_mask = tf.ones([qlen, qlen])
        mask_u = tf.linalg.band_part(attn_mask, 0, -1)
        mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
        attn_mask_pad = tf.zeros([qlen, mlen])
        dec_attn_mask = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
        if self.same_length:
            mask_l = tf.linalg.band_part(attn_mask, -1, 0)
            dec_attn_mask = tf.concat([dec_attn_mask[:, :qlen] + mask_l - mask_dia,
                                       dec_attn_mask[:, qlen:]], 1)
        # ::: PyTorch masking code for reference :::
        # if self.same_length:
        #     all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
        #     mask_len = klen - self.mem_len
        #     if mask_len > 0:
        #         mask_shift_len = qlen - mask_len
        #     else:
        #         mask_shift_len = qlen
        #     dec_attn_mask = (torch.triu(all_ones, 1+mlen)
        #             + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
        # else:
        #     dec_attn_mask = torch.triu(
        #         word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]

        hids = []
        attentions = []
        if self.attn_type == 0: # default
            pos_seq = tf.range(klen-1, -1, -1.0)
            if self.clamp_len > 0:
                pos_seq = tf.minimum(pos_seq, self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb, training=training)
            pos_emb = self.drop(pos_emb, training=training)

            for i, layer in enumerate(self.layers):
                hids.append(core_out)
                mems_i = None if mems is None else mems[i]
                layer_outputs = layer([core_out, pos_emb, dec_attn_mask,
                                       mems_i, head_mask[i]], training=training)
                core_out = layer_outputs[0]
                if self.output_attentions:
                    attentions.append(layer_outputs[1])
        else: # learnable embeddings and absolute embeddings
            raise NotImplementedError  # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint

        core_out = self.drop(core_out, training=training)

        new_mems = self._update_mems(hids, mems, mlen, qlen)

        # We transpose back here to shape [bsz, len, hidden_dim]
        outputs = [tf.transpose(core_out, perm=(1, 0, 2)), new_mems]
        if self.output_hidden_states:
            # Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
            hids.append(core_out)
            hids = list(tf.transpose(t, perm=(1, 0, 2)) for t in hids)
            outputs.append(hids)
        if self.output_attentions:
            # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
            attentions = list(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
            outputs.append(attentions)
        return outputs  # last hidden state, new_mems, (all hidden states), (all attentions)


class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
    config_class = TransfoXLConfig
    pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
    load_pt_weights = load_transfo_xl_pt_weights_in_tf2
    base_model_prefix = "transformer"


TRANSFO_XL_START_DOCSTRING = r"""    The Transformer-XL model was proposed in
    `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
    by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
    It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
    previously computed hidden-states to attend to longer context (memory).
    This model also uses adaptive softmax inputs and outputs (tied).

    This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
    refer to the TF 2.0 documentation for all matter related to general usage and behavior.

    .. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
        https://arxiv.org/abs/1901.02860

    .. _`tf.keras.Model`:
        https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model

    Note on the model inputs:
        TF 2.0 models accepts two formats as inputs:

            - having all inputs as keyword arguments (like PyTorch models), or
            - having all inputs as a list, tuple or dict in the first positional arguments.

        This second option is usefull when using `tf.keras.Model.fit()` method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`.

        If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :

        - a single Tensor with input_ids only and nothing else: `model(inputs_ids)
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
            `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associaed to the input names given in the docstring:
            `model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`

    Parameters:
        config (:class:`~transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

TRANSFO_XL_INPUTS_DOCSTRING = r"""
    Inputs:
        **input_ids**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
            Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
            the right or on the left.
            Indices can be obtained using :class:`transformers.TransfoXLTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
            :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
        **mems**: (`optional`)
            list of ``Numpy array`` or ``tf.Tensor`` (one for each layer):
            that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
            (see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
        **head_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""

@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
                      TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the last layer of the model.
        **mems**:
            list of ``tf.Tensor`` (one for each layer):
            that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
            (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        import tensorflow as tf
        from transformers import TransfoXLTokenizer, TFTransfoXLModel

        tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
        model = TFTransfoXLModel.from_pretrained('transfo-xl-wt103')
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states, mems = outputs[:2]

    """
    def __init__(self, config, *inputs, **kwargs):
        super(TFTransfoXLModel, self).__init__(config, *inputs, **kwargs)
        self.transformer = TFTransfoXLMainLayer(config, name='transformer')

    def call(self, inputs, **kwargs):
        outputs = self.transformer(inputs, **kwargs)
        return outputs


@add_start_docstrings("""The Transformer-XL Model with a language modeling head on top
    (adaptive softmax with weights tied to the adaptive input embeddings)""",
    TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **prediction_scores**: ``None`` if ``lm_labels`` is provided else ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
            We don't output them when the loss is computed to speedup adaptive softmax decoding.
        **mems**:
            list of ``tf.Tensor`` (one for each layer):
            that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
            (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        import tensorflow as tf
        from transformers import TransfoXLTokenizer, TFTransfoXLLMHeadModel

        tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
        model = TFTransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
        outputs = model(input_ids)
        prediction_scores, mems = outputs[:2]

    """
    def __init__(self, config):
        super(TFTransfoXLLMHeadModel, self).__init__(config)
        self.transformer = TFTransfoXLMainLayer(config, name='transformer')
        self.sample_softmax = config.sample_softmax
        # use sampled softmax
        if config.sample_softmax > 0:
            raise NotImplementedError
        # use adaptive softmax (including standard softmax)
        else:
            self.crit = TFAdaptiveSoftmaxMask(config.n_token, config.d_embed, config.d_model, 
                                              config.cutoffs, div_val=config.div_val, name='crit')

    def reset_length(self, tgt_len, ext_len, mem_len):
        self.transformer.reset_length(tgt_len, ext_len, mem_len)

    def init_mems(self, data):
        return self.transformer.init_mems(data)

    def call(self, inputs, mems=None, head_mask=None, labels=None, training=False):
        if isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
            mems = inputs[1] if len(inputs) > 1 else mems
            head_mask = inputs[2] if len(inputs) > 2 else head_mask
            labels = inputs[3] if len(inputs) > 3 else labels
            assert len(inputs) <= 4, "Too many inputs."
        elif isinstance(inputs, dict):
            input_ids = inputs.get('input_ids')
            mems = inputs.get('mems', mems)
            head_mask = inputs.get('head_mask', head_mask)
            labels = inputs.get('labels', labels)
            assert len(inputs) <= 4, "Too many inputs."
        else:
            input_ids = inputs

        bsz, tgt_len = shape_list(input_ids)[:2]

        transformer_outputs = self.transformer([input_ids, mems, head_mask], training=training)

        last_hidden = transformer_outputs[0]
        pred_hid = last_hidden[:, -tgt_len:]
        outputs = transformer_outputs[1:]
        if self.sample_softmax > 0 and training:
            raise NotImplementedError
        else:
            # pred_hid = tf.reshape(pred_hid, (-1, shape_list(pred_hid)[-1]))
            softmax_output = self.crit([pred_hid, labels], training=training)
            # softmax_output = tf.reshape(softmax_output, (bsz, tgt_len, -1))
            outputs = [softmax_output] + outputs

        return outputs  # logits, new_mems, (all hidden states), (all attentions)
