import logging
from fairseq.modules.transformer_layer import TransformerEncoderLayer
from fairseq.modules.transformer_layer import TransformerDecoderLayer

from fairseq.models.transformer import TransformerEncoder
from fairseq.models.transformer import TransformerDecoder
from fairseq.models.transformer import TransformerModel


from fairseq.models.transformer import base_architecture, transformer_iwslt_de_en, transformer_vaswani_wmt_en_de_big
from fairseq.modules.adapter_layer import AdapterLayer
from fairseq.models import register_model, register_model_architecture
import torch
from torch import nn


logger = logging.getLogger(__name__)


class AdapterTransformerEncoderLayer(TransformerEncoderLayer):
    def __init__(self, layer_id, args, *args_, **kwargs):
        super().__init__(args, *args_, **kwargs)
        adapter_uids = args.adapter_uids
        if adapter_uids is None:
            adapter_uids = args.adapter_uids or ['default']
        corpus_adapter_uids = getattr(args, "enc_corpus_adapter_uids", None)
        no_stack = getattr(args, "no_stack", False)
        self.corpus_adapter_ld = getattr(args, "corpus_adapter_ld", 0.0)
        self.adapter_dropout = getattr(args, "adapter_dropout", 0.0)
        self.layer_id = layer_id
        self.univ = getattr(args, "univ", False)
        if args.domain_adapter_dim == -1:
            domain_adapter_dim = args.adapter_dim
        else:
            domain_adapter_dim = args.domain_adapter_dim
        if corpus_adapter_uids:
            if not (no_stack and self.layer_id % 2 == 1):
                self.corpus_adapters = nn.ModuleDict({
                    uid: AdapterLayer(
                        args.encoder_embed_dim,
                        domain_adapter_dim,
                        pfeiffer=args.pfeiffer,
                        init=args.adapter_init,
                        dropout=self.adapter_dropout,
                    )
                    for uid in corpus_adapter_uids
                })
            else:
                self.corpus_adapters = None
            if args.adapter_uids and not (no_stack and self.layer_id % 2 == 0):
                self.adapters = nn.ModuleDict({
                    uid: AdapterLayer(
                        args.encoder_embed_dim,
                        args.adapter_dim,
                        pfeiffer=args.pfeiffer,
                        init=args.adapter_init
                    )
                    for uid in adapter_uids
                })
            else:
                self.adapters = None # TODO: repeated code
        else:
            self.adapters = nn.ModuleDict({
                uid: AdapterLayer(
                    args.encoder_embed_dim,
                    args.adapter_dim,
                    pfeiffer=args.pfeiffer,
                    init=args.adapter_init
                )
                for uid in adapter_uids
            })
            self.corpus_adapters = None
        if self.univ:
            self.univ_adapter = AdapterLayer(
                    args.encoder_embed_dim,
                    args.adapter_dim,
                    pfeiffer=args.pfeiffer,
                    init=args.adapter_init
                )
        '''adapter_uids = args.adapter_uids or ['default']
        self.adapters = nn.ModuleDict({
            uid: AdapterLayer(
                args.encoder_embed_dim,
                args.adapter_dim,
                pfeiffer=args.pfeiffer,
                init=args.adapter_init
            )
            for uid in adapter_uids
        })'''
        self.adapter_uid = adapter_uids[0]
        self.return_prenorm = args.pfeiffer

    def forward(self, x, *args, **kwargs):
        x = super().forward(x, *args, **kwargs)
        if self.adapters:
            if self.adapter_uid in self.adapters:
                adapter = self.adapters[self.adapter_uid]
                if self.return_prenorm:
                    if self.corpus_adapters:
                        z = adapter(self.final_layer_norm(x)) + x
                        #x = self.final_layer_norm(y + x)
                    else:
                        y = adapter(self.final_layer_norm(x))
                        x = self.final_layer_norm(y + x)
                else:
                    x = adapter(x)
                #if self.training:
                #    x  = x + 1e-6 * x.uniform_(-1., 1.) # (0.5 * torch.rand(x.size()).to(x.device) - 1.)
        if self.univ:
            x = self.univ_adapter(x)
        if self.corpus_adapters:
            if self.corpus_adapter_uid in self.corpus_adapters:
                #dropout_prob = torch.empty(1).uniform_()
                if not self.training or self.drop_corpus_adapter: #(dropout_prob > self.corpus_adapter_ld):
                    corpus_adapter = self.corpus_adapters[self.corpus_adapter_uid]
                    if self.return_prenorm:
                        #y = corpus_adapter(self.final_layer_norm(x))
                        y = corpus_adapter(z)
                        x = self.final_layer_norm(y + x)
                    else:
                        x = corpus_adapter(x)
        '''if self.adapter_uid in self.adapters:
            adapter = self.adapters[self.adapter_uid]
            if self.return_prenorm:
                y = adapter(self.final_layer_norm(x))
                x = self.final_layer_norm(y + x)
            else:
                x = adapter(x)'''
        return x


class AdapterTransformerDecoderLayer(TransformerDecoderLayer):
    def __init__(self, layer_id, args, *args_, **kwargs):
        super().__init__(args, *args_, **kwargs)
        adapter_uids = args.decoder_adapter_uids
        if adapter_uids is None:
            adapter_uids = args.adapter_uids or ['default']
        corpus_adapter_uids = getattr(args, "corpus_adapter_uids", None)
        no_stack = getattr(args, "no_stack", False)
        self.corpus_adapter_ld = getattr(args, "corpus_adapter_ld", 0.0)
        self.adapter_dropout = getattr(args, "adapter_dropout", 0.0)
        self.layer_id = layer_id
        self.univ = getattr(args, "univ", False)
        if corpus_adapter_uids:
            if not (no_stack and self.layer_id % 2 == 1):
                self.corpus_adapters = nn.ModuleDict({
                    uid: AdapterLayer(
                        args.decoder_embed_dim,
                        args.adapter_dim,
                        pfeiffer=args.pfeiffer,
                        init=args.adapter_init,
                        dropout=self.adapter_dropout,
                    )
                    for uid in corpus_adapter_uids
                })
            else:
                self.corpus_adapters = None
            if args.decoder_adapter_uids and not (no_stack and self.layer_id % 2 == 0):
                self.adapters = nn.ModuleDict({
                    uid: AdapterLayer(
                        args.decoder_embed_dim,
                        args.adapter_dim,
                        pfeiffer=args.pfeiffer,
                        init=args.adapter_init
                    )
                    for uid in adapter_uids
                })
            else:
                self.adapters = None # TODO: repeated code
        else:
            self.adapters = nn.ModuleDict({
                uid: AdapterLayer(
                    args.decoder_embed_dim,
                    args.adapter_dim,
                    pfeiffer=args.pfeiffer,
                    init=args.adapter_init
                )
                for uid in adapter_uids
            })
            self.corpus_adapters = None
        if self.univ:
            self.univ_adapter = AdapterLayer(
                    args.decoder_embed_dim,
                    args.adapter_dim,
                    pfeiffer=args.pfeiffer,
                    init=args.adapter_init
                )
        self.adapter_uid = adapter_uids[0]
        self.return_prenorm = args.pfeiffer

    def forward(self, x, *args, **kwargs):
        x, *extra = super().forward(x, *args, **kwargs)
        if self.adapters:
            if self.adapter_uid in self.adapters:
                adapter = self.adapters[self.adapter_uid]
                if self.return_prenorm:
                    if self.corpus_adapters:
                        z = adapter(self.final_layer_norm(x)) + x
                        #x = self.final_layer_norm(y + x)
                    else:
                        y = adapter(self.final_layer_norm(x))
                        x = self.final_layer_norm(y + x)
                else:
                    x = adapter(x)
                #if self.training:
                #    x  = x + 1e-6 * x.uniform_(-1., 1.) # (0.5 * torch.rand(x.size()).to(x.device) - 1.)
        if self.univ:
            x = self.univ_adapter(x)
        if self.corpus_adapters:
            if self.corpus_adapter_uid in self.corpus_adapters:
                #dropout_prob = torch.empty(1).uniform_()
                if not self.training or self.drop_corpus_adapter: #(dropout_prob > self.corpus_adapter_ld):
                    corpus_adapter = self.corpus_adapters[self.corpus_adapter_uid]
                    if self.return_prenorm:
                        #y = corpus_adapter(self.final_layer_norm(x))
                        y = corpus_adapter(z)
                        x = self.final_layer_norm(y + x)
                    else:
                        x = corpus_adapter(x)

        return (x, *extra)


class AdapterTransformerEncoder(TransformerEncoder):
    def __init__(self, args, *args_, **kwargs):
        self.args = args
        self.layer_id = 0
        self.to_skip = getattr(args, "ignore_enc_layers", [])
        self.freeze_adapters = getattr(args, "freeze_adapters", False)
        self.corpus_adapter_ld = getattr(args, "corpus_adapter_ld", 0.0)
        super().__init__(args, *args_, **kwargs)

    def build_encoder_layer(self, *args, **kwargs):
        if self.layer_id in self.to_skip:
            layer = TransformerEncoderLayer(*args, **kwargs)
        else:
            layer = AdapterTransformerEncoderLayer(self.layer_id, *args, **kwargs)
        self.layer_id += 1
        return layer

    def forward(self, *args, **kwargs):
        l = len(self.layers)
        dropout_prob = torch.empty(l).uniform_()
        drop_list = [d > self.corpus_adapter_ld for d in dropout_prob]
        #print(drop_list, 'list')
        if self.freeze_adapters:
            dropped = [0 if d else 1 for d in drop_list]
            if sum(dropped) == l:
                drop_list[0] = True
        for i, layer in enumerate(self.layers):
            layer.drop_corpus_adapter = drop_list[i]
        if self.args.enc_corpus_adapters and not self.args.lang_adapters:
            corpus = kwargs['meta']['corpus_tag']
            for layer in self.layers:
                layer.corpus_adapter_uid = f'corpus:{corpus}'
        elif self.args.enc_corpus_adapters and self.args.lang_adapters:
            corpus = kwargs['meta']['corpus_tag']
            for layer in self.layers:
                layer.corpus_adapter_uid = f'corpus:{corpus}'
            tgt_lang = kwargs['meta']['tgt_lang']
            for layer in self.layers:
                layer.adapter_uid = f'lang:{tgt_lang}'
        elif self.args.lang_adapters:
            src_lang = kwargs['meta']['src_lang']
            for layer in self.layers:
                layer.adapter_uid = f'lang:{src_lang}'
        elif self.args.lang_pair_adapters:
            src_lang = kwargs['meta']['src_lang']
            tgt_lang = kwargs['meta']['tgt_lang']
            for layer in self.layers:
                layer.adapter_uid = f'lang:{src_lang}->{tgt_lang}'

        return super().forward(*args, **kwargs)


class AdapterTransformerDecoder(TransformerDecoder):
    def __init__(self, args, *args_, **kwargs):
        self.args = args
        self.layer_id = 0
        self.to_skip = getattr(args, "ignore_dec_layers", [])
        self.freeze_adapters = getattr(args, "freeze_adapters", False)
        self.corpus_adapter_ld = getattr(args, "corpus_adapter_ld", 0.0)
        super().__init__(args, *args_, **kwargs)

    def build_decoder_layer(self, *args, **kwargs):
        if self.layer_id in self.to_skip:
            layer = TransformerDecoderLayer(*args, **kwargs)
        else:
            layer = AdapterTransformerDecoderLayer(self.layer_id, *args, **kwargs)
        self.layer_id += 1
        return layer

    def forward(self, *args, **kwargs):
        l = len(self.layers)
        dropout_prob = torch.empty(l).uniform_()
        drop_list = [d > self.corpus_adapter_ld for d in dropout_prob]
        if self.freeze_adapters:
            dropped = [0 if d else 1 for d in drop_list]
            if sum(dropped) == l:
                drop_list[0] = True
        for i, layer in enumerate(self.layers):
            layer.drop_corpus_adapter = drop_list[i]
        if self.args.corpus_adapters and not self.args.lang_adapters:
            corpus = kwargs['meta']['corpus_tag']
            for layer in self.layers:
                layer.corpus_adapter_uid = f'corpus:{corpus}'
        elif self.args.corpus_adapters and self.args.lang_adapters:
            corpus = kwargs['meta']['corpus_tag']
            for layer in self.layers:
                layer.corpus_adapter_uid = f'corpus:{corpus}'
            tgt_lang = kwargs['meta']['tgt_lang']
            for layer in self.layers:
                layer.adapter_uid = f'lang:{tgt_lang}'
        elif self.args.lang_adapters:
            tgt_lang = kwargs['meta']['tgt_lang']
            for layer in self.layers:
                layer.adapter_uid = f'lang:{tgt_lang}'
        elif self.args.lang_pair_adapters:
            src_lang = kwargs['meta']['src_lang']
            tgt_lang = kwargs['meta']['tgt_lang']
            for layer in self.layers:
                layer.adapter_uid = f'lang:{src_lang}->{tgt_lang}'

        return super().forward(*args, **kwargs)


@register_model('adapter_transformer')
class AdapterTransformerModel(TransformerModel):
    """
    Only overriding build_encoder and build_decoder methods.
    """
    @classmethod
    def add_args(cls, parser):
        """Add model-specific arguments to the parser."""
        # fmt: off
        super().add_args(parser)
        parser.add_argument('--adapter-dim', type=int, default=64,
                            help="size of the adapter bottleneck dimension")
        parser.add_argument('--domain-adapter-dim', type=int, default=-1,
                            help="size of the adapter bottleneck dimension")
        parser.add_argument('--ignore-enc-layers', nargs='+', type=int,
                            help='no adapters at these encoder layers (zero-indexed layer IDs)')
        parser.add_argument('--ignore-dec-layers', nargs='+', type=int,
                            help='no adapters at these decoder layers (zero-indexed layer IDs)')
        parser.add_argument('--pfeiffer', action='store_true')
        parser.add_argument('--encoder-only', action='store_true', help='no encoder adapters')
        parser.add_argument('--decoder-only', action='store_true', help='no decoder adapters')
        parser.add_argument('--adapter-init', default='small', choices=['small', 'bert'])
        parser.add_argument('--train-all-params', action='store_true',
                            help='train all model parameters, not only the adapters')
        parser.add_argument('--lang-adapters', action='store_true',
                            help='train per-language adapters')
        parser.add_argument('--corpus-adapters', action='store_true',
                            help='train per-language adapters')
        parser.add_argument('--enc-corpus-adapters', action='store_true',
                            help='train per-language adapters')
        parser.add_argument('--lang-pair-adapters', action='store_true',
                            help='train per-language-pair adapters')
        parser.add_argument('--univ', action='store_true',
                            help='dont stack corpus and language adapters')
        parser.add_argument('--no-stack', action='store_true',
                            help='dont stack corpus and language adapters')
        parser.add_argument('--freeze-adapters', action='store_true',
                            help='dont stack corpus and language adapters')
        parser.add_argument('--corpus-adapter-ld', type=float, default=0.0,
                            help="size of the adapter bottleneck dimension")
        parser.add_argument('--adapter-dropout', type=float, default=0.0,
                            help="size of the adapter bottleneck dimension")
        parser.add_argument('--existing-src-langs', type=str, default='none',
                            help="size of the adapter bottleneck dimension")
        parser.add_argument('--existing-tgt-langs', type=str, default='none',
                            help="size of the adapter bottleneck dimension")

    def __init__(self, args, encoder, decoder):
        super().__init__(args, encoder, decoder)
        args.freeze_adapters = getattr(args, "freeze_adapters", False)
        for name, parameter in self.named_parameters():
            if args.freeze_adapters and "corpus_adapters" not in name:
                parameter.requires_grad = False
            elif not args.train_all_params and "adapters" not in name:
                parameter.requires_grad = False

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        from fairseq.tasks.translation import TranslationTask

        assert isinstance(task, TranslationTask) and args.dynamic_dataset

        args.corpus_adapters = getattr(args, "corpus_adapters", False)
        args.enc_corpus_adapters = getattr(args, "enc_corpus_adapters", False)
        args.existing_src_langs = getattr(args, "existing_src_langs", 'none')
        args.existing_tgt_langs = getattr(args, "existing_tgt_langs", 'none')

        if args.existing_src_langs != 'none':
            src_langs = args.existing_src_langs.split(',')
        else:
            src_langs = task.src_langs
        if args.existing_tgt_langs != 'none':
            tgt_langs = args.existing_tgt_langs.split(',')
        else:
            tgt_langs = task.tgt_langs

        if args.enc_corpus_adapters:
            if args.enc_corpus_adapter_uids is None:
                args.enc_corpus_adapter_uids = [f'corpus:{corpus}' for corpus in task.corpus_tags]

        if args.corpus_adapters and not args.lang_adapters:
            if args.adapter_uids is None:
                args.adapter_uids = [f'lang:{lang}' for lang in src_langs]
            if args.corpus_adapter_uids is None:
                args.corpus_adapter_uids = [f'corpus:{corpus}' for corpus in task.corpus_tags]

        if args.corpus_adapters and args.lang_adapters:
            if args.adapter_uids is None:
                args.adapter_uids = [f'lang:{lang}' for lang in src_langs]
            if args.decoder_adapter_uids is None:
                args.decoder_adapter_uids = [f'lang:{lang}' for lang in tgt_langs]
            if args.corpus_adapter_uids is None:
                args.corpus_adapter_uids = [f'corpus:{corpus}' for corpus in task.corpus_tags]
        elif args.lang_adapters:
            if args.adapter_uids is None:
                args.adapter_uids = [f'lang:{lang}' for lang in src_langs]
            if args.decoder_adapter_uids is None:
                args.decoder_adapter_uids = [f'lang:{lang}' for lang in tgt_langs]
        elif args.lang_pair_adapters:
            if args.adapter_uids is None:
                args.adapter_uids = [f'lang:{src}->{tgt}' for src, tgt in task.lang_pairs]

        return super().build_model(args, task)

    @classmethod
    def build_encoder(cls, args, *args_, **kwargs):
        if args.decoder_only:
            return TransformerEncoder(args, *args_, **kwargs)
        else:
            return AdapterTransformerEncoder(args, *args_, **kwargs)

    @classmethod
    def build_decoder(cls, args, *args_, **kwargs):
        if args.encoder_only:
            return TransformerDecoder(args, *args_, **kwargs)
        else:
            return AdapterTransformerDecoder(args, *args_, **kwargs)

    def load_state_dict(self, state_dict, strict=False, args=None):
        """
        Some hacks to load TransformerModel checkpoints into
        AdapterModel.
        """
        self.upgrade_state_dict(state_dict)

        status = super().load_state_dict(state_dict, strict=False)

        if status.missing_keys:
            logger.info("Missing keys detected")

        if status.unexpected_keys:
            logger.info("Unexpected keys found")


@register_model_architecture('adapter_transformer', 'adapter_transformer')
def adapter_transformer(args):
    args.adapter_dim = getattr(args, "adapter_dim", 64)
    args.domain_adapter_dim = getattr(args, "domain_adapter_dim", -1)
    args.pfeiffer = getattr(args, "pfeiffer", False)
    args.encoder_only = getattr(args, "encoder_only", False)
    args.decoder_only = getattr(args, "decoder_only", False)
    args.adapter_init = getattr(args, "adapter_init", "small")
    args.train_all_params = getattr(args, "train_all_params", False)
    args.lang_adapters = getattr(args, "lang_adapters", False)
    args.lang_pair_adapters = getattr(args, "lang_pair_adapters", False)
    base_architecture(args)

@register_model_architecture('adapter_transformer', 'adapter_transformer_iwslt_de_en')
def adapter_transformer_iwslt_de_en(args):
    transformer_iwslt_de_en(args)
    adapter_transformer(args)

@register_model_architecture('adapter_transformer', 'adapter_transformer_vaswani_wmt_en_de_big')
def adapter_transformer_vaswani_wmt_en_de_big(args):
    transformer_vaswani_wmt_en_de_big(args)
    adapter_transformer(args)
