# Copyright (c) 2019-present, Anon.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from logging import getLogger

import torch

from .transformer import TransformerModel
from ..data.dictionary import (
    Dictionary,
    BOS_WORD,
    EOS_WORD,
    PAD_WORD,
    UNK_WORD,
    MASK_WORD,
)
from ..utils import AttrDict

logger = getLogger()


class SentenceEmbedder(object):
    @staticmethod
    def reload(path, params):
        """
        Create a sentence embedder from a pretrained model.
        """
        # reload model
        reloaded = torch.load(path)
        state_dict = reloaded["model"]

        # handle models from multi-GPU checkpoints
        if "checkpoint" in path:
            state_dict = {
                (k[7:] if k.startswith("module.") else k): v
                for k, v in state_dict.items()
            }

        # reload dictionary and model parameters
        dico = Dictionary(
            reloaded["dico_id2word"], reloaded["dico_word2id"], reloaded["dico_counts"]
        )
        pretrain_params = AttrDict(reloaded["params"])
        pretrain_params.n_words = len(dico)
        pretrain_params.bos_index = dico.index(BOS_WORD)
        pretrain_params.eos_index = dico.index(EOS_WORD)
        pretrain_params.pad_index = dico.index(PAD_WORD)
        pretrain_params.unk_index = dico.index(UNK_WORD)
        pretrain_params.mask_index = dico.index(MASK_WORD)

        # build model and reload weights
        model = TransformerModel(pretrain_params, dico, True, True)
        model.load_state_dict(state_dict)
        model.eval()

        # adding missing parameters
        params.max_batch_size = 0

        return SentenceEmbedder(model, dico, pretrain_params)

    def __init__(self, model, dico, pretrain_params):
        """
        Wrapper on top of the different sentence embedders.
        Returns sequence-wise or single-vector sentence representations.
        """
        self.pretrain_params = {k: v for k, v in pretrain_params.__dict__.items()}
        self.model = model
        self.dico = dico
        self.n_layers = model.n_layers
        self.out_dim = model.dim
        self.n_words = model.n_words

    def train(self):
        self.model.train()

    def eval(self):
        self.model.eval()

    def cuda(self):
        self.model.cuda()

    def get_parameters(self, layer_range):

        s = layer_range.split(":")
        assert len(s) == 2
        i, j = int(s[0].replace("_", "-")), int(s[1].replace("_", "-"))

        # negative indexing
        i = self.n_layers + i + 1 if i < 0 else i
        j = self.n_layers + j + 1 if j < 0 else j

        # sanity check
        assert 0 <= i <= self.n_layers
        assert 0 <= j <= self.n_layers

        if i > j:
            return []

        parameters = []

        # embeddings
        if i == 0:
            # embeddings
            parameters += self.model.embeddings.parameters()
            logger.info("Adding embedding parameters to optimizer")
            # positional embeddings
            if self.pretrain_params["sinusoidal_embeddings"] is False:
                parameters += self.model.position_embeddings.parameters()
                logger.info("Adding positional embedding parameters to optimizer")
            # language embeddings
            if hasattr(self.model, "lang_embeddings"):
                parameters += self.model.lang_embeddings.parameters()
                logger.info("Adding language embedding parameters to optimizer")
            parameters += self.model.layer_norm_emb.parameters()
        # layers
        for l in range(max(i - 1, 0), j):
            parameters += self.model.attentions[l].parameters()
            parameters += self.model.layer_norm1[l].parameters()
            parameters += self.model.ffns[l].parameters()
            parameters += self.model.layer_norm2[l].parameters()
            logger.info("Adding layer-%s parameters to optimizer" % (l + 1))

        logger.info(
            "Optimizing on %i Transformer elements."
            % sum([p.nelement() for p in parameters])
        )

        return parameters

    def get_embeddings(self, x, lengths, positions=None, langs=None):
        """
        Inputs:
            `x`        : LongTensor of shape (slen, bs)
            `lengths`  : LongTensor of shape (bs,)
        Outputs:
            `sent_emb` : FloatTensor of shape (bs, out_dim)
        With out_dim == emb_dim
        """
        slen, bs = x.size()
        assert lengths.size(0) == bs and lengths.max().item() == slen

        # get transformer last hidden layer
        tensor = self.model(
            "fwd", x=x, lengths=lengths, positions=positions, langs=langs, causal=False
        )
        assert tensor.size() == (slen, bs, self.out_dim)

        # single-vector sentence representation (first column of last layer)
        return tensor[0]
