# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch.nn as nn
from fairseq.models.transformer import TransformerEncoder

from .linformer_sentence_encoder_layer import LinformerTransformerEncoderLayer


class LinformerTransformerEncoder(TransformerEncoder):
    """
    Implementation for a Bi-directional Linformer based Sentence Encoder used
    in BERT/XLM style pre-trained models.

    This first computes the token embedding using the token embedding matrix,
    position embeddings (if specified) and segment embeddings
    (if specified). After applying the specified number of
    LinformerEncoderLayers, it outputs all the internal states of the
    encoder as well as the final representation associated with the first
    token (usually CLS token).

    Input:
        - tokens: B x T matrix representing sentences
        - segment_labels: B x T matrix representing segment label for tokens

    Output:
        - a tuple of the following:
            - a list of internal model states used to compute the
              predictions where each tensor has shape T x B x C
            - sentence representation associated with first input token
              in format B x C.
    """

    def __init__(self, args, dictionary, embed_tokens):
        self.compress_layer = None
        super().__init__(args, dictionary, embed_tokens)

    def build_encoder_layer(self, args):
        if self.args.shared_layer_kv_compressed == 1 and self.compress_layer is None:
            compress_layer = nn.Linear(
                self.args.max_positions,
                self.args.max_positions // self.args.compressed,
            )
            # intialize parameters for compressed layer
            nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2))
            if self.args.freeze_compress == 1:
                compress_layer.weight.requires_grad = False
            self.compress_layer = compress_layer

        return LinformerTransformerEncoderLayer(args, self.compress_layer)
