# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
from fairseq.models.transformer import (
    TransformerDecoder,
    TransformerEncoder,
    TransformerModel,
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params


def ensemble_encoder(func):
    def wrapper(self, *args, **kwargs):
        if self.ensemble_models is None or len(self.ensemble_models) == 1:
            return func(self, *args, **kwargs)
        encoder_outs = [
            func(model, *args, **kwargs, return_all_hiddens=True)
            for model in self.ensemble_models
        ]
        _encoder_out = encoder_outs[0].copy()

        def stack(key):
            outs = [e[key][0] for e in encoder_outs]
            return [torch.stack(outs, -1) if outs[0] is not None else None]

        _encoder_out["encoder_out"] = stack("encoder_out")
        _encoder_out["encoder_embedding"] = stack("encoder_embedding")

        num_layers = len(_encoder_out["encoder_states"])
        if num_layers > 0:
            _encoder_out["encoder_states"] = [
                torch.stack([e["encoder_states"][i] for e in encoder_outs], -1)
                for i in range(num_layers)
            ]
        return _encoder_out

    return wrapper


def ensemble_decoder(func):
    def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs):
        if self.ensemble_models is None or len(self.ensemble_models) == 1:
            return func(
                self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs
            )

        def _replace(encoder_out, new_val):
            new_encoder_out = encoder_out.copy()
            new_encoder_out["encoder_out"] = [new_val]
            return new_encoder_out

        action_outs = [
            func(
                model,
                normalize=normalize,
                encoder_out=_replace(
                    encoder_out, encoder_out["encoder_out"][0][:, :, :, i]
                ),
                *args,
                **kwargs
            )
            for i, model in enumerate(self.ensemble_models)
        ]

        if not isinstance(action_outs[0], tuple):  # return multiple values
            action_outs = [[a] for a in action_outs]
        else:
            action_outs = [list(a) for a in action_outs]

        ensembled_outs = []
        for i in range(len(action_outs[0])):
            if i == 0 and normalize:
                ensembled_outs += [
                    torch.logsumexp(
                        torch.stack([a[i] for a in action_outs], -1), dim=-1
                    )
                    - math.log(len(self.ensemble_models))
                ]
            elif action_outs[0][i] is not None:
                ensembled_outs += [torch.stack([a[i] for a in action_outs], -1)]
            else:
                ensembled_outs += [None]

        if len(ensembled_outs) == 1:
            return ensembled_outs[0]
        return tuple(ensembled_outs)

    return wrapper


class FairseqNATModel(TransformerModel):
    """
    Abstract class for all nonautoregressive-based models
    """

    def __init__(self, args, encoder, decoder):
        super().__init__(args, encoder, decoder)
        self.tgt_dict = decoder.dictionary
        self.bos = decoder.dictionary.bos()
        self.eos = decoder.dictionary.eos()
        self.pad = decoder.dictionary.pad()
        self.unk = decoder.dictionary.unk()

        self.ensemble_models = None

    @property
    def allow_length_beam(self):
        return False

    @property
    def allow_ensemble(self):
        return True

    def enable_ensemble(self, models):
        self.encoder.ensemble_models = [m.encoder for m in models]
        self.decoder.ensemble_models = [m.decoder for m in models]

    @staticmethod
    def add_args(parser):
        TransformerModel.add_args(parser)
        parser.add_argument(
            "--apply-bert-init",
            action="store_true",
            help="use custom param initialization for BERT",
        )

    @classmethod
    def build_decoder(cls, args, tgt_dict, embed_tokens):
        decoder = FairseqNATDecoder(args, tgt_dict, embed_tokens)
        if getattr(args, "apply_bert_init", False):
            decoder.apply(init_bert_params)
        return decoder

    @classmethod
    def build_encoder(cls, args, src_dict, embed_tokens):
        encoder = FairseqNATEncoder(args, src_dict, embed_tokens)
        if getattr(args, "apply_bert_init", False):
            encoder.apply(init_bert_params)
        return encoder

    def forward_encoder(self, encoder_inputs):
        return self.encoder(*encoder_inputs)

    def forward_decoder(self, *args, **kwargs):
        return NotImplementedError

    def initialize_output_tokens(self, *args, **kwargs):
        return NotImplementedError

    def forward(self, *args, **kwargs):
        return NotImplementedError


class FairseqNATEncoder(TransformerEncoder):
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(args, dictionary, embed_tokens)
        self.ensemble_models = None

    @ensemble_encoder
    def forward(self, *args, **kwargs):
        return super().forward(*args, **kwargs)


class FairseqNATDecoder(TransformerDecoder):
    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
        super().__init__(args, dictionary, embed_tokens, no_encoder_attn)
        self.ensemble_models = None
