# From https://github.com/embeddings-benchmark/mteb/blob/main/mteb/models/model_implementations/vlm2vec_models.py

import logging
from typing import Any

import torch
from tqdm.auto import tqdm

from ..model import MultimodalEmbedderProtocol

logger = logging.getLogger(__name__)

VLM2VEC_CITATION = """@article{jiang2024vlm2vec,
  title={VLM2Vec: Training Vision-Language Models for Massive Multimodal Embedding Tasks},
  author={Jiang, Ziyan and Meng, Rui and Yang, Xinyi and Yavuz, Semih and Zhou, Yingbo and Chen, Wenhu},
  journal={arXiv preprint arXiv:2410.05160},
  year={2024}
}"""


class VLM2VecWrapper(MultimodalEmbedderProtocol):
    """Adapted from https://github.com/TIGER-AI-Lab/VLM2Vec/blob/main/src/model.py"""

    def __init__(
        self,
        model_name: str = "TIGER-Lab/VLM2Vec-LoRA",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        base_model_name = "microsoft/Phi-3.5-vision-instruct",
        **kwargs,
    ):
        import flash_attn  # noqa
        from peft import LoraConfig, PeftModel
        from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor

        self.pooling = "last"
        self.normalize = True
        self.temperature = 1.0
        self.hidden_size = 4096
        self.device = device

        # Loading the base model
        config = AutoConfig.from_pretrained(base_model_name, trust_remote_code=True)
        config.use_cache = False
        config.padding_side = "right"

        checkpoint_path = base_model_name if "LoRA" in model_name else model_name
        base_model = AutoModelForCausalLM.from_pretrained(
            checkpoint_path,
            config=config,
            attn_implementation="flash_attention_2",
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        )
        base_model.padding_side = "right"

        # Building the model on top of the base
        if "LoRA" in model_name:
            lora_config = LoraConfig.from_pretrained(model_name)
            lora_model = PeftModel.from_pretrained(
                base_model, model_name, config=lora_config
            )
            merged_model = lora_model.merge_and_unload()
            model = merged_model.to(torch.bfloat16)  # propagate dtype.
        else:
            model = base_model.to(torch.bfloat16)

        model.eval()
        model.to(device)
        self.mdl = model

        self.processor = AutoProcessor.from_pretrained(
            base_model_name,
            trust_remote_code=True,
            num_crops=4,
        )

    def encode_input(self, input):
        hidden_states = self.mdl(**input, return_dict=True, output_hidden_states=True)
        hidden_states = hidden_states.hidden_states[-1]
        pooled_output = self._pooling(hidden_states, input["attention_mask"])
        return pooled_output

    def _pooling(self, last_hidden_state, attention_mask):
        if self.pooling == "last":
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_state.shape[0]
            reps = last_hidden_state[
                torch.arange(batch_size, device=last_hidden_state.device),
                sequence_lengths,
            ]
        else:
            raise NotImplementedError
        if self.normalize:
            reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
        return reps

    # reference: https://github.com/TIGER-AI-Lab/VLM2Vec/blob/main/src/collator.py
    def get_image_embeddings(
        self,
        images,
        batch_size: int = 4,
        show_progress_bar: bool = True,
        **kwargs: Any,
    ):
        text = "<|image_1|> Represent the given image."
        all_image_embeddings = []

        n_batch = len(images) // batch_size + int(len(images) % batch_size > 0)

        with torch.no_grad():
            for n in tqdm(
                range(0, n_batch * batch_size, batch_size), disable=not show_progress_bar, desc="Image Encoding"
            ):
                batch = images[n: n+batch_size]
                input_ids, pixel_values, image_sizes = [], [], []
                for b in batch:  # ["image"]
                    inputs = self.processor(
                        text,
                        b,
                        return_tensors="pt",
                        max_length=256,
                        truncation=True,
                    )
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}
                    input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1))
                    pixel_values.append(inputs["pixel_values"])
                    image_sizes.append(inputs["image_sizes"])

                input_ids = torch._C._nn.pad_sequence(
                    input_ids,
                    batch_first=True,
                    padding_value=self.processor.tokenizer.pad_token_id,
                ).squeeze(2)
                attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id)

                pixel_values = torch.cat(pixel_values, dim=0)
                image_sizes = torch.cat(image_sizes, dim=0)
                inputs = {
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "pixel_values": pixel_values,
                    "image_sizes": image_sizes,
                }

                image_outputs = self.encode_input(inputs)
                all_image_embeddings.append(image_outputs.cpu().to(torch.float32))

        all_image_embeddings = torch.cat(all_image_embeddings, dim=0)
        return all_image_embeddings

    def get_text_embeddings(
        self,
        texts,
        batch_size: int = 32,
        show_progress_bar: bool = True,
        **kwargs: Any,
    ):
        all_text_embeddings = []

        n_batch = len(texts) // batch_size + int(len(texts) % batch_size > 0)

        with torch.no_grad():
            for n in tqdm(
                range(0, n_batch * batch_size, batch_size), disable=not show_progress_bar, desc="Text Encoding"
            ):
                batch = texts[n: n+batch_size]
                input_ids = []
                for text in batch:  # ["text"]
                    inputs = self.processor(
                        text,
                        None,
                        return_tensors="pt",
                        max_length=256,
                        truncation=True,
                    )
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}
                    input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1))

                input_ids = torch._C._nn.pad_sequence(
                    input_ids,
                    batch_first=True,
                    padding_value=self.processor.tokenizer.pad_token_id,
                ).squeeze(2)
                attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id)
                inputs = {
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                }

                text_outputs = self.encode_input(inputs)
                all_text_embeddings.append(text_outputs.cpu().to(torch.float32))

        all_text_embeddings = torch.cat(all_text_embeddings, dim=0)
        return all_text_embeddings

    def encode(
        self,
        mm_inputs,
        batch_size: int = 4,
        show_progress_bar: bool = True,
        **kwargs: Any,
    ):
        all_fused_embeddings = []

        n_batch = len(mm_inputs) // batch_size + int(len(mm_inputs) % batch_size > 0)

        with torch.no_grad():
            for n in tqdm(
                range(0, n_batch * batch_size, batch_size), disable=not show_progress_bar, desc="Multimodal Encoding"
            ):
                batch = mm_inputs[n: n+batch_size]
                input_ids, pixel_values, image_sizes = [], [], []
                batch_text = [i["text"] for i in batch]  # batch["text"]
                batch_image = [i["image"] for i in batch]  # batch["image"]
                for item_image, item_text in zip(batch_image, batch_text):
                    inputs = self.processor(
                        f"<|image_1|> Represent the given image with the following question: {item_text}",
                        item_image,
                        return_tensors="pt",
                        max_length=256,
                        truncation=True,
                    )
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}
                    input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1))
                    pixel_values.append(inputs["pixel_values"])
                    image_sizes.append(inputs["image_sizes"])

                input_ids = torch._C._nn.pad_sequence(
                    input_ids,
                    batch_first=True,
                    padding_value=self.processor.tokenizer.pad_token_id,
                ).squeeze(2)
                attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id)

                pixel_values = torch.cat(pixel_values, dim=0)
                image_sizes = torch.cat(image_sizes, dim=0)
                inputs = {
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "pixel_values": pixel_values,
                    "image_sizes": image_sizes,
                }

                outputs = self.encode_input(inputs)
                all_fused_embeddings.append(outputs.cpu().to(torch.float32))

        fused_embeddings = torch.cat(all_fused_embeddings, dim=0)
        return fused_embeddings

    def embed_text(
        self, inputs, prompt=None, input_type=None, task_name=None, **kwargs
    ) -> torch.Tensor:
        if 'batch_size' not in kwargs:
            kwargs['batch_size'] = self.batch_size
        texts = [i['data'] for i in inputs]
        return self.get_text_embeddings(texts, **kwargs)

    def embed_image(
        self, inputs, prompt=None, input_type=None, task_name=None, **kwargs
    ) -> torch.Tensor:
        if 'batch_size' not in kwargs:
            kwargs['batch_size'] = self.batch_size
        images = [i['data'] for i in inputs]
        return self.get_image_embeddings(images, **kwargs)

    def embed_multimodal(
        self, inputs, prompt=None, input_type=None, task_name=None, **kwargs
    ) -> torch.Tensor:
        from PIL.Image import Image

        if 'batch_size' not in kwargs:
            kwargs['batch_size'] = self.batch_size
        data = list()
        for item in inputs:
            ins = dict()
            text = ''
            for i in item['data']:
                if isinstance(i, str):
                    text += (i + ' ')
                elif isinstance(i, Image):
                    if 'image' not in ins:
                        ins['image'] = i
            ins['text'] = text.strip()
            data.append(ins)

        return self.encode(data, **kwargs)
