#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

from abc import ABC, abstractmethod

import random
import torch
import torch.nn as nn

from .multimodal_encoder.builder import build_vision_tower
from .multimodal_resampler.builder import build_vision_resampler
from .multimodal_projector.builder import build_vision_projector

from llava.constants import *
from llava.mm_utils import get_anyres_image_grid_shape
from llava.train.train_utils import rank0_print


class LlavaMetaModel:

    def __init__(self, config):
        super(LlavaMetaModel, self).__init__(config)

        if hasattr(config, "mm_vision_tower"):
            delay_load = getattr(config, "delay_load", False)
            self.vision_tower = build_vision_tower(config, delay_load=delay_load)
            self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
            self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)

            if "unpad" in getattr(config, "mm_patch_merge_type", ""):
                self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))

    def get_vision_tower(self):
        vision_tower = getattr(self, "vision_tower", None)
        if type(vision_tower) is list:
            vision_tower = vision_tower[0]
        return vision_tower

    def initialize_vision_modules(self, model_args, fsdp=None):
        vision_tower = model_args.vision_tower
        mm_vision_select_layer = model_args.mm_vision_select_layer
        mm_vision_select_feature = model_args.mm_vision_select_feature
        pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
        mm_patch_merge_type = model_args.mm_patch_merge_type

        self.config.mm_vision_tower = vision_tower
        self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")

        if self.get_vision_tower() is None:
            vision_tower = build_vision_tower(model_args)
            vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
            for k, v in vision_resampler.config.items():
                setattr(self.config, k, v)

            if fsdp is not None and len(fsdp) > 0:
                self.vision_tower = [vision_tower]
                self.vision_resampler = [vision_resampler]
            else:
                self.vision_tower = vision_tower
                self.vision_resampler = vision_resampler
        else:
            if fsdp is not None and len(fsdp) > 0:
                vision_resampler = self.vision_resampler[0]
                vision_tower = self.vision_tower[0]
            else:
                vision_resampler = self.vision_resampler
                vision_tower = self.vision_tower
            vision_tower.load_model()

            # In case it is frozen by LoRA
            for p in self.vision_resampler.parameters():
                p.requires_grad = True

        self.config.use_mm_proj = True
        self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
        self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
        self.config.mm_vision_select_layer = mm_vision_select_layer
        self.config.mm_vision_select_feature = mm_vision_select_feature
        self.config.mm_patch_merge_type = mm_patch_merge_type

        if getattr(self, "mm_projector", None) is None:
            self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)

            if "unpad" in mm_patch_merge_type:
                embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
                self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
        else:
            # In case it is frozen by LoRA
            for p in self.mm_projector.parameters():
                p.requires_grad = True

        if pretrain_mm_mlp_adapter is not None:
            mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")

            def get_w(weights, keyword):
                return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}

            incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
            rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
            incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
            rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")


def unpad_image(tensor, original_size):
    """
    Unpads a PyTorch tensor of a padded and resized image.

    Args:
    tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
    original_size (tuple): The original size of the image (height, width).

    Returns:
    torch.Tensor: The unpadded image tensor.
    """
    original_width, original_height = original_size
    current_height, current_width = tensor.shape[1:]

    # Compute aspect ratios
    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    # Determine padding size and direction
    if original_aspect_ratio > current_aspect_ratio:
        # Padding was added to the height
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
        padding = (current_height - new_height) // 2
        unpadded_tensor = tensor[:, padding : current_height - padding, :]
    else:
        # Padding was added to the width
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
        padding = (current_width - new_width) // 2
        unpadded_tensor = tensor[:, :, padding : current_width - padding]

    return unpadded_tensor


class LlavaMetaForCausalLMWithRet(ABC):

    @abstractmethod
    def get_model(self, model_name='query'):
        pass

    def get_vision_tower(self, model_name='query'):
        return self.get_model(model_name).get_vision_tower()

    def encode_images(self, images, model_name='query'):
        image_features = self.get_model(model_name).get_vision_tower()(images)
        image_features = self.get_model(model_name).vision_resampler(image_features, images=images)
        image_features = self.get_model(model_name).mm_projector(image_features)
        return image_features

    def prepare_inputs_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, images, is_query, is_doc):
        assert is_query or is_doc
        model_name = 'query' if is_query else 'document'
        if images is not None:
            image_features = self.encode_images(images, model_name)
        else:
            image_features = None

        if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
            raise NotImplementedError

        # Let's just add dummy tensors if they do not exist,
        # it is a headache to deal with None all the time.
        # But it is not ideal, and if you have a better idea,
        # please open an issue / submit a PR, thanks.
        _position_ids = position_ids
        _attention_mask = attention_mask
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
        else:
            attention_mask = attention_mask.bool()
        if position_ids is None:
            position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)

        input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]

        new_input_embeds = []
        cur_image_idx = 0

        for cur_input_ids in input_ids:
            num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
            if num_images == 0:
                cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
                new_input_embeds.append(cur_input_embeds)
                continue

            image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
            cur_input_ids_noim = []
            for i in range(len(image_token_indices) - 1):
                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
            split_sizes = [x.shape[0] for x in cur_input_ids_noim]
            cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
            cur_new_input_embeds = []

            for i in range(num_images + 1):
                cur_new_input_embeds.append(cur_input_embeds_no_im[i])
                if i < num_images:
                    try:
                        cur_image_features = image_features[cur_image_idx]
                    except:
                        print("Print")
                    cur_image_idx += 1
                    cur_new_input_embeds.append(cur_image_features)

            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
            cur_new_input_embeds = torch.cat(cur_new_input_embeds)
            new_input_embeds.append(cur_new_input_embeds)

        # Truncate sections into several passages according to the  to max length as image embeddings can make the sequence longer
        tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
        # Leave a room for the retrieval tokens
        tokenizer_model_max_length = tokenizer_model_max_length - self.soft_prompt.num_embeddings // 2
        # Split the section into passages if it exceeds
        if is_query:
            new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
            section_group_indices = None
        else:
            new_input_embeds_ = []
            section_group_indices = []
            cur_passage_id = 0
            for x in new_input_embeds:
                section_length = len(x)
                num_passages = (section_length + tokenizer_model_max_length - 1) // tokenizer_model_max_length
                # Append indices for the section group
                section_group_indices.append(torch.arange(cur_passage_id, cur_passage_id+num_passages))
                # Split the input and extend the result list
                new_input_embeds_.extend(x[i:i + tokenizer_model_max_length] for i in range(0, section_length, tokenizer_model_max_length))
                # Update passage_id
                cur_passage_id += num_passages
            new_input_embeds = new_input_embeds_

        batch_size = len(new_input_embeds)
        if is_query:  # The first N // 2 new tokens
            ret_inputs = torch.arange(0, self.soft_prompt.num_embeddings // 2, dtype=torch.int64, device=self.device).repeat(batch_size, 1)
            ret_tok = self.soft_prompt(ret_inputs)
        else:  # The last N // 2 new tokens
            ret_inputs = torch.arange(self.soft_prompt.num_embeddings // 2, self.soft_prompt.num_embeddings, dtype=torch.int64, device=self.device).repeat(batch_size, 1)
            ret_tok = self.soft_prompt(ret_inputs)

        # Record the position of the EoQue or EoSec in input_ids
        end_tok_indices = [torch.arange(len(x), len(x)+ self.soft_prompt.num_embeddings // 2) for x in new_input_embeds]
        # Add the retrieval token at the end of input.
        new_input_embeds = [torch.cat([x, ret_tok[batch_idx]], dim=0) for batch_idx, x in enumerate(new_input_embeds)]

        # Combine them
        max_len = max(x.shape[0] for x in new_input_embeds)

        new_input_embeds_padded = []
        attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
        position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)

        for i, cur_new_embed in enumerate(new_input_embeds):
            cur_len = cur_new_embed.shape[0]
            if getattr(self.config, "tokenizer_padding_side", "right") == "left":
                new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
                if cur_len > 0:
                    attention_mask[i, -cur_len:] = True
                    position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
            else:
                new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
                if cur_len > 0:
                    attention_mask[i, :cur_len] = True
                    position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)

        new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)

        if _attention_mask is None:
            attention_mask = None
        else:
            attention_mask = attention_mask.to(dtype=_attention_mask.dtype)

        if _position_ids is None:
            position_ids = None

        return None, position_ids, attention_mask, past_key_values, new_input_embeds, end_tok_indices, section_group_indices


    def initialize_vision_tokenizer(self, model_args, tokenizer):
        if model_args.mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

        if model_args.mm_use_im_start_end:
            num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

            if num_new_tokens > 0:
                input_embeddings = self.get_input_embeddings().weight.data
                output_embeddings = self.get_output_embeddings().weight.data

                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

                input_embeddings[-num_new_tokens:] = input_embeddings_avg
                output_embeddings[-num_new_tokens:] = output_embeddings_avg

            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = True
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

            if model_args.pretrain_mm_mlp_adapter:
                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
                embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
                assert num_new_tokens == 2
                if input_embeddings.shape == embed_tokens_weight.shape:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
                elif embed_tokens_weight.shape[0] == num_new_tokens:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight
                else:
                    raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
        elif model_args.mm_use_im_patch_token:
            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = False
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False


class LlavaMetaForCausalLMWithReRank(ABC):

    @abstractmethod
    def get_model(self):
        pass

    def get_vision_tower(self):
        return self.get_model().get_vision_tower()

    def encode_images(self, images):
        image_features = self.get_model().get_vision_tower()(images)
        image_features = self.get_model().vision_resampler(image_features, images=images)
        image_features = self.get_model().mm_projector(image_features)
        return image_features

    def prepare_inputs_for_multimodal(self, query_input_ids, query_attention_mask, query_images,
                                      doc_input_ids, doc_attention_mask, doc_images, doc_num_sections, position_ids, past_key_values):
        if query_images is not None:
            query_image_features = self.encode_images(query_images)
        else:
            query_image_features = None
        
        if doc_images is not None:
            doc_image_features = self.encode_images(doc_images)
        else:
            doc_image_features = None

        if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
            raise NotImplementedError

        _query_attention_mask = query_attention_mask
        _doc_attention_mask = doc_attention_mask
        if query_attention_mask is None:
            query_attention_mask = torch.ones_like(query_input_ids, dtype=torch.bool)
        else:
            query_attention_mask = query_attention_mask.bool()
        if doc_attention_mask is None:
            doc_attention_mask = torch.ones_like(doc_input_ids, dtype=torch.bool)
        else:
            doc_attention_mask =doc_attention_mask.bool()

        query_input_ids = [cur_query_input_ids[cur_query_attention_mask] for cur_query_input_ids, cur_query_attention_mask in zip(query_input_ids, query_attention_mask)]
        doc_input_ids = [cur_doc_input_ids[cur_doc_attention_mask] for cur_doc_input_ids, cur_doc_attention_mask in zip(doc_input_ids, doc_attention_mask)]

        new_query_embeds = self.prepare_text_image(query_input_ids, query_image_features)
        new_doc_embeds = self.prepare_text_image(doc_input_ids, doc_image_features)

        # Truncate sections into several passages according to the  to max length as image embeddings can make the sequence longer
        tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
        # Leave a room for the retrieval tokens. Note that compared to the retriever, there is only a single LVLM in the reranker
        tokenizer_model_max_length = tokenizer_model_max_length - self.soft_prompt.num_embeddings
        
        new_input_embeds_ = []
        section_group_indices = []
        cur_passage_id = 0
        cur_sec_id = 0

        for query_idx, new_query in enumerate(new_query_embeds):
            query_length = len(new_query)
            split_len = tokenizer_model_max_length - query_length

            query_num_sec = doc_num_sections[query_idx]

            for new_doc in new_doc_embeds[cur_sec_id:cur_sec_id+query_num_sec]:
                section_length = len(new_doc)
                num_passages = (section_length + split_len - 1) // split_len
                # Append indices for the section group
                section_group_indices.append(torch.arange(cur_passage_id, cur_passage_id+num_passages))
                # Split the input and extend the result list
                new_input_embeds_.extend(torch.cat([new_query, new_doc[i:i + split_len]], dim=0) for i in range(0, section_length, split_len))
                # Update passage_id
                cur_passage_id += num_passages
            
            cur_sec_id += query_num_sec

        new_input_embeds = new_input_embeds_

        batch_size = len(new_input_embeds)
        # Get learnable new features
        rerank_inputs = torch.arange(0, self.soft_prompt.num_embeddings, dtype=torch.int64, device=self.device).repeat(batch_size, 1)
        rerank_tok = self.soft_prompt(rerank_inputs)

        # Record the position of the EoSec in input_ids
        end_tok_indices = [torch.arange(len(x), len(x)+ self.soft_prompt.num_embeddings) for x in new_input_embeds]
        # Add the retrieval token at the end of input.
        new_input_embeds = [torch.cat([x, rerank_tok[batch_idx]], dim=0) for batch_idx, x in enumerate(new_input_embeds)]

        # Combine them
        max_len = max(x.shape[0] for x in new_input_embeds)

        new_input_embeds_padded = []
        attention_mask = torch.zeros((batch_size, max_len), dtype=doc_attention_mask.dtype, device=doc_attention_mask.device)

        for i, cur_new_embed in enumerate(new_input_embeds):
            cur_len = cur_new_embed.shape[0]
            if getattr(self.config, "tokenizer_padding_side", "right") == "left":
                new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
                if cur_len > 0:
                    attention_mask[i, -cur_len:] = True
            else:
                new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
                if cur_len > 0:
                    attention_mask[i, :cur_len] = True

        new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)

        if _query_attention_mask is None:
            query_attention_mask = None
        else:
            query_attention_mask = query_attention_mask.to(dtype=_query_attention_mask.dtype)
        
        if _doc_attention_mask is None:
            doc_attention_mask = None
        else:
            doc_attention_mask = doc_attention_mask.to(dtype=_doc_attention_mask.dtype)

        return None, None, attention_mask, past_key_values, new_input_embeds, end_tok_indices, section_group_indices

    
    def prepare_text_image(self, input_ids, image_features):
        new_input_embeds = []
        cur_image_idx = 0
        for cur_input_ids in input_ids:
            num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
            if num_images == 0:
                cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
                new_input_embeds.append(cur_input_embeds)
                continue

            image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
            cur_input_ids_noim = []
            for i in range(len(image_token_indices) - 1):
                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
            split_sizes = [x.shape[0] for x in cur_input_ids_noim]
            cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
            cur_new_input_embeds = []

            for i in range(num_images + 1):
                cur_new_input_embeds.append(cur_input_embeds_no_im[i])
                if i < num_images:
                    cur_image_features = image_features[cur_image_idx]
                    cur_image_idx += 1
                    cur_new_input_embeds.append(cur_image_features)

            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
            cur_new_input_embeds = torch.cat(cur_new_input_embeds)
            new_input_embeds.append(cur_new_input_embeds)
        
        return new_input_embeds

    def initialize_vision_tokenizer(self, model_args, tokenizer):
        if model_args.mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

        if model_args.mm_use_im_start_end:
            num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

            if num_new_tokens > 0:
                input_embeddings = self.get_input_embeddings().weight.data
                output_embeddings = self.get_output_embeddings().weight.data

                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

                input_embeddings[-num_new_tokens:] = input_embeddings_avg
                output_embeddings[-num_new_tokens:] = output_embeddings_avg

            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = True
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

            if model_args.pretrain_mm_mlp_adapter:
                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
                embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
                assert num_new_tokens == 2
                if input_embeddings.shape == embed_tokens_weight.shape:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
                elif embed_tokens_weight.shape[0] == num_new_tokens:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight
                else:
                    raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
        elif model_args.mm_use_im_patch_token:
            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = False
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False


class LlavaMetaForCausalLMWithReRankInterleaved(ABC):

    @abstractmethod
    def get_model(self):
        pass

    def get_vision_tower(self):
        return self.get_model().get_vision_tower()

    def encode_images(self, images):
        image_features = self.get_model().get_vision_tower()(images)
        image_features = self.get_model().vision_resampler(image_features, images=images)
        image_features = self.get_model().mm_projector(image_features)
        return image_features

    def prepare_inputs_for_multimodal(self, query_input_ids, query_attention_mask, query_images,
                                      doc_input_ids, doc_attention_mask, doc_images, doc_num_sections, position_ids, past_key_values):
        if query_images is not None:
            query_image_features = self.encode_images(query_images)
        else:
            query_image_features = None
        
        if doc_images is not None:
            doc_image_features = self.encode_images(doc_images)
        else:
            doc_image_features = None

        if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
            raise NotImplementedError

        _query_attention_mask = query_attention_mask
        _doc_attention_mask = doc_attention_mask
        if query_attention_mask is None:
            query_attention_mask = torch.ones_like(query_input_ids, dtype=torch.bool)
        else:
            query_attention_mask = query_attention_mask.bool()
        if doc_attention_mask is None:
            doc_attention_mask = torch.ones_like(doc_input_ids, dtype=torch.bool)
        else:
            doc_attention_mask =doc_attention_mask.bool()

        query_input_ids = [cur_query_input_ids[cur_query_attention_mask] for cur_query_input_ids, cur_query_attention_mask in zip(query_input_ids, query_attention_mask)]
        doc_input_ids = [cur_doc_input_ids[cur_doc_attention_mask] for cur_doc_input_ids, cur_doc_attention_mask in zip(doc_input_ids, doc_attention_mask)]

        new_query_embeds = self.prepare_text_image(query_input_ids, query_image_features)
        new_doc_embeds = self.prepare_text_image(doc_input_ids, doc_image_features)

        # Truncate sections into several passages according to the  to max length as image embeddings can make the sequence longer
        tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
        
        new_input_embeds_ = []
        end_tok_indices = []
        section_group_indices = []
        cur_interleaved_doc_id = 0
        cur_sec_id = 0

        for query_idx, new_query in enumerate(new_query_embeds):
            query_length = len(new_query)
            query_num_sec = doc_num_sections[query_idx]

            # Initialize variable
            accum_interleaved_doc_length = 0  # length of the tokens in the temp_doc_list
            temp_doc_list = []  # Temporally store query and sections
            end_tok_list = []  # indices for End of Section tokens

            for new_doc in new_doc_embeds[cur_sec_id:cur_sec_id+query_num_sec]:

                if len(temp_doc_list) != 0:
                    section_length = len(new_doc)
                    accum_interleaved_doc_length += section_length
                    # If the future accumulated query + interleaved document exceed the maximum length, 
                    # make another batch that includies the current document and remaining documents
                    if accum_interleaved_doc_length > tokenizer_model_max_length:
                        new_input_embeds_.append(torch.cat(temp_doc_list, dim=0))
                        end_tok_indices.append(end_tok_list)
                        # Append indices for the section group
                        section_group_indices.extend([torch.arange(cur_interleaved_doc_id+sec_idx, cur_interleaved_doc_id+sec_idx+1) for sec_idx in range(len(end_tok_list))])
                        # Update passage_id
                        cur_interleaved_doc_id += len(end_tok_list)
                        
                        # Initialize the variables
                        accum_interleaved_doc_length = 0
                        end_tok_list = []
                        temp_doc_list = []
                    else:
                        temp_doc_list.append(new_doc)
                        # Record the position of EoS tokens
                        end_tok_list.append(torch.arange(accum_interleaved_doc_length, accum_interleaved_doc_length + self.soft_prompt.num_embeddings))
                        # Append the Eos tokens to the temporary list
                        rerank_inputs = torch.arange(0, self.soft_prompt.num_embeddings, dtype=torch.int64, device=self.device)
                        temp_doc_list.append(self.soft_prompt(rerank_inputs))
                        accum_interleaved_doc_length += self.soft_prompt.num_embeddings

                # Empty temporal document list
                if len(temp_doc_list) == 0:
                    section_length = len(new_doc)
                    accum_interleaved_doc_length = query_length + section_length + self.soft_prompt.num_embeddings

                    # A section that already exceeds the maximum context length should be split.
                    if accum_interleaved_doc_length > tokenizer_model_max_length:
                        split_len = tokenizer_model_max_length - query_length - self.soft_prompt.num_embeddings
                        num_passages = (section_length + split_len - 1) // split_len
                        # Append indices for the section group
                        section_group_indices.append(torch.arange(cur_interleaved_doc_id, cur_interleaved_doc_id+num_passages))
                        cur_interleaved_doc_id += num_passages
                        # Split the input and extend the result list
                        for pas_idx in range(num_passages):
                            end_tok_start = query_length + len(new_doc[pas_idx * split_len : (pas_idx+1) * split_len])
                            end_tok_indices.append([torch.arange(end_tok_start, end_tok_start + self.soft_prompt.num_embeddings)])
                            rerank_inputs = torch.arange(0, self.soft_prompt.num_embeddings, dtype=torch.int64, device=self.device)
                            new_input_embeds_.append(torch.cat([new_query, new_doc[pas_idx*split_len : (pas_idx+1)*split_len], self.soft_prompt(rerank_inputs)], dim=0))

                        # Initialize variables
                        accum_interleaved_doc_length = 0
                        temp_doc_list = []
                        end_tok_list = []
                    else:
                        accum_interleaved_doc_length = query_length + section_length 
                        end_tok_list = [torch.arange(accum_interleaved_doc_length, accum_interleaved_doc_length + self.soft_prompt.num_embeddings)]
                        rerank_inputs = torch.arange(0, self.soft_prompt.num_embeddings, dtype=torch.int64, device=self.device)
                        temp_doc_list = [new_query, new_doc, self.soft_prompt(rerank_inputs)]
                        accum_interleaved_doc_length += self.soft_prompt.num_embeddings

            if len(temp_doc_list) != 0:
                # Process the remainings
                new_input_embeds_.append(torch.cat(temp_doc_list, dim=0))
                end_tok_indices.append(end_tok_list)
                # Append indices for the section group
                section_group_indices.extend([torch.arange(cur_interleaved_doc_id+sec_idx, cur_interleaved_doc_id+sec_idx+1) for sec_idx in range(len(end_tok_list))])
                # Update passage_id
                cur_interleaved_doc_id += len(end_tok_list)

        new_input_embeds = new_input_embeds_

        batch_size = len(new_input_embeds)

        # Combine them
        max_len = max(x.shape[0] for x in new_input_embeds)

        new_input_embeds_padded = []
        attention_mask = torch.zeros((batch_size, max_len), dtype=doc_attention_mask.dtype, device=doc_attention_mask.device)

        for i, cur_new_embed in enumerate(new_input_embeds):
            cur_len = cur_new_embed.shape[0]
            if getattr(self.config, "tokenizer_padding_side", "right") == "left":
                new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
                if cur_len > 0:
                    attention_mask[i, -cur_len:] = True
            else:
                new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
                if cur_len > 0:
                    attention_mask[i, :cur_len] = True

        new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)

        if _query_attention_mask is None:
            query_attention_mask = None
        else:
            query_attention_mask = query_attention_mask.to(dtype=_query_attention_mask.dtype)
        
        if _doc_attention_mask is None:
            doc_attention_mask = None
        else:
            doc_attention_mask = doc_attention_mask.to(dtype=_doc_attention_mask.dtype)

        return None, None, attention_mask, past_key_values, new_input_embeds, end_tok_indices, section_group_indices

    
    def prepare_text_image(self, input_ids, image_features):
        new_input_embeds = []
        cur_image_idx = 0
        for cur_input_ids in input_ids:
            num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
            if num_images == 0:
                cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
                new_input_embeds.append(cur_input_embeds)
                continue

            image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
            cur_input_ids_noim = []
            for i in range(len(image_token_indices) - 1):
                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
            split_sizes = [x.shape[0] for x in cur_input_ids_noim]
            cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
            cur_new_input_embeds = []

            for i in range(num_images + 1):
                cur_new_input_embeds.append(cur_input_embeds_no_im[i])
                if i < num_images:
                    cur_image_features = image_features[cur_image_idx]
                    cur_image_idx += 1
                    cur_new_input_embeds.append(cur_image_features)

            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
            cur_new_input_embeds = torch.cat(cur_new_input_embeds)
            new_input_embeds.append(cur_new_input_embeds)
        
        return new_input_embeds

    def initialize_vision_tokenizer(self, model_args, tokenizer):
        if model_args.mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

        if model_args.mm_use_im_start_end:
            num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

            if num_new_tokens > 0:
                input_embeddings = self.get_input_embeddings().weight.data
                output_embeddings = self.get_output_embeddings().weight.data

                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

                input_embeddings[-num_new_tokens:] = input_embeddings_avg
                output_embeddings[-num_new_tokens:] = output_embeddings_avg

            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = True
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

            if model_args.pretrain_mm_mlp_adapter:
                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
                embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
                assert num_new_tokens == 2
                if input_embeddings.shape == embed_tokens_weight.shape:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
                elif embed_tokens_weight.shape[0] == num_new_tokens:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight
                else:
                    raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
        elif model_args.mm_use_im_patch_token:
            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = False
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

class LlavaMetaForCausalLMWithReRankRandomNeg(ABC):

    @abstractmethod
    def get_model(self):
        pass

    def get_vision_tower(self):
        return self.get_model().get_vision_tower()

    def encode_images(self, images):
        image_features = self.get_model().get_vision_tower()(images)
        image_features = self.get_model().vision_resampler(image_features, images=images)
        image_features = self.get_model().mm_projector(image_features)
        return image_features

    def prepare_inputs_for_multimodal(self, query_input_ids, query_attention_mask, query_images,
                                      doc_input_ids, doc_attention_mask, doc_images, doc_num_sections, position_ids, past_key_values):
        if query_images is not None:
            query_image_features = self.encode_images(query_images)
        else:
            query_image_features = None
        
        if doc_images is not None:
            doc_image_features = self.encode_images(doc_images)
        else:
            doc_image_features = None

        if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
            raise NotImplementedError

        _query_attention_mask = query_attention_mask
        _doc_attention_mask = doc_attention_mask
        if query_attention_mask is None:
            query_attention_mask = torch.ones_like(query_input_ids, dtype=torch.bool)
        else:
            query_attention_mask = query_attention_mask.bool()
        if doc_attention_mask is None:
            doc_attention_mask = torch.ones_like(doc_input_ids, dtype=torch.bool)
        else:
            doc_attention_mask =doc_attention_mask.bool()

        query_input_ids = [cur_query_input_ids[cur_query_attention_mask] for cur_query_input_ids, cur_query_attention_mask in zip(query_input_ids, query_attention_mask)]
        doc_input_ids = [cur_doc_input_ids[cur_doc_attention_mask] for cur_doc_input_ids, cur_doc_attention_mask in zip(doc_input_ids, doc_attention_mask)]

        new_query_embeds = self.prepare_text_image(query_input_ids, query_image_features)
        new_doc_embeds = self.prepare_text_image(doc_input_ids, doc_image_features)

        # Truncate sections into several passages according to the  to max length as image embeddings can make the sequence longer
        tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
        # Leave a room for the retrieval tokens. Note that compared to the retriever, there is only a single LVLM in the reranker
        tokenizer_model_max_length = tokenizer_model_max_length - self.soft_prompt.num_embeddings
        
        new_input_embeds_ = []
        section_group_indices = []
        cur_passage_id = 0
        cur_sec_id = 0

        for query_idx, new_query in enumerate(new_query_embeds):
            query_length = len(new_query)
            split_len = tokenizer_model_max_length - query_length

            query_num_sec = doc_num_sections[query_idx]
            assert query_num_sec == 1

            # Make query - positive section pair
            pos_doc = new_doc_embeds[query_idx]
            section_length = len(pos_doc)
            num_passages = (section_length + split_len - 1) // split_len
            # Append indices for the section group
            section_group_indices.append(torch.arange(cur_passage_id, cur_passage_id+num_passages))
            # Split the input and extend the result list
            new_input_embeds_.extend(torch.cat([new_query, pos_doc[i:i + split_len]], dim=0) for i in range(0, section_length, split_len))
            # Update passage_id
            cur_passage_id += num_passages

            # Make query - negative section pair using other batch data
            neg_indices = [i for i in range(len(new_query_embeds)) if i != query_idx]
            neg_indices = random.sample(neg_indices, min(3, len(neg_indices)))

            for neg_idx in neg_indices:
                neg_doc = new_doc_embeds[neg_idx]
                section_length = len(neg_doc)
                num_passages = (section_length + split_len - 1) // split_len
                # Append indices for the section group
                section_group_indices.append(torch.arange(cur_passage_id, cur_passage_id+num_passages))
                # Split the input and extend the result list
                new_input_embeds_.extend(torch.cat([new_query, neg_doc[i:i + split_len]], dim=0) for i in range(0, section_length, split_len))
                # Update passage_id
                cur_passage_id += num_passages
            
        new_input_embeds = new_input_embeds_

        batch_size = len(new_input_embeds)
        # Get learnable new features
        rerank_inputs = torch.arange(0, self.soft_prompt.num_embeddings, dtype=torch.int64, device=self.device).repeat(batch_size, 1)
        rerank_tok = self.soft_prompt(rerank_inputs)

        # Record the position of the EoSec in input_ids
        end_tok_indices = [torch.arange(len(x), len(x)+ self.soft_prompt.num_embeddings) for x in new_input_embeds]
        # Add the retrieval token at the end of input.
        new_input_embeds = [torch.cat([x, rerank_tok[batch_idx]], dim=0) for batch_idx, x in enumerate(new_input_embeds)]

        # Combine them
        max_len = max(x.shape[0] for x in new_input_embeds)

        new_input_embeds_padded = []
        attention_mask = torch.zeros((batch_size, max_len), dtype=doc_attention_mask.dtype, device=doc_attention_mask.device)

        for i, cur_new_embed in enumerate(new_input_embeds):
            cur_len = cur_new_embed.shape[0]
            if getattr(self.config, "tokenizer_padding_side", "right") == "left":
                new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
                if cur_len > 0:
                    attention_mask[i, -cur_len:] = True
            else:
                new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
                if cur_len > 0:
                    attention_mask[i, :cur_len] = True

        new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)

        if _query_attention_mask is None:
            query_attention_mask = None
        else:
            query_attention_mask = query_attention_mask.to(dtype=_query_attention_mask.dtype)
        
        if _doc_attention_mask is None:
            doc_attention_mask = None
        else:
            doc_attention_mask = doc_attention_mask.to(dtype=_doc_attention_mask.dtype)

        return None, None, attention_mask, past_key_values, new_input_embeds, end_tok_indices, section_group_indices

    
    def prepare_text_image(self, input_ids, image_features):
        new_input_embeds = []
        cur_image_idx = 0
        for cur_input_ids in input_ids:
            num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
            if num_images == 0:
                cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
                new_input_embeds.append(cur_input_embeds)
                continue

            image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
            cur_input_ids_noim = []
            for i in range(len(image_token_indices) - 1):
                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
            split_sizes = [x.shape[0] for x in cur_input_ids_noim]
            cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
            cur_new_input_embeds = []

            for i in range(num_images + 1):
                cur_new_input_embeds.append(cur_input_embeds_no_im[i])
                if i < num_images:
                    cur_image_features = image_features[cur_image_idx]
                    cur_image_idx += 1
                    cur_new_input_embeds.append(cur_image_features)

            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
            cur_new_input_embeds = torch.cat(cur_new_input_embeds)
            new_input_embeds.append(cur_new_input_embeds)
        
        return new_input_embeds

    def initialize_vision_tokenizer(self, model_args, tokenizer):
        if model_args.mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

        if model_args.mm_use_im_start_end:
            num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

            if num_new_tokens > 0:
                input_embeddings = self.get_input_embeddings().weight.data
                output_embeddings = self.get_output_embeddings().weight.data

                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

                input_embeddings[-num_new_tokens:] = input_embeddings_avg
                output_embeddings[-num_new_tokens:] = output_embeddings_avg

            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = True
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

            if model_args.pretrain_mm_mlp_adapter:
                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
                embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
                assert num_new_tokens == 2
                if input_embeddings.shape == embed_tokens_weight.shape:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
                elif embed_tokens_weight.shape[0] == num_new_tokens:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight
                else:
                    raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
        elif model_args.mm_use_im_patch_token:
            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = False
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False
