# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union

import mmengine.dist as dist
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.utils import track_iter_progress

from mmpretrain.registry import MODELS, TOKENIZER
from mmpretrain.structures import DataSample
from ..blip.blip_retrieval import BlipRetrieval, all_gather_concat


@MODELS.register_module()
class Blip2Retrieval(BlipRetrieval):
    """BLIP2 Retriever.

    Args:
        vision_backbone (dict): Backbone for extracting image features.
        text_backbone (dict): Backbone for extracting text features.
        multimodal_backbone (Optional[dict]): Backbone for extracting
            multi-modal features.
        vision_neck (Optional[dict]): The neck module to process image features
            from vision backbone. Defaults to None.
        text_neck (Optional[dict]): The neck module to process text features
            from text backbone. Defaults to None.
        head (Optional[Union[List[dict], dict]]): The head module to calculate
            loss from processed single modality features.
            See :mod:`mmmultimodal.models.heads`.
            Notice that if the head is not set, `loss` method cannot be used.
            Defaults to None.
        multimodal_head (Optional[Union[List[dict], dict]]): The multi-modal
            head module to calculate loss from processed multimodal features.
            See :mod:`mmmultimodal.models.heads`.
            Notice that if the head is not set, `loss` method cannot be used.
            Defaults to None.
        tokenizer (Optional[dict]): The config for tokenizer. Defaults to None.
        temperature (float): Temperature parameter that controls the
            concentration level of the distribution. Defaults to 0.07.
        fast_match (bool): If False, select topk similarity as candidates and
            compute the matching score. If True, return the similarity as the
            matching score directly. Defaults to False.
        topk (int): Select topk similarity as candidates for compute matching
            scores. Notice that this is not the topk in evaluation.
            Defaults to 256.
        data_preprocessor (Optional[dict]): The config for preprocessing input
            data. If None or no specified type, it will use
            "MultiModalDataPreprocessor" as type.
            See :class:`MultiModalDataPreprocessor` for more details.
            Defaults to None.
        init_cfg (Optional[dict]): the config to control the initialization.
            Defaults to None.
    """

    def __init__(self,
                 vision_backbone: dict,
                 text_backbone: Optional[dict] = None,
                 multimodal_backbone: Optional[dict] = None,
                 vision_neck: Optional[dict] = None,
                 text_neck: Optional[dict] = None,
                 head: Optional[Union[List[dict], dict]] = None,
                 multimodal_head: Optional[Union[List[dict], dict]] = None,
                 tokenizer: Optional[dict] = None,
                 temperature: float = 0.07,
                 fast_match: bool = False,
                 topk: int = 256,
                 data_preprocessor: Optional[dict] = None,
                 init_cfg: Optional[dict] = None) -> None:
        if data_preprocessor is None:
            data_preprocessor = {}
        if isinstance(data_preprocessor, dict):
            data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
            data_preprocessor = MODELS.build(data_preprocessor)

        # Skip BlipRetrieval init
        super(BlipRetrieval, self).__init__(
            init_cfg=init_cfg, data_preprocessor=data_preprocessor)

        self.vision_backbone = MODELS.build(vision_backbone)
        self.ln_vision_backbone = nn.LayerNorm(self.vision_backbone.embed_dims)
        self.tokenizer = TOKENIZER.build(tokenizer)

        if text_backbone is not None:
            self.text_backbone = MODELS.build(text_backbone)

        if multimodal_backbone is not None:
            self.multimodal_backbone = MODELS.build(multimodal_backbone)
            self.multimodal_backbone.resize_token_embeddings(
                len(self.tokenizer))
        self.query_tokens = nn.Parameter(
            torch.zeros(1, self.multimodal_backbone.bert.config.query_length,
                        self.multimodal_backbone.bert.config.hidden_size))
        self.query_tokens.data.normal_(
            mean=0.0,
            std=self.multimodal_backbone.bert.config.initializer_range)

        if vision_neck is not None:
            self.vision_neck = MODELS.build(vision_neck)

        if text_neck is not None:
            self.text_neck = MODELS.build(text_neck)

        if head is not None:
            self.head = MODELS.build(head)

        if multimodal_head is not None:
            self.multimodal_head = MODELS.build(multimodal_head)

        self.temp = nn.Parameter(temperature * torch.ones([]))

        # Notice that this topk is used for select k candidate to compute
        # image-text score, but not the final metric topk in evaluation.
        self.fast_match = fast_match
        self.topk = topk

    def _extract_feat(self, inputs: Union[torch.Tensor, dict],
                      modality: str) -> Tuple[torch.Tensor]:
        """Extract features from the single modality.
        Args:
            inputs (Union[torch.Tensor, dict]): A batch of inputs.
                For image, a tensor of shape (N, C, ...) in general.
                For text, a dict of tokenized text inputs.
            modality (str): Modality feature to be extracted. Only two
                options are supported.

                - ``images``: Only extract image features, mostly used for
                    inference.
                - ``texts``: Only extract text features, mostly used for
                    inference.
        Returns:
            Tuple[torch.Tensor]: The output features.
        """
        if modality == 'images':
            # extract image features
            # TODO:
            # Add layernorm inside backbone and handle the concat outside
            image_embeds = self.ln_vision_backbone(
                self.vision_backbone(inputs)[0])
            image_atts = torch.ones(
                image_embeds.size()[:-1], dtype=torch.long).to(self.device)

            query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1,
                                                    -1)
            query_output = self.multimodal_backbone.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                use_cache=True,
                return_dict=True,
            )
            image_feat = F.normalize(
                self.vision_neck([query_output.last_hidden_state]), dim=-1)
            return {
                'image_embeds': image_embeds,
                'image_feat': image_feat,
                'query_output': query_output
            }
        elif modality == 'texts':
            # extract text features
            text_output = self.multimodal_backbone.bert(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                return_dict=True,
            )
            text_embeds = text_output.last_hidden_state
            text_feat = F.normalize(
                self.text_neck([text_embeds[:, 0, :]]), dim=-1)
            return {'text_embeds': text_embeds, 'text_feat': text_feat}
        else:
            raise RuntimeError(f'Invalid modality "{modality}".')

    def loss(
        self,
        images: torch.Tensor,
        data_samples: Optional[List[DataSample]] = None,
    ) -> Dict[str, torch.tensor]:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            inputs (dict): A batch of inputs. The input tensor with of
                at least one modality. For image, the value is a tensor
                of shape (N, C, ...) in general.
                For text, the value is a dict of tokenized text inputs.
            data_samples (Optional[List[DataSample]]):
                The annotation data of every samples. Defaults to None.

        Returns:
            Dict[str, torch.tensor]: a dictionary of loss components of
                both head and multimodal head.
        """
        output = self.extract_feat(images, data_samples)

        text_ids = output['text_ids']
        text_attn_mask = output['text_attn_mask']
        image_embeds = output['image_embeds']
        image_feat = output['image_feat']
        text_feat = output['text_feat']
        query_output = output['query_output']

        # ITC Loss
        # B*world_size, num_query, D
        image_feat_all = torch.cat(dist.all_gather(image_feat))
        # B*world_size, D
        text_feat_all = torch.cat(dist.all_gather(text_feat))

        # B, B*world_size, num_query
        sim_q2t = torch.matmul(
            image_feat.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze()

        # image to text similarity
        sim_i2t, _ = sim_q2t.max(-1)
        sim_i2t = sim_i2t / self.temp

        # B, B*world_size, num_query
        sim_t2q = torch.matmul(
            text_feat.unsqueeze(1).unsqueeze(1),
            image_feat_all.permute(0, 2, 1)).squeeze()

        # text-image similarity
        sim_t2i, _ = sim_t2q.max(-1)
        sim_t2i = sim_t2i / self.temp

        rank = dist.get_rank()
        bs = images.size(0)
        targets = torch.linspace(
            rank * bs, rank * bs + bs - 1, bs, dtype=int).to(self.device)

        itc_loss = (F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) +
                    F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)) / 2

        # prepare for itm
        text_input_ids_world = torch.cat(dist.all_gather(text_ids))
        text_attention_mask_world = torch.cat(dist.all_gather(text_attn_mask))
        image_embeds_world = torch.cat(dist.all_gather(image_embeds))
        with torch.no_grad():
            weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-4
            weights_t2i[:, rank * bs:rank * bs + bs].fill_diagonal_(0)
            weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-4
            weights_i2t[:, rank * bs:rank * bs + bs].fill_diagonal_(0)

        # select a negative image for each text
        image_embeds_neg = []
        for b in range(bs):
            neg_idx = torch.multinomial(weights_t2i[b], 1).item()
            image_embeds_neg.append(image_embeds_world[neg_idx])
        image_embeds_neg = torch.stack(image_embeds_neg, dim=0)

        # select a negative text for each image
        text_ids_neg = []
        text_atts_neg = []
        for b in range(bs):
            neg_idx = torch.multinomial(weights_i2t[b], 1).item()
            text_ids_neg.append(text_input_ids_world[neg_idx])
            text_atts_neg.append(text_attention_mask_world[neg_idx])

        text_ids_neg = torch.stack(text_ids_neg, dim=0)
        text_atts_neg = torch.stack(text_atts_neg, dim=0)

        text_ids_all = torch.cat([text_ids, text_ids, text_ids_neg],
                                 dim=0)  # pos, pos, neg
        text_atts_all = torch.cat(
            [text_attn_mask, text_attn_mask, text_atts_neg],
            dim=0,
        )

        query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1,
                                                    -1)
        query_atts_itm = torch.ones(
            query_tokens_itm.size()[:-1], dtype=torch.long).to(self.device)
        attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)

        image_embeds_all = torch.cat(
            [image_embeds, image_embeds_neg, image_embeds],
            dim=0)  # pos, neg, pos
        image_atts_all = torch.ones(
            image_embeds_all.size()[:-1], dtype=torch.long).to(self.device)

        output_itm = self.multimodal_backbone.bert(
            text_ids_all,
            query_embeds=query_tokens_itm,
            attention_mask=attention_mask_all,
            encoder_hidden_states=image_embeds_all,
            encoder_attention_mask=image_atts_all,
            return_dict=True,
        )

        vl_embeddings = output_itm.last_hidden_state[:, :query_tokens_itm.
                                                     size(1), :]

        # create false data samples
        data_samples.extend(
            [DataSample(is_matched=False) for _ in range(2 * bs)])
        loss_multimodal = self.multimodal_head.loss((vl_embeddings, ),
                                                    data_samples)

        # LM loss
        decoder_input_ids = text_ids.clone()
        decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
        labels = decoder_input_ids.masked_fill(
            decoder_input_ids == self.tokenizer.pad_token_id, -100)

        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_atts = torch.ones(
            query_tokens.size()[:-1], dtype=torch.long).to(self.device)
        attention_mask = torch.cat([query_atts, text_attn_mask], dim=1)
        lm_output = self.multimodal_backbone(
            decoder_input_ids,
            attention_mask=attention_mask,
            past_key_values=query_output.past_key_values,
            return_dict=True,
            labels=labels,
        )

        return dict(
            itc_loss=itc_loss, **loss_multimodal, lm_loss=lm_output.loss)

    def predict_all(self,
                    feats: Dict[str, torch.Tensor],
                    data_samples: List[DataSample],
                    num_images: int = None,
                    num_texts: int = None,
                    cal_i2t: bool = True,
                    cal_t2i: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute similarity matrix between images and texts across all ranks.

        Args:
            feats (Dict[str, torch.Tensor]): Features from the current rank.
            data_samples (List[DataSample]): Data samples from the current
                rank.
            num_images (int, optional): Number of images to use.
                Defaults to None.
            num_texts (int, optional): Number of texts to use.
                Defaults to None.
            cal_i2t (bool, optional): Whether to compute image-to-text
                similarity. Defaults to True.
            cal_t2i (bool, optional): Whether to compute text-to-image
                similarity. Defaults to True.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Image-to-text and text-to-image
            similarity matrices.
        """
        text_ids = feats['text_ids']
        text_attn_mask = feats['text_attn_mask']
        image_embeds = feats.get('image_embeds', None)
        image_feat = feats['image_feat']
        text_feat = feats['text_feat']

        num_images = num_images or image_feat.size(0)
        num_texts = num_texts or text_feat.size(0)

        if not self.fast_match:
            image_embeds_all = all_gather_concat(image_embeds)[:num_images]
        else:
            image_embeds_all = None
        image_feat_all = all_gather_concat(image_feat)[:num_images]
        text_feat_all = all_gather_concat(text_feat)[:num_texts]
        text_ids_all = all_gather_concat(text_ids)[:num_texts]
        text_attn_mask_all = all_gather_concat(text_attn_mask)[:num_texts]

        results = []
        if cal_i2t:
            result_i2t = self.compute_score_matrix_i2t(
                image_feat,
                image_embeds,
                text_feat_all,
                text_ids_all,
                text_attn_mask_all,
            )
            results.append(
                self._get_predictions(result_i2t, data_samples, mode='i2t'))
        if cal_t2i:
            result_t2i = self.compute_score_matrix_t2i(
                image_feat_all,
                image_embeds_all,
                text_feat,
                text_ids,
                text_attn_mask,
            )
            results.append(
                self._get_predictions(result_t2i, data_samples, mode='t2i'))
        return tuple(results)

    def compute_score_matrix_i2t(self, img_feats: torch.Tensor,
                                 img_embeds: List[torch.Tensor],
                                 text_feats: torch.Tensor,
                                 text_ids: torch.Tensor,
                                 text_atts: torch.Tensor) -> torch.Tensor:
        """Compare the score matrix for image-to-text retrieval. Every image
        should compare to all the text features.

        Args:
            img_feats (torch.Tensor): The input tensor with shape (M, C).
                M stands for numbers of samples on a single GPU.
            img_embeds (List[torch.Tensor]): Image features from each layer of
                the vision backbone.
            text_feats (torch.Tensor): The input tensor with shape (N, C).
                N stands for numbers of all samples on all GPUs.
            text_ids (torch.Tensor): The input tensor with shape (N, C).
            text_atts (torch.Tensor): The input tensor with shape (N, C).

        Returns:
            torch.Tensor: Score matrix of image-to-text retrieval.
        """

        # compute i2t sim matrix
        # TODO: check correctness
        sim_matrix_i2t, _ = (img_feats @ text_feats.t()).max(1)
        if self.fast_match:
            return sim_matrix_i2t

        score_matrix_i2t = torch.full((img_feats.size(0), text_feats.size(0)),
                                      -100.0).to(self.device)

        for i in track_iter_progress(range(img_feats.size(0))):
            sims = sim_matrix_i2t[i]
            topk_sim, topk_idx = sims.topk(k=self.topk, dim=0)
            # get repeated image embeddings
            encoder_output = img_embeds[i].repeat(self.topk, 1, 1)
            encoder_att = torch.ones(
                encoder_output.size()[:-1], dtype=torch.long).to(self.device)
            # query embeds and attention masks
            query_tokens = self.query_tokens.expand(encoder_output.shape[0],
                                                    -1, -1)
            query_atts = torch.ones(
                query_tokens.size()[:-1], dtype=torch.long).to(self.device)
            attention_mask = torch.cat([query_atts, text_atts[topk_idx]],
                                       dim=1)
            output = self.multimodal_backbone.bert(
                text_ids[topk_idx],
                query_embeds=query_tokens,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_output,
                encoder_attention_mask=encoder_att,
                return_dict=True,
            )
            score = self.multimodal_head(
                (output.last_hidden_state[:, :query_tokens.size(1), :],
                 ))[:, :, 1].mean(dim=1)
            score_matrix_i2t[i, topk_idx] = score + topk_sim

        return score_matrix_i2t

    def compute_score_matrix_t2i(self, img_feats: torch.Tensor,
                                 img_embeds: List[torch.Tensor],
                                 text_feats: torch.Tensor,
                                 text_ids: torch.Tensor,
                                 text_atts: torch.Tensor) -> torch.Tensor:
        """Compare the score matrix for text-to-image retrieval.

        Every text should compare to all the image features.

        Args:
            img_feats (torch.Tensor): The input tensor with shape (N, C).
                N stands for numbers of all samples on all GPUs.
            img_embeds (List[torch.Tensor]): Image features from each layer of
                the vision backbone.
            text_feats (torch.Tensor): The input tensor with shape (M, C).
                M stands for numbers of samples on a single GPU.
            text_ids (torch.Tensor): The input tensor with shape (M, C).
            text_atts (torch.Tensor): The input tensor with shape (M, C).

        Returns:
            torch.Tensor: Score matrix of text-to-image retrieval.
        """

        # compute t2i sim matrix
        # TODO: check correctness
        sim_matrix_i2t, _ = (img_feats @ text_feats.t()).max(1)
        sim_matrix_t2i = sim_matrix_i2t.t()
        if self.fast_match:
            return sim_matrix_i2t

        score_matrix_t2i = torch.full((text_feats.size(0), img_feats.size(0)),
                                      -100.0).to(self.device)

        for i in track_iter_progress(range(text_feats.size(0))):
            sims = sim_matrix_t2i[i]
            topk_sim, topk_idx = sims.topk(k=self.topk, dim=0)
            # get topk image embeddings
            encoder_output = img_embeds[topk_idx]
            encoder_att = torch.ones(
                encoder_output.size()[:-1], dtype=torch.long).to(self.device)
            # get query embeds and attention masks
            query_tokens = self.query_tokens.expand(encoder_output.shape[0],
                                                    -1, -1)
            query_atts = torch.ones(
                query_tokens.size()[:-1], dtype=torch.long).to(self.device)
            attention_mask = torch.cat(
                [query_atts, text_atts[i].repeat(self.topk, 1)], dim=1)
            output = self.multimodal_backbone.bert(
                text_ids[i].repeat(self.topk, 1),
                query_embeds=query_tokens,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_output,
                encoder_attention_mask=encoder_att,
                return_dict=True,
            )
            score = self.multimodal_head(
                (output.last_hidden_state[:, :query_tokens.size(1), :],
                 ))[:, :, 1].mean(dim=1)
            score_matrix_t2i[i, topk_idx] = score + topk_sim

        return score_matrix_t2i
