import tensorflow as tf
import numpy as np


def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]


def act(x):
    act = "relu"
    if act == "leaky_relu":
        return tf.nn.leaky_relu(x)
    elif act == "elu":
        return tf.nn.elu(x)
    elif act == "gelu":
        cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
        return x * cdf
    else:
        return tf.nn.relu(x)


class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, demb, **kwargs):
        super().__init__(**kwargs)

        self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb))

    def call(self, pos_seq, bsz=None):
        sinusoid_inp = tf.einsum("i,j->ij", pos_seq, self.inv_freq)
        pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)

        if bsz is not None:
            return tf.tile(pos_emb[:, None, :], [1, bsz, 1])
        else:
            return pos_emb[:, None, :]


class PositionwiseFF(tf.keras.layers.Layer):
    def __init__(self, d_model, d_inner, dropout, kernel_initializer,
                 pre_lnorm=False, **kwargs):
        super().__init__(**kwargs)

        self.d_model = d_model
        self.d_inner = d_inner
        self.dropout = dropout

        self.layer_1 = tf.keras.layers.Dense(
            d_inner, kernel_initializer=kernel_initializer, name='layer_1'
        )
        self.act = act
        self.drop_1 = tf.keras.layers.Dropout(dropout, name='drop_1')
        self.layer_2 = tf.keras.layers.Dense(
            d_model, kernel_initializer=kernel_initializer, name='layer_2'
        )
        self.drop_2 = tf.keras.layers.Dropout(dropout, name='drop_2')

        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name='layer_norm')
        self.pre_lnorm = pre_lnorm

    def call(self, inp, training=False):
        if self.pre_lnorm:
            # layer normalization + positionwise feed-forward
            core_out = self.layer_norm(inp)
            core_out = self.act(self.layer_1(core_out))
            core_out = self.drop_1(core_out, training=training)
            core_out = self.layer_2(core_out)
            core_out = self.drop_2(core_out, training=training)

            output = [core_out + inp]
        else:
            # positionwise feed-forward
            core_out = self.act(self.layer_1(inp))
            core_out = self.drop_1(core_out, training=training)
            core_out = self.layer_2(core_out)
            core_out = self.drop_2(core_out, training=training)

            output = [self.layer_norm(inp + core_out)]

        return output


class RelativeMultiHeadAttn(tf.keras.layers.Layer):
    def __init__(
        self,
        n_head,
        d_model,
        d_head,
        dropout,
        dropatt,
        kernel_initializer,
        pre_lnorm=False,
        r_r_bias=None,
        r_w_bias=None,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout
        self.kernel_initializer=kernel_initializer

        self.qkv_net = tf.keras.layers.Dense(
            3 * n_head * d_head, kernel_initializer=kernel_initializer, use_bias=False, name="qkv"
        )
        self.r_net = tf.keras.layers.Dense(
            self.n_head * self.d_head, kernel_initializer=kernel_initializer, use_bias=False, name="r"
        )
        self.drop = tf.keras.layers.Dropout(dropout)
        self.dropatt = tf.keras.layers.Dropout(dropatt)
        self.o_net = tf.keras.layers.Dense(
            d_model, kernel_initializer=kernel_initializer, use_bias=False, name="o"
        )
        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.scale = 1 / (d_head ** 0.5)

        if r_r_bias is not None and r_w_bias is not None:  # Biases are shared
            self.r_r_bias = r_r_bias
            self.r_w_bias = r_w_bias
        else:
            self.r_r_bias = self.add_weight(
                shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
            )
            self.r_w_bias = self.add_weight(
                shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
            )

        self.pre_lnorm = pre_lnorm

    def _rel_shift(self, x):
        x_size = shape_list(x)

        x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
        x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]])
        x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
        x = tf.reshape(x, x_size)

        return x

    def call(self, inputs, training=False):
        w, r, attn_mask, mems = inputs
        qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]

        if mems is not None:
            cat = tf.concat([mems, w], 0)
        else:
            cat = w
        
        if self.pre_lnorm:
            cat = self.layer_norm(cat)

        w_heads = self.qkv_net(cat)
        r_head_k = self.r_net(r)

        w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
        w_head_q = w_head_q[-qlen:]

        klen = shape_list(w_head_k)[0]

        w_head_q = tf.reshape(w_head_q, (qlen, bsz, self.n_head, self.d_head))
        w_head_k = tf.reshape(w_head_k, (klen, bsz, self.n_head, self.d_head))
        w_head_v = tf.reshape(w_head_v, (klen, bsz, self.n_head, self.d_head))

        r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head))

        rw_head_q = w_head_q + self.r_w_bias
        rr_head_q = w_head_q + self.r_r_bias

        AC = tf.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k)
        BD = tf.einsum("ibnd,jnd->ijbn", rr_head_q, r_head_k)
        BD = self._rel_shift(BD)

        attn_score = AC + BD
        attn_score = attn_score * self.scale

        attn_mask_t = attn_mask[:, :, None, None]
        attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t

        attn_prob = tf.nn.softmax(attn_score, axis=1)
        attn_prob = self.dropatt(attn_prob, training=training)

        attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v)
        size_t = shape_list(attn_vec)
        attn_vec = tf.reshape(attn_vec, (size_t[0], size_t[1], self.n_head * self.d_head))

        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out, training=training)

        if self.pre_lnorm:
            outputs = [w + attn_out]
        else:
            outputs = [self.layer_norm(w + attn_out)]

        return outputs


class TransformerXLLayer(tf.keras.layers.Layer):
    def __init__(
        self,
        n_head,
        d_model,
        d_head,
        d_inner,
        dropout,
        dropatt,
        initializer,
        pre_lnorm=False,
        r_w_bias=None,
        r_r_bias=None,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.d_inner = d_inner
        self.dropout = dropout
        self.dropatt = dropatt
        self.initializer = initializer
        self.pre_lnorm = pre_lnorm

        self.xltran_attn = RelativeMultiHeadAttn(
            n_head=self.n_head,
            d_model=self.d_model,
            d_head=self.d_head,
            dropout=self.dropout,
            dropatt=self.dropatt,
            kernel_initializer=self.initializer,
            pre_lnorm=self.pre_lnorm,
            r_w_bias=r_w_bias,
            r_r_bias=r_r_bias,
            name="xltran_attn",
        )
        self.pos_ff = PositionwiseFF(
            d_model=self.d_model,
            d_inner=self.d_inner,
            dropout=self.dropout,
            kernel_initializer=self.initializer,
            pre_lnorm=self.pre_lnorm,
            name="pos_ff",
        )

    def call(self, inputs, training=False):
        inp, r, attn_mask, mems = inputs
        attn_outputs = self.xltran_attn([inp, r, attn_mask, mems], training=training)
        ff_output = self.pos_ff(attn_outputs[0], training=training)

        outputs = [ff_output[0]]

        return outputs


class AdaptiveEmbedding(tf.keras.layers.Layer):
    def __init__(self, n_token, d_embed, d_proj, cutoffs, initializer, \
                 proj_initializer=None, div_val=1, proj_same_dim=True, \
                 use_tpu=True, **kwargs):
        super().__init__(**kwargs)

        self.n_token = n_token
        self.d_embed = d_embed
        self.d_proj = d_proj

        self.cutoffs = cutoffs + [n_token]
        self.cutoff_ends = [0] + self.cutoffs

        self.initializer = initializer
        self.proj_initializer = proj_initializer if proj_initializer is not None else initializer

        self.div_val = div_val
        self.proj_same_dim = proj_same_dim

        self.use_tpu = use_tpu

        self.emb_scale = d_proj ** 0.5

        self.emb_weights = []
        self.emb_projs = []

        for i in range(len(self.cutoffs)):
            l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
            d_emb_i = self.d_embed // (self.div_val ** i)
            self.emb_weights.append(
                self.add_weight(
                    shape=(r_idx - l_idx, d_emb_i),
                    initializer=self.initializer,
                    name="emb_weights_._{}".format(i),
                )
            )
            if d_emb_i == d_proj and \
                    (not self.proj_same_dim or self.div_val == 1):
                self.emb_projs.append(None)
            else:
                self.emb_projs.append(
                    self.add_weight(
                        shape=(d_emb_i, self.d_proj),
                        initializer=self.proj_initializer,
                        trainable=True,
                        name="emb_projs_._{}".format(i),
                    )
                )

    def get_weights(self):
        weights = {"emb_layers": [], "emb_projs": []}
        for i in range(len(self.emb_layers)):
            weights["emb_layers"].append(self.emb_layers[i].get_weights())
            weights["emb_projs"].append(self.emb_projs[i])
        return weights

    @staticmethod
    def _embedding_lookup(lookup_table, x, use_tpu=False):
        if use_tpu:
            n_token = shape_list(lookup_table)[0]
            one_hot_idx = tf.one_hot(x, n_token)
            if one_hot_idx.shape.ndims == 2:
                return tf.einsum('nd,in->id', lookup_table, one_hot_idx)
            else:
                return tf.einsum('nd,ibn->ibd', lookup_table, one_hot_idx)
        else:
            return tf.nn.embedding_lookup(lookup_table, x)

    def call(self, inp):
        inp_flat = tf.reshape(inp, (-1,))
        emb_flat = tf.zeros([shape_list(inp_flat)[0], self.d_proj])
        for i in range(len(self.cutoffs)):
            l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

            mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
            inp_i = tf.minimum(inp_flat, r_idx-1)
            inp_i = tf.maximum(inp_i-l_idx, 0)
            emb_i = self._embedding_lookup(self.emb_weights[i], inp_i, self.use_tpu)
            if self.emb_projs[i] is not None:
                emb_i = tf.einsum("id,de->ie", emb_i, self.emb_projs[i])

            mask_i = tf.tile(tf.reshape(mask_i, [-1, 1]), [1, self.d_proj])
            emb_flat = tf.where(mask_i, emb_i, emb_flat)

        embed_shape = shape_list(inp) + [self.d_proj]
        embed = tf.reshape(emb_flat, embed_shape)

        embed *= self.emb_scale

        return embed


class AdaptiveSoftmax(tf.keras.layers.Layer):
    def __init__(self, n_token, d_embed, d_proj, cutoffs, tie_projs, \
                 initializer=None, proj_initializer=None, div_val=1, \
                 proj_same_dim=True, tied_to=None, **kwargs):
        super().__init__(**kwargs)

        self.n_token = n_token
        self.d_embed = d_embed
        self.d_proj = d_proj

        self.cutoffs = cutoffs + [n_token]
        self.cutoff_ends = [0] + self.cutoffs
        self.n_clusters = len(self.cutoffs) - 1

        self.div_val = div_val
        self.proj_same_dim=True

        self.tied_to = tied_to
        assert tied_to is not None
        self.tie_projs = tie_projs

        self.out_weights = []
        self.out_biases = []
        self.out_projs = []

        if self.n_clusters > 0:
            self.cluster_weight = self.add_weight(
                shape=(self.n_clusters, self.d_embed), initializer="zeros", \
                    trainable=True, name="cluster_weight"
            )
            self.cluster_bias = self.add_weight(
                shape=(self.n_clusters,), initializer="zeros", trainable=True, \
                    name="cluster_bias"
            )

        for i, emb_weight in enumerate(self.tied_to.emb_weights):
            self.out_weights.append(emb_weight)
            vocab_size = shape_list(emb_weight)[0]
            self.out_biases.append(
                self.add_weight(
                    shape=(vocab_size,),
                    initializer="zeros",
                    trainable=True,
                    name="out_layers_._{}_.bias".
                        format(i)
                )
            )

        for i, emb_proj in enumerate(self.tied_to.emb_projs):
            out_proj = emb_proj
            if emb_proj is not None and not self.tie_projs[i]:
                out_proj = self.add_weight(
                    shape=shape_list(emb_proj),
                    initializer=proj_initializer,
                    trainable=True,
                    name="out_projs_._{}".format(i)
                )
            self.out_projs.append(out_proj)

    @staticmethod
    def _logit(x, W, b, proj=None):
        y = x
        if x.shape.ndims == 3:
            if proj is not None:
                y = tf.einsum("ibd,ed->ibe", y, proj)
            return tf.einsum("ibd,nd->ibn", y, W) + b
        else:
            if proj is not None:
                y = tf.einsum('id,ed->ie', y, proj)
            return tf.einsum('id,nd->in', y, W) + b

    @staticmethod
    def _gather_logprob(logprob, target):
        lp_size = shape_list(target)
        r = tf.range(lp_size[0])
        c = tf.range(lp_size[1])
        C, R = tf.meshgrid(c, r)
        idx = tf.stack([R, C, target], axis=2)
        return tf.gather_nd(logprob, idx)

    def call(self, inputs, return_mean=True):
        hidden, target = inputs
        head_logprob = 0
        if self.n_clusters == 0:
            output = self._logit(hidden, self.out_weights[0], self.out_biases[0], self.out_projs[0])
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output)
        else:
            hidden_sizes = shape_list(hidden)
            out = []
            loss = tf.zeros(hidden_sizes[:2], dtype=tf.float32)
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

                mask = (target >= l_idx) & (target < r_idx)
                cur_target = tf.minimum(target, r_idx-1)
                cur_target = tf.maximum(cur_target-l_idx, 0)

                cur_W = self.out_weights[i]
                cur_b = self.out_biases[i]
                cur_P = self.out_projs[i]

                if i == 0:
                    cur_W = tf.concat([cur_W, self.cluster_weight], 0)
                    cur_b = tf.concat([cur_b, self.cluster_bias], 0)

                    head_logit = self._logit(hidden, cur_W, cur_b, cur_P)
                    head_logprob = tf.nn.log_softmax(head_logit)
                        
                    cur_loss = self._gather_logprob(head_logprob, cur_target)
                    loss = tf.where(mask, cur_loss, loss)
                else:
                    tail_logit = self._logit(hidden, cur_W, cur_b, cur_P)
                    tail_logprob = tf.nn.log_softmax(tail_logit)

                    cluster_prob_idx = self.cutoffs[0] + i - 1
                    logprob_i = head_logprob[..., cluster_prob_idx, None] + tail_logprob

                    cur_loss = self._gather_logprob(logprob_i, cur_target)
                    loss = tf.where(mask, cur_loss, loss)
            loss = -loss
        if return_mean:
            loss = tf.reduce_mean(loss)

        return loss


class CompressionLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, pool_size, strides, initializer, 
                 dropout, comp_type, pre_lnorm, **kwargs):
        super().__init__(**kwargs)
        self.d_model = d_model
        self.pool_size = pool_size
        self.strides = strides
        self.dropout = dropout
        self.comp_type = comp_type
        self.pre_lnorm = pre_lnorm
        if self.comp_type == 'avg_pooling':
            self.avg_pooling = tf.keras.layers.AveragePooling1D(
                pool_size=self.pool_size,
                strides=self.strides
            )
        elif self.comp_type == 'max_pooling':
            self.max_pooling = tf.keras.layers.MaxPooling1D(
                pool_size=self.pool_size,
                strides=self.strides
            )
        elif self.comp_type == 'query_only_pooling':
            raise NotImplemented
        else:
            raise NotImplemented

    def call(self, inp, mem, training):
        qlen = shape_list(inp)[0]
        inp = tf.concat([mem, inp], axis=0)
        if self.pool_size > self.strides:
            shape = shape_list(inp)
            pad_shape = [self.pool_size-self.strides] + shape[1:]
            inp = tf.concat([tf.zeros(pad_shape), inp], axis=0)
        if self.comp_type == 'avg_pooling':
            out = tf.transpose(inp, perm=[1, 0, 2])
            out = self.avg_pooling(out)
            out = tf.transpose(out, perm=[1, 0, 2])
        elif self.comp_type == 'max_pooling':
            out = tf.transpose(inp, perm=[1, 0, 2])
            out = self.max_pooling(out)
            out = tf.transpose(out, perm=[1, 0, 2])
        elif self.comp_type == 'query_only_pooling':
            raise NotImplemented
        else:
            raise NotImplemented

        out = out[-qlen:]

        #mean = tf.reduce_mean(out, axis=2, keepdims=True)
        #out = out - mean
        #std = tf.math.reduce_std(out, axis=2, keepdims=True)
        #out = out / (std+1e-10)

        return out


class DropPathLayer(tf.keras.layers.Layer):
    def __init__(self, drop_prob, **kargs):
        super(DropPathLayer, self).__init__(**kargs)
        self.drop_prob = drop_prob

    def call(self, inp, training=False):
        @tf.custom_gradient
        def droppath(x):
            input_shape = shape_list(x)
            mask_shape = input_shape[:-2]+[1, 1]
            thr = tf.reshape(
                (1+tf.range(input_shape[2], dtype=tf.float32))/input_shape[2],
                shape=[1, 1, -1, 1]
            )
            dropping_mask = tf.cast(
                tf.keras.backend.random_uniform(mask_shape) > 1.-self.drop_prob,
                dtype=tf.float32
            )
            rand = tf.keras.backend.random_uniform(mask_shape)
            dropped_rand = rand * dropping_mask

            mask = tf.cast(dropped_rand < thr, tf.float32)
            mask_weight = tf.keras.backend.sum(mask, axis=2, keepdims=True)
            norm_mask = mask / mask_weight

            y = x * norm_mask
            def grad(dy):
                return dy * norm_mask
            return y, grad

        if training:
            return droppath(inp)
        else:
            return inp


class AccumulationLayer(tf.keras.layers.Layer):
    def __init__(self, n_scale, droppath, d_model, initializer, **kwargs):
        super().__init__(**kwargs)
        self.n_scale = n_scale
        self.droppath = droppath
        self.d_model = d_model
        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name='layer_norm')

        self.drop_path = DropPathLayer(self.droppath)

    def _create_expansion_matrix(self, tgt_len, tlen, mlen):
        mf = tgt_len//tlen
        exp_mat = tf.stack([tf.eye(tlen) for _ in range(mf)], axis=2)
        exp_mat = tf.reshape(exp_mat, [tlen, -1])
        exp_mat = tf.concat([tf.zeros((tlen, mf-1)), exp_mat], axis=1)[:, :tgt_len]
        if mlen > 0:
            trow = tf.constant(
                value=[1.]*(mf-1)+[0.]*(tgt_len-mf+1), dtype=tf.float32, shape=[1, tgt_len]
            )
            exp_mat = tf.concat([trow, exp_mat], axis=0)
        return tf.transpose(exp_mat, [1, 0])

    def call(self, inps, training=False):
        tgt_len = shape_list(inps[0]['state'])[0]
        exp_hids = []
        for i, inp in enumerate(inps):
            hl, mem = inp['state'], inp['memory']
            tlen, mlen = shape_list(hl)[0], shape_list(mem)[0]
            if mlen > 0:
                hl = tf.concat([mem[-1, None], hl], axis=0)
            exp_mat = self._create_expansion_matrix(tgt_len, tlen, mlen)
            hl = tf.tensordot(exp_mat, hl, axes=[[1], [0]])
            exp_hids.append(hl)
        exp_hids = tf.stack(exp_hids, axis=2)

        out = self.drop_path(exp_hids, training)
        out = tf.reduce_mean(out, axis=2)

        if len(inps) == 1:
            return out
        else:
            return self.layer_norm(out)


class TransformerQL(tf.keras.Model):
    def __init__(self, n_token, n_layer_per_scale, n_output_layer, d_model, d_embed, n_head, \
                 d_head, d_inner, dropout, dropatt, droppath, initializer, proj_initializer=None, \
                 pre_lnorm=False, comp_type='avg_pooling', tgt_len=None, mem_len=0, cutoffs=[], \
                 div_val=1, tie_projs=[], clamp_len=-1, untie_r=False, proj_same_dim=True, \
                 use_tpu=True, use_mem=True):

        super(TransformerQL, self).__init__()

        self.n_token = n_token
        self.d_model = d_model
        self.d_embed = d_embed 
        self.n_head = n_head
        self.d_head = d_head
        self.d_inner = d_inner

        self.n_layer_per_scale = n_layer_per_scale
        self.n_layer = sum(n_layer_per_scale)
        self.n_scale = len(n_layer_per_scale)
        self.n_output_layer = n_output_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len

        self.dropout = dropout 
        self.dropatt = dropatt
        self.droppath = droppath

        self.cutoffs = cutoffs 
        self.div_val = div_val
        self.tie_projs = tie_projs
        self.clamp_len = clamp_len
        self.untie_r = untie_r
        self.proj_same_dim = proj_same_dim

        self.initializer = initializer
        self.proj_initializer = proj_initializer if proj_initializer is not None else initializer

        self.pre_lnorm = pre_lnorm
        self.comp_type = comp_type
        self.use_tpu = use_tpu
        self.use_mem = use_mem

        self.embedding_layer = AdaptiveEmbedding(
                n_token=self.n_token, 
                d_embed=self.d_embed, 
                d_proj=self.d_model, 
                cutoffs=self.cutoffs, 
                initializer=self.initializer, 
                proj_initializer=self.proj_initializer,
                div_val=self.div_val,
                proj_same_dim=self.proj_same_dim,
                use_tpu=self.use_tpu,
                name='emb_layer'
            )
        self.pos_emb = PositionalEmbedding(d_model)

        self.emb_dropout = tf.keras.layers.Dropout(dropout, name='emb_drop')
        self.pos_dropout = tf.keras.layers.Dropout(dropout, name='pos_drop')

        if not self.untie_r:
            self.r_w_bias = self.add_weight(
                shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
            )
            self.r_r_bias = self.add_weight(
                shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
            )

        self.cum_n_layer_per_scale = [n_layer_per_scale[0]]
        for i in range(1, self.n_scale):
            self.cum_n_layer_per_scale.append(n_layer_per_scale[i]+self.cum_n_layer_per_scale[i-1])

        self.tran_layers = []
        self.comp_layers = []
        self.norm_layers = []
        for i in range(self.n_layer):
            if i in self.cum_n_layer_per_scale:
                self.comp_layers.append(
                    CompressionLayer(
                        self.d_model, 2, 2, 
                        self.initializer, 
                        self.dropout, 
                        self.comp_type,
                        self.pre_lnorm,
                    )
                )
                self.norm_layers.append(
                    tf.keras.layers.LayerNormalization(epsilon=1e-6)
                )

            layer = TransformerXLLayer(
                n_head=self.n_head,
                d_model=self.d_model,
                d_head=self.d_head,
                d_inner=self.d_inner,
                dropout=self.dropout,
                dropatt=self.dropatt,
                initializer=self.initializer,
                pre_lnorm=self.pre_lnorm,
                r_w_bias=None if self.untie_r else self.r_w_bias,
                r_r_bias=None if self.untie_r else self.r_r_bias,
                name='layers_._{}'.format(i)
            )
            self.tran_layers.append(layer)
        self.norm_layers.append(
            tf.keras.layers.LayerNormalization(epsilon=1e-6)
        )

        self.output_layers = []
        for i in range(self.n_output_layer):
            self.output_layers.append(
                TransformerXLLayer(
                    n_head=self.n_head,
                    d_model=self.d_model,
                    d_head=self.d_head,
                    d_inner=self.d_inner,
                    dropout=self.dropout,
                    dropatt=self.dropatt,
                    initializer=self.initializer,
                    pre_lnorm=self.pre_lnorm,
                    r_w_bias=None if self.untie_r else self.r_w_bias,
                    r_r_bias=None if self.untie_r else self.r_r_bias,
                    name='output_layers_._{}'.format(i)
                )
            )

        if self.pre_lnorm:
            self.out_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name='out_layer_norm')
            
        self.out_dropout = tf.keras.layers.Dropout(dropout, name='out_drop')

        self.logsoftmax_layer = AdaptiveSoftmax(
                n_token=self.n_token,
                d_embed=self.d_embed,
                d_proj=self.d_model,
                cutoffs=self.cutoffs,
                tie_projs=self.tie_projs,
                initializer=self.initializer,
                proj_initializer=self.proj_initializer,
                div_val=self.div_val,
                proj_same_dim=self.proj_same_dim,
                tied_to=self.embedding_layer,
                name='softmax_layer'
            )

        self.accumulation_layer = AccumulationLayer(
            self.n_scale,
            self.droppath,
            self.d_model,
            self.initializer
        )

    def reset_length(self, tgt_len, mem_len):
        self.tgt_len = tgt_len
        self.mem_len = mem_len

    def init_mems(self, bsz, mem_len_per_layer):
        mems = []
        for mlen in mem_len_per_layer:
            empty = tf.zeros([mlen, bsz, self.d_model])
            mems.append(empty)
        return mems

    def _update_mems(self, hids, mems, shift_len_per_layer):
        if mems is None:
            return None
        assert len(hids) == len(mems), "len(hids) != len(mems)"
        new_mems = []
        for i in range(len(hids)):
            slen = shift_len_per_layer[i]
            cat = tf.concat([mems[i], hids[i][:slen]], axis=0)
            cat = tf.stop_gradient(cat)
            mlen = shape_list(mems[i])[0]
            if mlen > 0:
                new_mems.append(cat[-mlen:])
            else:
                shape = [mlen]+shape_list(cat)[1:]
                new_mems.append(tf.zeros(shape))
        return new_mems

    def _create_mask(self, qlen, mlen):
        attn_mask = tf.ones([qlen, qlen])
        mask_u = tf.linalg.band_part(attn_mask, 0, -1)
        mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
        attn_mask_pad = tf.zeros([qlen, mlen])
        ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
        return ret

    def call(self, inp, tgt, mems=None, return_mean=False, training=False):
        # the original code for Transformer-XL used shapes [len, bsz] 
        # so we transpose here from shape [bsz, len] to shape [len, bsz]
        inp = tf.transpose(inp, perm=(1, 0))
        tgt = tf.transpose(tgt, perm=(1, 0))

        qlen, bsz = shape_list(inp)

        shift_len_per_layer = []
        slen = qlen
        for i in range(self.n_layer):
            if i in self.cum_n_layer_per_scale:
                # Add the layer before compression layer
                shift_len_per_layer.append(slen)
                slen = slen // 2
            shift_len_per_layer.append(slen)
        # Add extra element for the output of the last layer
        shift_len_per_layer.append(slen)
        # Add shift_len for output layers
        for _ in range(self.n_output_layer):
            shift_len_per_layer.append(qlen)

        if not self.use_mem:
            mem_len_per_layer = [0]*len(shift_len_per_layer)
            mlen = 0
        else:
            total_mlen = shape_list(mems)[0]
            n_layer = len(shift_len_per_layer)
            if total_mlen >= n_layer*qlen:
                mlen = total_mlen//n_layer
                mem_len_per_layer = [mlen]*n_layer
            else:
                n_s1_layer = self.n_layer_per_scale[0]+1+self.n_output_layer
                mlen = (2*total_mlen - (n_layer-n_s1_layer)*qlen)//(n_layer+n_s1_layer)
                mem_len_per_layer = [mlen]*(self.n_layer_per_scale[0]+1) \
                                    + [(mlen+qlen)//2]*(n_layer-n_s1_layer) \
                                    + [mlen]*self.n_output_layer

        if not self.use_mem or mems is None:
            mems = self.init_mems(bsz, mem_len_per_layer)
        else:
            mems = tf.split(mems, mem_len_per_layer, axis=0)
            assert(shape_list(mems[0])[1] == bsz)
            assert(len(mems) == self.n_layer+self.n_scale+self.n_output_layer)

        word_emb = self.embedding_layer(inp)
        d_word_emb = self.emb_dropout(word_emb, training=training)

        klen = mlen + qlen
        pos_seq = tf.range(klen - 1, -1, -1.0)
        if self.clamp_len > 0:
            pos_seq = tf.minimum(pos_seq, self.clamp_len)
        pos_emb = self.pos_emb(pos_seq)
        d_pos_emb = self.pos_dropout(pos_emb, training=training)

        core_out = d_word_emb
        hids = []
        outs = []
        comp_layer_idx = 0
        for i, layer in enumerate(self.tran_layers):
            if i in self.cum_n_layer_per_scale:
                if self.pre_lnorm:
                    core_out = self.norm_layers[comp_layer_idx](core_out)
                hids.append(core_out)

                mems_i = mems[i+comp_layer_idx]
                shift = shift_len_per_layer[i+comp_layer_idx]
                outs.append(
                    {
                        'state': core_out[-shift:], 
                        'memory': tf.concat([mems_i, core_out[:-shift]], axis=0)
                    }
                )
                core_out = self.comp_layers[comp_layer_idx](core_out, mems_i, training=training)
                comp_layer_idx = comp_layer_idx + 1
           
            hids.append(core_out)
            mems_i = mems[i+comp_layer_idx]

            tlen, mlen = shape_list(core_out)[0], shape_list(mems_i)[0]
            klen = tlen + mlen
            attn_mask_i = self._create_mask(tlen, mlen)

            all_out = layer([core_out, d_pos_emb[:klen], attn_mask_i, mems_i], training=training)
            core_out = all_out[0]

        if self.pre_lnorm:
            core_out = self.norm_layers[comp_layer_idx](core_out)

        hids.append(core_out)
        shift = shift_len_per_layer[self.n_layer+self.n_scale-1]
        mems_i = mems[self.n_layer+self.n_scale-1]
        outs.append(
            {
                'state': core_out[-shift:], 
                'memory': tf.concat([mems_i, core_out[:-shift]], axis=0)
            }
        )
        core_out = self.accumulation_layer(outs, training=training)

        for i, layer in enumerate(self.output_layers):
            hids.append(core_out)

            mems_i = mems[i+self.n_layer+self.n_scale]

            tlen, mlen = shape_list(core_out)[0], shape_list(mems_i)[0]
            klen = tlen + mlen
            attn_mask_i = self._create_mask(tlen, mlen)

            all_out = layer([core_out, d_pos_emb[:klen], attn_mask_i, mems_i], training=training)
            core_out = all_out[0]

        if self.pre_lnorm:
            core_out = self.out_layer_norm(core_out)

        core_out = self.out_dropout(core_out, training=training)

        new_mems = self._update_mems(hids, mems, shift_len_per_layer)
        new_mems = tf.concat(new_mems, axis=0)

        loss = self.logsoftmax_layer([core_out, tgt], return_mean=return_mean, training=training)

        # transpose loss back to shape [bsz, len] if necessary
        if loss.shape.ndims == 2:
            loss = tf.transpose(loss, [1, 0])

        return loss, new_mems
        

