from typing import Any, Literal

import torch
from tqdm.auto import tqdm

from ..model import MultimodalEmbedderProtocol

VISTA_CITATION = """@article{zhou2024vista,
  title={VISTA: Visualized Text Embedding For Universal Multi-Modal Retrieval},
  author={Zhou, Junjie and Liu, Zheng and Xiao, Shitao and Zhao, Bo and Xiong, Yongping},
  journal={arXiv preprint arXiv:2406.04292},
  year={2024}
}"""


def vista_loader(model_name, **kwargs):
    try:  # a temporal fix for the dependency issues of vista models.
        from visual_bge.modeling import Visualized_BGE
    except ImportError:
        raise ImportError(
            "Please install `visual_bge`, refer to https://github.com/FlagOpen/FlagEmbedding/tree/master/research/visual_bge#install-flagembedding."
        )

    class VisualizedBGEWrapper(Visualized_BGE, MultimodalEmbedderProtocol):
        """Setting up VISTA

        ```
        git clone https://github.com/FlagOpen/FlagEmbedding.git
        cd FlagEmbedding/research/visual_bge
        pip install -e .
        pip install torchvision timm einops ftfy
        ```
        back to the root folder of mteb; download the vision tower for bge-base
        ```
        cd ..
        wget https://huggingface.co/BAAI/bge-visualized/resolve/main/Visualized_base_en_v1.5.pth?download=true
        ```
        rename it to `visualized_base_en_V1.5.pth`
        ```
        mv Visualized_base_en_v1.5.pth?download=true visualized_base_en_V1.5.pth
        ```
        download the vision tower for bge-m3
        ```
        wget https://huggingface.co/BAAI/bge-visualized/resolve/main/Visualized_m3.pth?download=true
        ```
        rename it to `visualized_m3.pth`
        ```
        mv Visualized_m3.pth?download=true visualized_m3.pth
        ```
        """

        def __init__(
            self,
            model_name_bge: str | None = None,
            model_weight=None,
            normlized: bool = True,
            sentence_pooling_method: str = "cls",
            negatives_cross_device: bool = False,
            temperature: float = 0.02,
            from_pretrained=None,
            image_tokens_num: int | None = None,
            **kwargs: Any,
        ):
            super().__init__(
                model_name_bge=model_name_bge,
                model_weight=model_weight,
                normlized=normlized,
                sentence_pooling_method=sentence_pooling_method,
                negatives_cross_device=negatives_cross_device,
                temperature=temperature,
                from_pretrained=from_pretrained,
            )
            self.image_tokens_num = image_tokens_num
            self.max_text_len_with_image = (
                self.tokenizer.model_max_length - image_tokens_num
            )
            self.eval()

        def encode_text(self, texts: dict[str, torch.Tensor]):
            """Currently override Visualized_BGE's the original implementation
            to fix attention_mask & embedding_output dtype misalignment

            Args:
                texts: {"input_ids": ..., "attention_mask": ...}

            Returns:
                Array of text embeddings
            """
            input_ids = texts["input_ids"]
            attention_mask = texts["attention_mask"]

            input_shape = input_ids.size()
            device = input_ids.device

            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

            head_mask = [None] * self.depth
            extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
                attention_mask, input_shape
            )

            embedding_output = self.bge_embeddings(
                input_ids=input_ids,
                position_ids=None,
                token_type_ids=token_type_ids,
                inputs_embeds=None,
                past_key_values_length=0,
            )

            # this line is missing in vista, currently override "encode_text" only to fix this.
            extended_attention_mask = extended_attention_mask.to(embedding_output.dtype)

            encoder_outputs = self.bge_encoder(
                embedding_output,
                attention_mask=extended_attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=None,
                encoder_attention_mask=None,
                past_key_values=None,
                use_cache=False,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True,
            )
            sequence_output = encoder_outputs[0]

            t_reps = self.sentence_embedding(
                sequence_output, texts["attention_mask"]
            )  # tensor: reps with pooling
            if self.normlized:
                t_reps = torch.nn.functional.normalize(t_reps, dim=-1)
            return t_reps.contiguous()

        def embed_text(  # get_text_embeddings 
            self,
            texts,
            batch_size: int = 32,
            show_progress_bar: bool = True,
            prompt=None,
            task_name=None,
            input_type: Literal["document", "query"] | None = None,
            **kwargs: Any,
        ):
            all_text_embeddings = []
            texts = [i['data'] for i in texts]
            n_batch = len(texts) // batch_size + int(len(texts) % batch_size > 0)
            for n in tqdm(
                range(0, n_batch * batch_size, batch_size), disable=not show_progress_bar, desc="Text Encoding"
            ):
                batch = texts[n: n+batch_size]
                with torch.no_grad():
                    text_encoding = self.tokenizer(
                        batch,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                    )
                    batch_embeddings = self.encode_text(text_encoding.to(self.device))
                all_text_embeddings.append(batch_embeddings.cpu())
            return torch.cat(all_text_embeddings, dim=0)

        def embed_image(  # get_image_embeddings
            self,
            images,
            batch_size: int = 32,
            show_progress_bar: bool = True,
            prompt=None,
            task_name=None,
            input_type: Literal["document", "query"] | None = None,
            **kwargs: Any,
        ):
            all_image_embeddings = []
            images = [i['data'] for i in images]
            n_batch = len(images) // batch_size + int(len(images) % batch_size > 0)
            with torch.no_grad():
                for n in tqdm(
                    range(0, n_batch * batch_size, batch_size), disable=not show_progress_bar, desc="Image Encoding"
                ):
                    batch = images[n: n+batch_size]
                    imgs = [self.preprocess_val(image) for image in batch]
                    imgs = torch.stack(imgs)

                    batch_embeddings = self.encode_image(images=imgs.to(self.device))
                    all_image_embeddings.append(batch_embeddings.cpu())
            return torch.cat(all_image_embeddings, dim=0)

        def embed_multimodal(
            self,
            inputs,
            batch_size: int = 32,
            show_progress_bar: bool = True,
            prompt=None,
            task_name=None,
            input_type: Literal["document", "query"] | None = None,
            **kwargs: Any,
        ):
            from PIL.Image import Image

            data = list()
            for item in inputs:
                ins = dict()
                text = ''
                for i in item['data']:
                    if isinstance(i, str):
                        text += (i + ' ')
                    elif isinstance(i, Image):
                        if 'image' not in ins:
                            ins['image'] = i
                ins['text'] = text.strip()
                data.append(ins)

            all_fused_embeddings = []
            n_batch = len(data) // batch_size + int(len(data) % batch_size > 0)
            with torch.no_grad():
                for n in tqdm(
                    range(0, n_batch * batch_size, batch_size),
                    disable=not show_progress_bar,
                    desc="Interleaved Encoding",
                ):
                    batch = data[n: n+batch_size]
                    texts = self.tokenizer(
                        [item["text"] for item in batch],
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=self.max_text_len_with_image,
                    )
                    images = [self.preprocess_val(item["image"]) for item in batch]
                    images = torch.stack(images)
                    all_fused_embeddings.append(
                        self.encode_mm(
                            images.to(self.device), texts.to(self.device)
                        )
                        .cpu()
                        .to(torch.float32)
                    )
            return torch.cat(all_fused_embeddings, dim=0)

    return VisualizedBGEWrapper(model_name, **kwargs)
