# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


from fairseq.models.roberta.hub_interface import RobertaHubInterface
import torch
import torch.nn.functional as F


class XMODHubInterface(RobertaHubInterface):
    def extract_features(
        self,
        tokens: torch.LongTensor,
        return_all_hiddens: bool = False,
        lang_id=None,
    ) -> torch.Tensor:
        if tokens.dim() == 1:
            tokens = tokens.unsqueeze(0)
        if tokens.size(-1) > self.model.max_positions():
            raise ValueError(
                "tokens exceeds maximum length: {} > {}".format(
                    tokens.size(-1), self.model.max_positions()
                )
            )
        features, extra = self.model(
            tokens.to(device=self.device),
            features_only=True,
            return_all_hiddens=return_all_hiddens,
            lang_id=lang_id,
        )
        if return_all_hiddens:
            # convert from T x B x C -> B x T x C
            inner_states = extra["inner_states"]
            return [inner_state.transpose(0, 1) for inner_state in inner_states]
        else:
            return features  # just the last layer's features

    def predict(
        self,
        head: str,
        tokens: torch.LongTensor,
        return_logits: bool = False,
        lang_id=None,
    ):
        features = self.extract_features(tokens.to(device=self.device), lang_id=lang_id)
        logits = self.model.classification_heads[head](features)
        if return_logits:
            return logits
        return F.log_softmax(logits, dim=-1)
