# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from collections import OrderedDict

import torch.nn as nn
from fairseq import utils
from fairseq.models import (
    FairseqEncoderDecoderModel,
    FairseqMultiModel,
    FairseqEncoder,
    FairseqDecoder,
    register_model,
    register_model_architecture,
)
from fairseq.models.transformer import (
    Embedding,
    TransformerDecoder,
    TransformerEncoder,
    TransformerModel,
    base_architecture,
)


@register_model("modular_transformer")
class ModularTransformerModel(FairseqEncoderDecoderModel):
    """Train Transformer models for multiple language pairs simultaneously.

    Requires `--task translation --dynamic-dataset`.

    We inherit all arguments from TransformerModel and assume that all language
    pairs use a single Transformer architecture. In addition, we provide several
    options that are specific to the multilingual setting.

    Args:
        --share-encoder-embeddings: share encoder embeddings across all source languages
        --share-decoder-embeddings: share decoder embeddings across all target languages
        --share-encoders: share all encoder params (incl. embeddings) across all source languages
        --share-decoders: share all decoder params (incl. embeddings) across all target languages
    """

    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        TransformerModel.add_args(parser)
        parser.add_argument(
            "--share-encoder-embeddings",
            action="store_true",
            help="share encoder embeddings across languages",
        )
        parser.add_argument(
            "--share-decoder-embeddings",
            action="store_true",
            help="share decoder embeddings across languages",
        )
        parser.add_argument(
            "--share-encoders",
            action="store_true",
            help="share encoders across languages",
        )
        parser.add_argument(
            "--share-decoders",
            action="store_true",
            help="share decoders across languages",
        )

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        from fairseq.tasks.translation import TranslationTask
        assert isinstance(task, TranslationTask) and args.dynamic_dataset

        # make sure all arguments are present in older models
        base_modular_architecture(args)

        if not hasattr(args, "max_source_positions"):
            args.max_source_positions = 1024
        if not hasattr(args, "max_target_positions"):
            args.max_target_positions = 1024

        src_langs = task.src_langs
        tgt_langs = task.tgt_langs

        if args.share_encoders:
            args.share_encoder_embeddings = True
        if args.share_decoders:
            args.share_decoder_embeddings = True

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        # build shared embeddings (if applicable)
        shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
        if args.share_all_embeddings:
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if args.decoder_embed_path and (
                args.decoder_embed_path != args.encoder_embed_path
            ):
                raise ValueError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                dicts=task.dicts,
                langs=task.langs,
                embed_dim=args.encoder_embed_dim,
                build_embedding=build_embedding,
                pretrained_embed_path=args.encoder_embed_path,
            )
            shared_decoder_embed_tokens = shared_encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            if args.share_encoder_embeddings:
                shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                    dicts=getattr(task, 'src_dicts', task.dicts),
                    langs=src_langs,
                    embed_dim=args.encoder_embed_dim,
                    build_embedding=build_embedding,
                    pretrained_embed_path=args.encoder_embed_path,
                )
            if args.share_decoder_embeddings:
                shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                    dicts=getattr(task, 'tgt_dicts', task.dicts),
                    langs=tgt_langs,
                    embed_dim=args.decoder_embed_dim,
                    build_embedding=build_embedding,
                    pretrained_embed_path=args.decoder_embed_path,
                )

        # encoders/decoders for each language
        lang_encoders, lang_decoders = {}, {}

        def get_encoder(lang):
            if lang not in lang_encoders:
                if shared_encoder_embed_tokens is not None:
                    encoder_embed_tokens = shared_encoder_embed_tokens
                else:
                    encoder_embed_tokens = build_embedding(
                        task.dicts[lang],
                        args.encoder_embed_dim,
                        args.encoder_embed_path,
                    )
                lang_encoders[lang] = cls._get_module_class(
                    True, args, task.dicts[lang], encoder_embed_tokens, src_langs
                )
            return lang_encoders[lang]

        def get_decoder(lang):
            if lang not in lang_decoders:
                if shared_decoder_embed_tokens is not None:
                    decoder_embed_tokens = shared_decoder_embed_tokens
                else:
                    decoder_embed_tokens = build_embedding(
                        task.dicts[lang],
                        args.decoder_embed_dim,
                        args.decoder_embed_path,
                    )
                lang_decoders[lang] = cls._get_module_class(
                    False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs
                )
            return lang_decoders[lang]

        if args.share_encoders:
            encoder = get_encoder(src_langs[0])
        else:
            encoders = {lang: get_encoder(lang) for lang in src_langs}
            encoder = ModularTransformerEncoder(encoders)

        if args.share_decoders:
            decoder = get_decoder(tgt_langs[0])
        else:
            decoders = {lang: get_decoder(lang) for lang in tgt_langs}
            decoder = ModularTransformerDecoder(decoders)
        
        return ModularTransformerModel(encoder, decoder)

    @classmethod
    def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
        module_class = TransformerEncoder if is_encoder else TransformerDecoder
        return module_class(args, lang_dict, embed_tokens)


    def load_state_dict(self, state_dict, strict=True, args=None):
        self_state_dict = self.state_dict()
        for k in list(state_dict.keys()):
            # remove unused encoders and decoders
            if k not in self_state_dict and (k.startswith('encoder.encoders.') or k.startswith('decoder.decoders.')):
                state_dict.pop(k)
        return super().load_state_dict(state_dict, strict, args)


class ModularTransformerEncoder(FairseqEncoder):
    def __init__(self, encoders):
        super().__init__(dictionary=None)
        self.encoders = nn.ModuleDict(encoders)

    def forward(self, *args, **kwargs):
        src_lang = kwargs['meta']['src_lang']
        encoder = self.encoders[src_lang]

        return encoder(*args, **kwargs)

    def max_positions(self):
        return min(encoder.max_positions() for encoder in self.encoders.values())

    def reorder_encoder_out(self, *args, **kwargs):
        return next(iter(self.encoders.values())).reorder_encoder_out(*args, **kwargs)


class ModularTransformerDecoder(FairseqDecoder):
    def __init__(self, decoders):
        super().__init__(dictionary=None)
        self.decoders = nn.ModuleDict(decoders)
        self.lang = None

    def forward(self, *args, **kwargs):
        tgt_lang = kwargs['meta']['tgt_lang']
        decoder = self.decoders[tgt_lang]

        return decoder(*args, **kwargs)

    def max_positions(self):
        return min(decoder.max_positions() for decoder in self.decoders.values())


@register_model_architecture("modular_transformer", "modular_transformer")
def base_modular_architecture(args):
    base_architecture(args)
    args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False)
    args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False)
    args.share_encoders = getattr(args, "share_encoders", False)
    args.share_decoders = getattr(args, "share_decoders", False)


@register_model_architecture(
    "modular_transformer", "modular_transformer_iwslt_de_en"
)
def modular_transformer(args):
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
    args.encoder_layers = getattr(args, "encoder_layers", 6)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
    args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
    args.decoder_layers = getattr(args, "decoder_layers", 6)
    base_modular_architecture(args)
