import argparse
import logging

import torch.nn as nn
import fairseq.checkpoint_utils
from fairseq.models import (
    FairseqEncoderDecoderModel,
    register_model,
    register_model_architecture,
)
from fairseq.models.transformer import TransformerDecoder
from fairseq.models.roberta import model as roberta

logger = logging.getLogger(__name__)


@register_model("roberta_enc_dec")
class RobertaEncDecModel(FairseqEncoderDecoderModel):
    @staticmethod
    def add_args(parser):
        parser.add_argument(
            "--pretrained-mlm-checkpoint",
            default=None,
            type=str,
            metavar="PRETRAINED",
            help="path to pretrained mlm checkpoint",
        )
        parser.add_argument(
            "--pretrained-decoder", action="store_true", help="reload decoder"
        )
        parser.add_argument(
            "--hack-layernorm-embedding",
            action="store_true",
            help="hack to reload old models trained with encoder-normalize-before=False (no equivalent to encoder-normalize-before=False and layernorm_embedding=False",
        )
        parser.add_argument(
            "--share-decoder-input-output-embed",
            action="store_true",
            help="share decoder input and output embeddings",
        )
        parser.add_argument(
            "--share-all-embeddings",
            action="store_true",
            help="share encoder, decoder and output embeddings"
            " (requires shared dictionary and embed dim)",
        )

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present
        base_enc_dec_architecture(args)
        if args.pretrained_mlm_checkpoint:
            arg_overrides = None
            if args.hack_layernorm_embedding:
                arg_overrides = {"layernorm_embedding": False}
            loaded = fairseq.checkpoint_utils.load_model_ensemble_and_task(
                [args.pretrained_mlm_checkpoint], arg_overrides=arg_overrides
            )
            ([roberta_enc], _cfg, _task) = loaded
        else:
            # Do we need to edit untie_weights here ?
            share_in_out = (
                args.share_decoder_input_output_embed or args.share_all_embeddings
            )
            args.untie_weights_roberta = not share_in_out
            if args.hack_layernorm_embedding:
                args.layernorm_embedding = False
                args.encoder_normalize_before = False
            roberta_enc = roberta.RobertaModel.build_model(args, task)

        return cls.from_roberta(roberta_enc, args, task.source_dictionary)

    @staticmethod
    def from_roberta(roberta_enc: roberta.RobertaModel, args, dictionary):
        encoder = roberta_enc.encoder.sentence_encoder
        vocab_size, embed_dim = encoder.embed_tokens.weight.shape

        if args.share_all_embeddings:
            lm_head = roberta_enc.encoder.lm_head
            assert encoder.embed_tokens.weight is lm_head.weight, (
                "Can't use --share-all-embeddings with a model "
                "that was pretraiend with --untie-weights-roberta_enc"
            )
        else:
            lm_head = roberta.RobertaLMHead(
                embed_dim, vocab_size, roberta_enc.args.activation_fn
            )

        dec_embs = nn.Embedding(vocab_size, embed_dim, dictionary.pad())
        if args.share_all_embeddings or args.share_decoder_input_output_embed:
            # Note: I wasn't able to use Embedding _weight parameter to achive this sharing.
            dec_embs.weight = lm_head.weight

        decoder = TransformerDecoder(
            RobertaEncDecModel.read_args_from_roberta(roberta_enc.args),
            dictionary,
            dec_embs,
            no_encoder_attn=False,
            output_projection=lm_head,
        )
        if getattr(args, "pretrained_decoder", False):
            decoder_dict = encoder.state_dict()

            # TODO: hide setting "encoder_attn" layers behind a flag.
            for k, w in list(decoder_dict.items()):
                if ".self_attn" in k:
                    k_enc_attn = k.replace(".self_attn", ".encoder_attn")
                    decoder_dict[k_enc_attn] = w.detach().clone()

            for k, w in lm_head.state_dict().items():
                decoder_dict["output_projection." + k] = w

            missing_keys, unexpected_keys = decoder.load_state_dict(
                decoder_dict, strict=False
            )
            # missing_keys = [m for m in missing_keys if ".encoder_attn" not in m]
            assert not missing_keys and not unexpected_keys, (
                "Failed to load state dict. "
                f"Missing keys: {missing_keys}. "
                f"Unexpected keys: {unexpected_keys}."
            )

        if args.share_all_embeddings:
            assert decoder.output_projection.weight is decoder.embed_tokens.weight
            assert encoder.embed_tokens.weight is decoder.embed_tokens.weight
        elif args.share_decoder_input_output_embed:
            assert decoder.output_projection.weight is decoder.embed_tokens.weight
            assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight
        else:
            assert decoder.output_projection.weight is not decoder.embed_tokens.weight
            assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight

        return RobertaEncDecModel(encoder, decoder)

    @staticmethod
    def read_args_from_roberta(roberta_args: argparse.Namespace):
        # TODO: this would become easier if encoder/decoder where using a similar
        # TransformerConfig object
        args = argparse.Namespace(**vars(roberta_args))
        attr_map = [
            ("encoder_attention_heads", "decoder_attention_heads"),
            ("encoder_embed_dim", "decoder_embed_dim"),
            ("encoder_embed_dim", "decoder_output_dim"),
            ("encoder_normalize_before", "decoder_normalize_before"),
            ("encoder_layers_to_keep", "decoder_layers_to_keep"),
            ("encoder_ffn_embed_dim", "decoder_ffn_embed_dim"),
            ("encoder_layerdrop", "decoder_layerdrop"),
            ("encoder_layers", "decoder_layers"),
            ("encoder_learned_pos", "decoder_learned_pos"),
            # should this be set from here ?
            ("max_positions", "max_target_positions"),
        ]
        for k1, k2 in attr_map:
            setattr(args, k2, getattr(roberta_args, k1))

        args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
        args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
        args.share_decoder_input_output_embed = not roberta_args.untie_weights_roberta
        return args

    def upgrade_state_dict_named(self, state_dict, name):
        prefix = name + "." if name != "" else ""
        super().upgrade_state_dict_named(state_dict, name)
        old_keys = list(state_dict.keys())

        # rename decoder -> encoder before upgrading children modules
        for k in old_keys:
            if k.startswith(prefix + "encoder.lm_head"):
                state_dict.pop(k)
                continue
            new_k = k
            new_k = new_k.replace(".sentence_encoder.", ".")
            new_k = new_k.replace("decoder.lm_head.", "decoder.output_projection.")
            if k == new_k:
                continue
            # print(k, "->", new_k)
            state_dict[new_k] = state_dict.pop(k)


@register_model_architecture("roberta_enc_dec", "roberta_enc_dec")
def base_enc_dec_architecture(args):
    args.hack_layernorm_embedding = getattr(args, "hack_layernorm_embedding", False)
    args.pretrained_mlm_checkpoint = getattr(args, "pretrained_mlm_checkpoint", None)
    args.pretrained_decoder = getattr(args, "pretrained_decoder", None)
    args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
    args.share_decoder_input_output_embed = getattr(
        args, "share_decoder_input_output_embed", False
    )

    roberta.base_architecture(args)
