# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
from mmengine.model import BaseModule

from mmpretrain.registry import MODELS


@MODELS.register_module()
class SeqGenerationHead(BaseModule):
    """Generation head for multi-modal pre-trained task, adopted by BLIP.
    Normally used for generation task.

    Args:
        decoder (dict): Decoder for blip generation head.
        init_cfg (dict, optional): the config to control the initialization.
            Defaults to None.
    """

    def __init__(
        self,
        decoder: dict,
        ignore_index=-100,
        loss: dict = dict(type='LabelSmoothLoss', label_smooth_val=0.1),
        init_cfg: Optional[dict] = None,
    ) -> None:
        super(SeqGenerationHead, self).__init__(init_cfg=init_cfg)
        self.decoder = MODELS.build(decoder)
        self.loss_fn = MODELS.build(loss)
        self.ignore_index = ignore_index

    def forward(self, input_ids: torch.Tensor,
                encoder_hidden_states: torch.Tensor,
                encoder_attention_mask: torch.Tensor, labels: torch.Tensor):
        """Forward to get decoder output.

        Args:
            input_ids (torch.Tensor): The tokenized input text tensor.
            encoder_hidden_states (torch.Tensor): Hidden states from image
                embeddings.
            encoder_attention_mask (torch.Tensor): Image embeddings hidden
                states attention mask.
            labels (torch.Tensor): Decoder target for calculate loss.

        Returns:
            dict[str, Tensor]: a dictionary of decoder outputs.
        """

        decoder_out = self.decoder(
            input_ids=input_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            labels=labels,
            return_dict=True,
        )
        return decoder_out

    def loss(self, input_ids, encoder_hidden_states, encoder_attention_mask,
             labels):
        """Calculate losses from the extracted features.

        Args:
            input_ids (torch.Tensor): The tokenized input text tensor.
            encoder_hidden_states (torch.Tensor): Hidden states from image
                embeddings.
            encoder_attention_mask (torch.Tensor): Image embeddings hidden
                states attention mask.
            labels (torch.Tensor): Decoder target for calculate loss.

        Returns:
            dict[str, Tensor]: a dictionary of loss components.
        """

        decoder_out = self(
            input_ids=input_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            labels=labels,
        )
        prediction_scores = decoder_out['logits']
        # we are doing next-token prediction;
        # shift prediction scores and input ids by one
        shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        vocab_size = prediction_scores.shape[-1]

        # mask ignored index
        if (labels == self.ignore_index).any():
            labels = labels.view(-1).clone()
            ignore_mask = (labels == self.ignore_index)
            labels.masked_fill_(ignore_mask, 0)
            weight = torch.logical_not(ignore_mask)
            avg_factor = max(weight.sum(), 1)
        else:
            weight = None
            avg_factor = labels.size(0)

        lm_loss = self.loss_fn(
            shifted_prediction_scores.view(-1, vocab_size),
            labels,
            weight=weight,
            avg_factor=avg_factor,
        )
        losses = {
            'seq_gen_lm_loss': lm_loss,
        }

        return losses

    def predict(self,
                input_ids,
                encoder_hidden_states,
                sep_token_id,
                pad_token_id,
                use_nucleus_sampling=False,
                num_beams=3,
                max_length=20,
                min_length=2,
                top_p=0.9,
                repetition_penalty=1.0,
                **kwargs):
        """Decoder prediction method.

        Args:
            input_ids (torch.Tensor): The tokenized input text tensor.
            encoder_hidden_states (torch.Tensor): Hidden states from image
                embeddings.
            sep_token_id (int): Tokenid of separation token.
            pad_token_id (int): Tokenid of pad token.
            use_nucleus_sampling (bool): Whether to use nucleus sampling in
                prediction. Defaults to False.
            num_beams (int): Number of beams used in predition.
                Defaults to 3.
            max_length (int): Max length of generated text in predition.
                Defaults to 20.
            min_length (int): Min length of generated text in predition.
                Defaults to 20.
            top_p (float):
                If < 1.0, only keep the top tokens with cumulative probability
                 >= top_p (nucleus filtering). Defaults to 0.9.
            repetition_penalty (float): The parameter for repetition penalty.
                Defaults to 1.0.
            **kwarg: Other arguments that might used in generation.

        Returns:
            dict[str, Tensor]: a dictionary of generation outputs.
        """
        device = encoder_hidden_states.device

        # TODO: In old version of transformers
        # Additional repeat interleave of hidden states should be add here.
        image_atts = torch.ones(
            encoder_hidden_states.size()[:-1], dtype=torch.long).to(device)

        model_kwargs = {
            'encoder_hidden_states': encoder_hidden_states,
            'encoder_attention_mask': image_atts,
        }
        model_kwargs.update(kwargs)

        if use_nucleus_sampling:
            # nucleus sampling
            outputs = self.decoder.generate(
                input_ids=input_ids,
                max_length=max_length,
                min_length=min_length,
                do_sample=True,
                top_p=top_p,
                num_return_sequences=1,
                eos_token_id=sep_token_id,
                pad_token_id=pad_token_id,
                repetition_penalty=1.1,
                **model_kwargs)
        else:
            # beam search
            outputs = self.decoder.generate(
                input_ids=input_ids,
                max_length=max_length,
                min_length=min_length,
                num_beams=num_beams,
                eos_token_id=sep_token_id,
                pad_token_id=pad_token_id,
                repetition_penalty=repetition_penalty,
                **model_kwargs)

        return outputs
