# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import (
    Embedding,
    TransformerDecoderEmbedding,
    TransformerDecoderLayer,
    TransformerDecoderOutputLayer,
    TransformerEncoderEmbedding,
    TransformerEncoderLayer,
    TransformerEncoderLayerNorm,
)
from fairseq.models import (
    BaseFairseqModel,
    FairseqDecoder,
    FairseqEncoder,
    register_model,
    register_model_architecture,
)
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import (
    base_architecture,
    transformer_iwslt_de_en,
    transformer_wmt_en_de_big,
)
from fairseq.modules import SinusoidalPositionalEmbedding


logger = logging.getLogger(__name__)


DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
TORCH_PIPE = False
RPC_INIT = False


def import_pipe():
    global TORCH_PIPE
    global RPC_INIT
    try:
        from torch.distributed.pipeline.sync import Pipe  # noqa

        global Pipe
        from torch.distributed.pipeline.sync.utils import partition_model

        global partition_model
        from torch.distributed import rpc
        import tempfile

        TORCH_PIPE = True
        # Initialize single process RPC agent since TORCH_PIPE requires
        # RRef. RRef depends on RPC being initialized and as a result we initialize
        # RPC with a single node.
        tmpfile = tempfile.NamedTemporaryFile()
        if not RPC_INIT:
            rpc.init_rpc(
                name="worker",
                rank=0,
                world_size=1,
                rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
                    init_method="file://{}".format(tmpfile.name),
                ),
            )
            RPC_INIT = True
        logger.info("Using torch pipe")
    except ImportError:
        try:
            from fairscale.nn import Pipe  # noqa

            logger.info("Using fairscale pipe")
        except ImportError:
            raise ImportError("Please install fairscale with: pip install fairscale")


@register_model("pipeline_parallel_transformer")
class PipelineParallelTransformerModel(BaseFairseqModel):
    def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint):
        import_pipe()
        super().__init__()
        assert isinstance(encoder, FairseqEncoder)
        assert isinstance(decoder, FairseqDecoder)
        encoder_module_list = (
            [encoder.embedding_layer]
            + list(encoder.encoder_layers)
            + [encoder.final_layer_norm]
        )
        self.num_encoder_modules = len(encoder_module_list)
        decoder_module_list = (
            [decoder.embedding_layer]
            + list(decoder.decoder_layers)
            + [decoder.decoder_output_layer]
        )
        self.num_decoder_modules = len(decoder_module_list)
        module_list = encoder_module_list + decoder_module_list
        self.devices = devices
        if TORCH_PIPE:
            self.model = Pipe(
                partition_model(nn.Sequential(*module_list), balance, devices),
                chunks=chunks,
                checkpoint=checkpoint,
            )
        else:
            self.model = Pipe(
                nn.Sequential(*module_list),
                balance=balance,
                devices=devices,
                chunks=chunks,
                checkpoint=checkpoint,
            )
        self.encoder_max_positions = self.max_positions_helper(
            encoder.embedding_layer, "max_source_positions"
        )
        self.decoder_max_positions = self.max_positions_helper(
            decoder.embedding_layer, "max_target_positions"
        )
        self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None)
        # Note: To be populated during inference
        self.encoder = None
        self.decoder = None

    def forward(self, src_tokens, src_lengths, prev_output_tokens):
        if self.training:
            input_lst = [src_tokens, src_lengths, prev_output_tokens]
            input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst)
            if TORCH_PIPE:
                return self.model(input).local_value()
            else:
                return self.model(input)
        else:
            assert self.encoder is not None and self.decoder is not None, (
                "encoder and decoder need to be initialized by "
                + "calling the `prepare_for_inference_()` method"
            )
            encoder_output_tuple = self.encoder(input)
            return self.decoder(encoder_output_tuple)

    def prepare_for_inference_(self, cfg):
        if self.encoder is not None and self.decoder is not None:
            logger.info("Encoder and Decoder already initialized")
            return
        encoder_module_list = []
        decoder_module_list = []
        module_count = 0
        for partition in self.model.partitions:
            for module in partition:
                if module_count < self.num_encoder_modules:
                    encoder_module_list.append(module)
                else:
                    decoder_module_list.append(module)
                module_count += 1
        self.model = None
        self.encoder = TransformerEncoder(
            cfg.distributed_training, None, None, encoder_module_list
        )
        self.decoder = TransformerDecoder(
            cfg.distributed_training,
            None,
            None,
            decoder_module_list=decoder_module_list,
        )

    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--activation-fn',
                            choices=utils.get_available_activation_fns(),
                            help='activation function to use')
        parser.add_argument('--dropout', type=float, metavar='D',
                            help='dropout probability')
        parser.add_argument('--attention-dropout', type=float, metavar='D',
                            help='dropout probability for attention weights')
        parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
                            help='dropout probability after activation in FFN.')
        parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
                            help='path to pre-trained encoder embedding')
        parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
                            help='encoder embedding dimension')
        parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
                            help='encoder embedding dimension for FFN')
        parser.add_argument('--encoder-layers', type=int, metavar='N',
                            help='num encoder layers')
        parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
                            help='num encoder attention heads')
        parser.add_argument('--encoder-normalize-before', action='store_true',
                            help='apply layernorm before each encoder block')
        parser.add_argument('--encoder-learned-pos', action='store_true',
                            help='use learned positional embeddings in the encoder')
        parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
                            help='path to pre-trained decoder embedding')
        parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
                            help='decoder embedding dimension')
        parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
                            help='decoder embedding dimension for FFN')
        parser.add_argument('--decoder-layers', type=int, metavar='N',
                            help='num decoder layers')
        parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
                            help='num decoder attention heads')
        parser.add_argument('--decoder-learned-pos', action='store_true',
                            help='use learned positional embeddings in the decoder')
        parser.add_argument('--decoder-normalize-before', action='store_true',
                            help='apply layernorm before each decoder block')
        parser.add_argument('--share-decoder-input-output-embed', action='store_true',
                            help='share decoder input and output embeddings')
        parser.add_argument('--share-all-embeddings', action='store_true',
                            help='share encoder, decoder and output embeddings'
                                 ' (requires shared dictionary and embed dim)')
        parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
                            help='if set, disables positional embeddings (outside self attention)')
        parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
                            help='comma separated list of adaptive softmax cutoff points. '
                                 'Must be used with adaptive_loss criterion'),
        parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
                            help='sets adaptive softmax dropout for the tail projections')
        parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1,
                            help='Number of embedding layer chunks (enables more even distribution'
                                 'of optimizer states across data parallel nodes'
                                 'when using optimizer state sharding and'
                                 'a big embedding vocabulary)')
        # fmt: on

    @classmethod
    def build_model_base(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if not hasattr(args, "max_source_positions"):
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if not hasattr(args, "max_target_positions"):
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1):
            assert embed_dim % num_embed_chunks == 0, (
                f"Number of embedding chunks = {num_embed_chunks} should be "
                + f"divisible by the embedding dimension = {embed_dim}"
            )
            assert path is None or num_embed_chunks == 1, (
                "Loading embedding from a path with number of embedding chunks > 1"
                + " is not yet supported"
            )
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            # if provided, load from preloaded dictionaries
            if path:
                emb = Embedding(num_embeddings, embed_dim, padding_idx)
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            else:
                embed_chunk_dim = embed_dim // num_embed_chunks
                emb = nn.ModuleList()
                for i in range(num_embed_chunks):
                    emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx))
            return emb

        num_embed_chunks = args.num_embedding_chunks
        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise ValueError("--share-all-embeddings requires a joined dictionary")
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if args.decoder_embed_path and (
                args.decoder_embed_path != args.encoder_embed_path
            ):
                raise ValueError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            encoder_embed_tokens = build_embedding(
                src_dict,
                args.encoder_embed_dim,
                args.encoder_embed_path,
                num_embed_chunks,
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            assert args.share_decoder_input_output_embed or num_embed_chunks == 1, (
                "Not sharing decoder I/O embeddings is not yet supported with number of "
                + "embedding chunks > 1"
            )
            encoder_embed_tokens = build_embedding(
                src_dict,
                args.encoder_embed_dim,
                args.encoder_embed_path,
                num_embed_chunks,
            )
            decoder_embed_tokens = build_embedding(
                tgt_dict,
                args.decoder_embed_dim,
                args.decoder_embed_path,
                num_embed_chunks,
            )

        encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        return (encoder, decoder)

    @classmethod
    def build_encoder(cls, args, src_dict, embed_tokens):
        return TransformerEncoder(args, src_dict, embed_tokens)

    @classmethod
    def build_decoder(cls, args, tgt_dict, embed_tokens):
        return TransformerDecoder(args, tgt_dict, embed_tokens)

    @classmethod
    def build_model(cls, args, task):
        encoder, decoder = cls.build_model_base(args, task)
        return PipelineParallelTransformerModel(
            encoder=encoder,
            decoder=decoder,
            balance=utils.eval_str_list(args.pipeline_balance, type=int),
            devices=utils.eval_str_list(args.pipeline_devices, type=int),
            chunks=args.pipeline_chunks,
            checkpoint=args.pipeline_checkpoint,
        )

    def output_layer(self, features, **kwargs):
        """Project features to the default output size (typically vocabulary size)."""
        return self.decoder.output_layer(features, **kwargs)

    def max_positions(self):
        """Maximum length supported by the model."""
        return (self.encoder_max_positions, self.decoder_max_positions)

    def max_positions_helper(
        self, embedding_layer, max_positions_field="max_source_positions"
    ):
        """Maximum input length supported by the encoder or decoder."""
        if embedding_layer.embed_positions is None:
            return getattr(embedding_layer, max_positions_field)
        return min(
            getattr(embedding_layer, max_positions_field),
            embedding_layer.embed_positions.max_positions,
        )

    def get_normalized_probs(self, net_output, log_probs, sample=None):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
            if sample is not None:
                assert "target" in sample
                target = sample["target"]
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output, target=target)
            return out.exp_() if not log_probs else out

        # A Pipe() module returns a tuple of tensors as the output.
        # In this case, the tuple has one element - the output tensor of logits
        logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0]
        if log_probs:
            return utils.log_softmax(logits, dim=-1, onnx_trace=False)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=False)

    def max_decoder_positions(self):
        """Maximum length supported by the decoder."""
        return self.decoder_max_positions

    def load_state_dict(self, state_dict, strict=True, model_cfg=None):
        """Copies parameters and buffers from *state_dict* into this module and
        its descendants.

        Overrides the method in :class:`nn.Module`. Compared with that method
        this additionally "upgrades" *state_dicts* from old checkpoints.
        """
        self.upgrade_state_dict(state_dict)
        is_regular_transformer = not any("model.partitions" in k for k in state_dict)
        if is_regular_transformer:
            state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict)
        return super().load_state_dict(state_dict, strict)

    def convert_to_pipeline_parallel_state_dict(self, state_dict):
        new_state_dict = self.state_dict()
        encoder_layer_idx = 0
        decoder_layer_idx = 0
        encoder_key_suffixes = [
            "self_attn.k_proj.weight",
            "self_attn.k_proj.bias",
            "self_attn.v_proj.weight",
            "self_attn.v_proj.bias",
            "self_attn.q_proj.weight",
            "self_attn.q_proj.bias",
            "self_attn.out_proj.weight",
            "self_attn.out_proj.bias",
            "self_attn_layer_norm.weight",
            "self_attn_layer_norm.bias",
            "fc1.weight",
            "fc1.bias",
            "fc2.weight",
            "fc2.bias",
            "final_layer_norm.weight",
            "final_layer_norm.bias",
        ]
        decoder_key_suffixes = [
            "self_attn.k_proj.weight",
            "self_attn.k_proj.bias",
            "self_attn.v_proj.weight",
            "self_attn.v_proj.bias",
            "self_attn.q_proj.weight",
            "self_attn.q_proj.bias",
            "self_attn.out_proj.weight",
            "self_attn.out_proj.bias",
            "self_attn_layer_norm.weight",
            "self_attn_layer_norm.bias",
            "encoder_attn.k_proj.weight",
            "encoder_attn.k_proj.bias",
            "encoder_attn.v_proj.weight",
            "encoder_attn.v_proj.bias",
            "encoder_attn.q_proj.weight",
            "encoder_attn.q_proj.bias",
            "encoder_attn.out_proj.weight",
            "encoder_attn.out_proj.bias",
            "encoder_attn_layer_norm.weight",
            "encoder_attn_layer_norm.bias",
            "fc1.weight",
            "fc1.bias",
            "fc2.weight",
            "fc2.bias",
            "final_layer_norm.weight",
            "final_layer_norm.bias",
        ]
        for pid, partition in enumerate(self.model.partitions):
            logger.info(f"Begin Partition {pid}")
            for mid, module in enumerate(partition):
                # fmt: off
                if isinstance(module, TransformerEncoderEmbedding):
                    new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight']
                    new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor']
                if isinstance(module, TransformerEncoderLayer):
                    for suffix in encoder_key_suffixes:
                        new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}']
                    encoder_layer_idx += 1
                if isinstance(module, TransformerDecoderLayer):
                    for suffix in decoder_key_suffixes:
                        new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}']
                    decoder_layer_idx += 1
                if isinstance(module, TransformerEncoderLayerNorm):
                    if 'encoder.layer_norm.weight' in state_dict:
                        new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight']
                        new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias']
                if isinstance(module, TransformerDecoderEmbedding):
                    new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight']
                    new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['decoder.embed_positions._float_tensor']
                if isinstance(module, TransformerDecoderOutputLayer):
                    new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight']
                # fmt: on
        return new_state_dict


class TransformerEncoder(FairseqEncoder):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """

    def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None):
        super().__init__(dictionary)
        self.register_buffer("version", torch.Tensor([3]))
        import_pipe()
        self.use_pipeline = encoder_module_list is not None
        if not self.use_pipeline:
            self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
            self.encoder_layers = nn.Sequential(
                *[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
            )
            if isinstance(embed_tokens, nn.ModuleList):
                emb_dim = sum(e.embedding_dim for e in embed_tokens)
            else:
                emb_dim = embed_tokens.embedding_dim
            self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim)
        else:
            encoder_balance = utils.eval_str_list(
                args.pipeline_encoder_balance, type=int
            )
            encoder_devices = utils.eval_str_list(
                args.pipeline_encoder_devices, type=int
            )
            assert sum(encoder_balance) == len(encoder_module_list), (
                f"Sum of encoder_balance={encoder_balance} is not equal "
                + f"to num_encoder_modules={len(encoder_module_list)}"
            )
            if TORCH_PIPE:
                self.model = Pipe(
                    module=partition_model(
                        nn.Sequential(*encoder_module_list),
                        encoder_balance,
                        encoder_devices,
                    ),
                    chunks=args.pipeline_chunks,
                    checkpoint=args.pipeline_checkpoint,
                )
            else:
                self.model = Pipe(
                    module=nn.Sequential(*encoder_module_list),
                    balance=encoder_balance,
                    devices=encoder_devices,
                    chunks=args.pipeline_chunks,
                    checkpoint=args.pipeline_checkpoint,
                )

    def forward(self, src_tokens, src_lengths):
        """
        Args:
            input_tuple(
                src_tokens (LongTensor): tokens in the source language of shape
                    `(batch, src_len)`
                src_lengths (torch.LongTensor): lengths of each source sentence of
                    shape `(batch)`
            )

        Returns:
            output_tuple(
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - prev_output_tokens
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
            )
        """
        dummy_prev_output_tokens = torch.zeros(
            1, dtype=src_tokens.dtype, device=src_tokens.device
        )
        input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens)
        if self.use_pipeline:
            input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
            if TORCH_PIPE:
                encoder_out = self.model(input_tuple).local_value()
            else:
                encoder_out = self.model(input_tuple)
        else:
            encoder_embed_output_tuple = self.embedding_layer(input_tuple)
            encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple)
            encoder_out = self.final_layer_norm(encoder_layers_output)
        # first element is the encoder output
        # second element is the encoder padding mask
        # the remaining elements of EncoderOut are not computed by
        # the PipelineParallelTransformer
        return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None)

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out.encoder_out is not None:
            encoder_out = encoder_out._replace(
                encoder_out=encoder_out.encoder_out.index_select(1, new_order)
            )
        if encoder_out.encoder_padding_mask is not None:
            encoder_out = encoder_out._replace(
                encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(
                    0, new_order
                )
            )
        if encoder_out.encoder_embedding is not None:
            encoder_out = encoder_out._replace(
                encoder_embedding=encoder_out.encoder_embedding.index_select(
                    0, new_order
                )
            )
        if encoder_out.encoder_states is not None:
            for idx, state in enumerate(encoder_out.encoder_states):
                encoder_out.encoder_states[idx] = state.index_select(1, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embedding_layer.embed_positions is None:
            return self.embedding_layer.max_source_positions
        return min(
            self.embedding_layer.max_source_positions,
            self.embedding_layer.embed_positions.max_positions,
        )


class TransformerDecoder(FairseqDecoder):
    """
    Transformer decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`TransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(
        self,
        args,
        dictionary,
        embed_tokens,
        no_encoder_attn=False,
        decoder_module_list=None,
    ):
        super().__init__(dictionary)
        self.register_buffer("version", torch.Tensor([3]))
        import_pipe()
        self.use_pipeline = decoder_module_list is not None
        if not self.use_pipeline:
            self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
            self.decoder_layers = nn.Sequential(
                *[
                    TransformerDecoderLayer(args, no_encoder_attn)
                    for _ in range(args.decoder_layers)
                ]
            )
            self.decoder_output_layer = TransformerDecoderOutputLayer(
                args, embed_tokens, dictionary
            )
        else:
            decoder_balance = utils.eval_str_list(
                args.pipeline_decoder_balance, type=int
            )
            decoder_devices = utils.eval_str_list(
                args.pipeline_decoder_devices, type=int
            )
            assert sum(decoder_balance) == len(decoder_module_list), (
                f"Sum of decoder_balance={decoder_balance} is not equal "
                + f"to num_decoder_modules={len(decoder_module_list)}"
            )
            if TORCH_PIPE:
                self.model = Pipe(
                    module=partition_model(
                        nn.Sequential(*decoder_module_list),
                        decoder_balance,
                        decoder_devices,
                    ),
                    chunks=args.pipeline_chunks,
                    checkpoint=args.pipeline_checkpoint,
                )
            else:
                self.model = Pipe(
                    module=nn.Sequential(*decoder_module_list),
                    balance=decoder_balance,
                    devices=decoder_devices,
                    chunks=args.pipeline_chunks,
                    checkpoint=args.pipeline_checkpoint,
                )

    def forward(
        self,
        prev_output_tokens,
        encoder_out=None,
    ):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_out (optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`
            features_only (bool, optional): only return features without
                applying output layer (default: False).

        Returns:
            tuple:
                - the decoder's output of shape `(batch, tgt_len, vocab)`
                - a dictionary with any model-specific outputs
        """
        input_tuple = (
            encoder_out.encoder_out,
            encoder_out.encoder_padding_mask,
            prev_output_tokens,
        )
        if self.use_pipeline:
            input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
            if TORCH_PIPE:
                return (self.model(input_tuple).local_value(),)
            else:
                return (self.model(input_tuple),)
        else:
            embed_layer_output = self.embedding_layer(input_tuple)
            state = self.decoder_layers(embed_layer_output)
            return (self.decoder_output_layer(state),)

    def output_layer(self, features, **kwargs):
        """Project features to the vocabulary size."""
        if self.adaptive_softmax is None:
            # project back to size of vocabulary
            if self.share_input_output_embed:
                return F.linear(features, self.embed_tokens.weight)
            else:
                return F.linear(features, self.embed_out)
        else:
            return features

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        if self.embedding_layer.embed_positions is None:
            return self.embedding_layer.max_target_positions
        return min(
            self.embedding_layer.max_target_positions,
            self.embedding_layer.embed_positions.max_positions,
        )

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if (
            not hasattr(self, "_future_mask")
            or self._future_mask is None
            or self._future_mask.device != tensor.device
            or self._future_mask.size(0) < dim
        ):
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
            )
        return self._future_mask[:dim, :dim]

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = "{}.embed_positions.weights".format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict[
                "{}.embed_positions._float_tensor".format(name)
            ] = torch.FloatTensor(1)

        for i in range(len(self.layers)):
            # update layer norms
            layer_norm_map = {
                "0": "self_attn_layer_norm",
                "1": "encoder_attn_layer_norm",
                "2": "final_layer_norm",
            }
            for old, new in layer_norm_map.items():
                for m in ("weight", "bias"):
                    k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
                    if k in state_dict:
                        state_dict[
                            "{}.layers.{}.{}.{}".format(name, i, new, m)
                        ] = state_dict[k]
                        del state_dict[k]

        version_key = "{}.version".format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])

        return state_dict


@register_model_architecture(
    "pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel"
)
def transformer_iwslt_de_en_dist(args):
    transformer_iwslt_de_en(args)


@register_model_architecture(
    "pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel"
)
def transformer_wmt_en_de_big_dist(args):
    transformer_wmt_en_de_big(args)
