#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import math
from functools import partial
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

from .multimodal_encoder.builder import build_vision_tower
from .multimodal_projector.builder import build_vision_projector
from .clustering import iterative_merge, cluster_in_masks

from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

from llava.mm_utils import get_anyres_image_grid_shape

from transformers import logging
from transformers.models.llama.modeling_llama import LlamaRMSNorm
logger = logging.get_logger(__name__)

class RegionCrossAttentionPooling(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        expand_mult: int = 1,
        num_heads: int = 8,
        dropout: float = 0.0,
        attn_impl: str = 'MHA',
        expand_side: str = 'hid', # ['hid', 'token']
        query_type: bool = 'feat', # ['bias', 'feat', 'linear', 'bias+feat']
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.expand_mult = expand_mult
        self.expand_side = expand_side
        self.query_type = query_type

        if query_type == 'bias':
            self.query_bias = nn.Parameter(torch.zeros(hidden_size*expand_mult))
        elif query_type == 'feat':
            self.query_proj = nn.Linear(hidden_size, hidden_size*expand_mult, bias=False)
        elif query_type == 'linear':
            self.query_proj = nn.Linear(hidden_size, hidden_size*expand_mult)
        elif query_type == 'bias+feat':
            assert expand_mult == 1, "expand_mult must be 1 for query_type 'bias+feat'."
            self.query_bias = nn.Parameter(torch.zeros(hidden_size))
        
        if expand_side == 'hid':
            total_hidden_size = hidden_size * expand_mult
            total_num_heads = num_heads * expand_mult
            if attn_impl == 'MHA':
                # TODO: have out_projection by default, would it mix the features for expanded tokens?
                self.attn = nn.MultiheadAttention(total_hidden_size, total_num_heads, dropout=dropout, bias=True, 
                                                  batch_first=True, kdim=hidden_size, vdim=hidden_size)
            else:
                raise ValueError(f"Unexpected attn_impl: {attn_impl}")
        else:
            raise NotImplementedError(f"expand_side {expand_side} is not implemented.")

    def forward(self, features:torch.Tensor, sam_masks:torch.Tensor, pooled_features:torch.Tensor=None):
        # features: torch.Tensor of shape (H, W, C)
        # sam_masks: torch.ByteTensor of shape (N, H, W)
        # pooled_features: torch.Tensor of shape (N, C)
        assert features.dim() == 3, "does not support batched inputs."
        H, W, C = features.shape
        N = sam_masks.shape[0]
        if pooled_features is None and self.query_type in ['feat', 'linear', 'bias+feat']:
            mask_areas = sam_masks.sum(dim=(1, 2))
            pooled_features = sam_masks.float().view(N, H*W) @ features.view(H*W, C).float() / mask_areas.unsqueeze(-1)
            pooled_features = pooled_features.to(features.dtype)
        
        if self.query_type == 'bias':
            query = self.query_bias.unsqueeze(0).expand(N, -1)
        elif self.query_type == 'feat':
            query = self.query_proj(pooled_features)
        elif self.query_type == 'linear':
            query = self.query_proj(pooled_features)
        elif self.query_type == 'bias+feat':
            query = self.query_bias.unsqueeze(0) + pooled_features
        kv = features.view(H*W, C)
        
        if self.expand_side == 'hid':
            attn_output, attn_output_weights = self.attn(query, kv, kv, attn_mask = sam_masks.view(N,H*W).logical_not())
        else:
            raise NotImplementedError(f"expand_side {self.expand_side} is not implemented.")
        return attn_output.view(N*self.expand_mult, self.hidden_size)

class ProjectorAndEmbedding(nn.Module):
    def __init__(self, projector: nn.Module, embedding: nn.Embedding = None, rmsnorm: LlamaRMSNorm = None, 
                 region_upproj: nn.Module = None, region_attn: RegionCrossAttentionPooling = None, **kwargs):
        super().__init__()
        self.projector = projector
        self.embedding = embedding
        self.rmsnorm = rmsnorm
        self.region_upproj = region_upproj
        self.region_attn = region_attn
        # add other modules if needed
        if kwargs is not None:
            for key, value in kwargs.items():
                setattr(self, key, value)

    def forward(self, x):
        return self.projector(x)

class LlavaMetaModel:

    def __init__(self, config, delay_load=True):
        super(LlavaMetaModel, self).__init__(config)

        if hasattr(config, "mm_vision_tower"):
            self.vision_tower = build_vision_tower(config, delay_load=delay_load)
            self.mm_projector = build_vision_projector(config)
            extra_modules = {}
            if getattr(config, "mm_vision_feature_pe", "none") != "none":
                pe_type=config.mm_vision_feature_pe
                if 'rmsnorm' in pe_type:
                    extra_modules['rmsnorm'] = LlamaRMSNorm(self.vision_tower.hidden_size)
                if 'plain' in pe_type:
                    img_feature_pos_embed = nn.Embedding(self.vision_tower.num_patches, self.vision_tower.hidden_size)
                    extra_modules['embedding'] = img_feature_pos_embed
            if getattr(config, "region_expand_mult", 1) != 1 and getattr(config, "region_pooling_method", "null") in ["average", "max"]:
                if config.region_expand_mult < 0:
                    raise NotImplementedError("adaptive region token expansion is not supported.")
                extra_modules['region_upproj'] = nn.Linear(self.vision_tower.hidden_size, self.vision_tower.hidden_size * config.region_expand_mult)
            if getattr(config, "region_pooling_method", "null") in ['cross-attn']:
                extra_modules['region_attn'] = RegionCrossAttentionPooling(
                    hidden_size=self.vision_tower.hidden_size,
                    expand_mult=config.region_expand_mult,
                    **config.region_attn_args
                )
            if len(extra_modules) > 0:
                self.mm_projector = ProjectorAndEmbedding(self.mm_projector, **extra_modules)

            if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
                self.image_newline = nn.Parameter(
                    torch.empty(config.hidden_size, dtype=self.dtype)
                )

    def get_vision_tower(self):
        vision_tower = getattr(self, 'vision_tower', None)
        if type(vision_tower) is list:
            vision_tower = vision_tower[0]
        return vision_tower

    def apply_vision_feature_postprocess(self, image_features: torch.Tensor):
        pe_type=getattr(self.config, 'mm_vision_feature_pe', 'none')
        if 'normalize' in pe_type:
            image_features = F.normalize(image_features, dim=-1)
        if 'rmsnorm' in pe_type:
            image_features = self.mm_projector.rmsnorm(image_features)
        if 'plain' in pe_type:
            image_features = image_features + self.mm_projector.embedding(torch.arange(image_features.shape[1], device=image_features.device))
        return image_features

    def initialize_vision_modules(self, model_args, fsdp=None):
        vision_tower = model_args.vision_tower
        mm_vision_select_layer = model_args.mm_vision_select_layer
        mm_vision_select_feature = model_args.mm_vision_select_feature
        pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
        mm_patch_merge_type = model_args.mm_patch_merge_type
        mm_force_imsize = model_args.mm_force_imsize
        mm_vision_feature_pe = model_args.mm_vision_feature_pe

        self.config.mm_vision_tower = vision_tower

        if self.get_vision_tower() is None:
            vision_tower = build_vision_tower(model_args)

            if fsdp is not None and len(fsdp) > 0:
                self.vision_tower = [vision_tower]
            else:
                self.vision_tower = vision_tower
        else:
            if fsdp is not None and len(fsdp) > 0:
                vision_tower = self.vision_tower[0]
            else:
                vision_tower = self.vision_tower
            vision_tower.load_model()

        self.config.use_mm_proj = True
        self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
        self.config.mm_hidden_size = vision_tower.hidden_size
        self.config.mm_vision_select_layer = mm_vision_select_layer
        self.config.mm_vision_select_feature = mm_vision_select_feature
        self.config.mm_patch_merge_type = mm_patch_merge_type
        self.config.mm_force_imsize = mm_force_imsize
        self.config.mm_vision_feature_pe = mm_vision_feature_pe

        if getattr(self, 'mm_projector', None) is None:
            self.mm_projector = build_vision_projector(self.config)
            extra_modules = {}
            if mm_vision_feature_pe != "none":
                pe_type=mm_vision_feature_pe
                if 'rmsnorm' in pe_type:
                    extra_modules['rmsnorm'] = LlamaRMSNorm(vision_tower.hidden_size)
                if 'plain' in pe_type:
                    img_feature_pos_embed = nn.Embedding(vision_tower.num_patches, vision_tower.hidden_size)
                    extra_modules['embedding'] = img_feature_pos_embed
            if getattr(self.config, "region_expand_mult", 1) != 1 and getattr(self.config, "region_pooling_method", "null") in ["average", "max"]:
                if self.config.region_expand_mult < 0:
                    raise NotImplementedError("adaptive region token expansion is not supported.")
                extra_modules['region_upproj'] = nn.Linear(vision_tower.hidden_size, vision_tower.hidden_size * self.config.region_expand_mult)
            if getattr(self.config, "region_pooling_method", "null") in ['cross-attn']:
                extra_modules['region_attn'] = RegionCrossAttentionPooling(
                    hidden_size=self.vision_tower.hidden_size,
                    expand_mult=self.config.region_expand_mult,
                    **self.config.region_attn_args
                )
            if len(extra_modules) > 0:
                self.mm_projector = ProjectorAndEmbedding(self.mm_projector, **extra_modules)

            if 'unpad' in mm_patch_merge_type:
                embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
                self.image_newline = nn.Parameter(
                    torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
                )
        else:
            # In case it is frozen by LoRA
            for p in self.mm_projector.parameters():
                p.requires_grad = True

        if pretrain_mm_mlp_adapter is not None:
            mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
            def get_w(weights, keyword):
                return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}

            self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))


def unpad_image(tensor, original_size):
    """
    Unpads a PyTorch tensor of a padded and resized image.

    Args:
    tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
    original_size (tuple): The original size of PIL image (width, height).

    Returns:
    torch.Tensor: The unpadded image tensor.
    """
    original_width, original_height = original_size
    current_height, current_width = tensor.shape[1:]

    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
        padding = (current_height - new_height) // 2
        unpadded_tensor = tensor[:, padding:current_height - padding, :]
    else:
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
        padding = (current_width - new_width) // 2
        unpadded_tensor = tensor[:, :, padding:current_width - padding]

    return unpadded_tensor


class LlavaMetaForCausalLM(ABC):

    @abstractmethod
    def get_model(self):
        pass

    def get_vision_tower(self):
        return self.get_model().get_vision_tower()

    def encode_images(self, images):
        image_features = self.get_model().get_vision_tower()(images)
        image_features = self.get_model().apply_vision_feature_postprocess(image_features)
        if self.shuffle_patches:
            device=image_features.device
            if getattr(self, "random_generator", None) is None:
                self.random_generator = torch.Generator(device=device)
                self.random_generator.manual_seed(114514)
            N, L, H = image_features.shape
            idx = torch.stack([torch.randperm(L, generator=self.random_generator, device=device) for _ in range(N)], dim=0).unsqueeze(-1).expand(N, L, H)
            image_features = image_features.gather(1, idx)

        image_features = self.get_model().mm_projector(image_features)
        return image_features

    def prepare_inputs_labels_for_multimodal(
        self, input_ids, position_ids, attention_mask, past_key_values, labels,
        images, image_sizes=None
    ):
        vision_tower = self.get_vision_tower()
        if vision_tower is None or images is None or input_ids.shape[1] == 1:
            return input_ids, position_ids, attention_mask, past_key_values, None, labels

        if type(images) is list or images.ndim == 5:
            if type(images) is list:
                images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
            concat_images = torch.cat([image for image in images], dim=0)
            image_features = self.encode_images(concat_images)
            split_sizes = [image.shape[0] for image in images]
            if isinstance(image_features, torch.Tensor):
                image_features = torch.split(image_features, split_sizes, dim=0)
            else:
                idx = 0
                new_image_features = []
                for split_size in split_sizes:
                    new_image_features.append(image_features[idx:idx+split_size])
                    idx += split_size
                image_features = new_image_features
            mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
            image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
            if mm_patch_merge_type == 'flat':
                if isinstance(image_features[0], torch.Tensor):
                    image_features = [x.flatten(0, 1) for x in image_features]
                else:
                    new_image_features = []
                    for image_feature in image_features:
                        new_image_features.append(torch.cat(image_feature, dim=0))
                    image_features = new_image_features
            elif mm_patch_merge_type.startswith('spatial'):
                new_image_features = []
                for image_idx, image_feature in enumerate(image_features):
                    if len(image_feature) > 1:
                        base_image_feature = image_feature[0]
                        image_feature = image_feature[1:]
                        height = width = self.get_vision_tower().num_patches_per_side
                        assert height * width == base_image_feature.shape[0]
                        if not isinstance(image_feature, torch.Tensor):
                            for i in range(len(image_feature)):
                                assert image_feature[i].shape[0] == height * width, f"This configuration only supports patch level features. Got {image_feature[i].shape[0]} patches."
                        if image_aspect_ratio == 'anyres':
                            num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
                            image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
                        else:
                            raise NotImplementedError
                        if 'unpad' in mm_patch_merge_type:
                            image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
                            image_feature = image_feature.flatten(1, 2).flatten(2, 3)
                            image_feature = unpad_image(image_feature, image_sizes[image_idx])
                            image_feature = torch.cat((
                                image_feature,
                                self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
                            ), dim=-1)
                            image_feature = image_feature.flatten(1, 2).transpose(0, 1)
                        else:
                            image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
                            image_feature = image_feature.flatten(0, 3)
                        image_feature = torch.cat((base_image_feature, image_feature), dim=0)
                    else:
                        image_feature = image_feature[0]
                        if 'unpad' in mm_patch_merge_type:
                            image_feature = torch.cat((
                                image_feature,
                                self.model.image_newline[None].to(image_feature.device)
                            ), dim=0)
                    new_image_features.append(image_feature)
                image_features = new_image_features
            else:
                raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
        else:
            image_features = self.encode_images(images)

        # TODO: image start / end is not implemented here to support pretraining.
        if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
            raise NotImplementedError

        # Let's just add dummy tensors if they do not exist,
        # it is a headache to deal with None all the time.
        # But it is not ideal, and if you have a better idea,
        # please open an issue / submit a PR, thanks.
        _labels = labels
        _position_ids = position_ids
        _attention_mask = attention_mask
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
        else:
            attention_mask = attention_mask.bool()
        if position_ids is None:
            position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
        if labels is None:
            labels = torch.full_like(input_ids, IGNORE_INDEX)

        # remove the padding using attention_mask -- FIXME
        _input_ids = input_ids
        input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
        labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]

        new_input_embeds = []
        new_labels = []
        cur_image_idx = 0
        for batch_idx, cur_input_ids in enumerate(input_ids):
            num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
            if num_images == 0:
                cur_image_features = image_features[cur_image_idx]
                cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
                cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
                new_input_embeds.append(cur_input_embeds)
                new_labels.append(labels[batch_idx])
                cur_image_idx += 1
                continue

            image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
            cur_input_ids_noim = []
            cur_labels = labels[batch_idx]
            cur_labels_noim = []
            for i in range(len(image_token_indices) - 1):
                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
                cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
            split_sizes = [x.shape[0] for x in cur_labels_noim]
            cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
            cur_new_input_embeds = []
            cur_new_labels = []

            for i in range(num_images + 1):
                cur_new_input_embeds.append(cur_input_embeds_no_im[i])
                cur_new_labels.append(cur_labels_noim[i])
                if i < num_images:
                    cur_image_features = image_features[cur_image_idx]
                    cur_image_idx += 1
                    cur_new_input_embeds.append(cur_image_features)
                    cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))

            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]

            cur_new_input_embeds = torch.cat(cur_new_input_embeds)
            cur_new_labels = torch.cat(cur_new_labels)

            new_input_embeds.append(cur_new_input_embeds)
            new_labels.append(cur_new_labels)

        # Truncate sequences to max length as image embeddings can make the sequence longer
        tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
        if tokenizer_model_max_length is not None:
            new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
            new_labels = [x[:tokenizer_model_max_length] for x in new_labels]

        # Combine them
        max_len = max(x.shape[0] for x in new_input_embeds)
        batch_size = len(new_input_embeds)

        new_input_embeds_padded = []
        new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
        attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
        position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)

        for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
            cur_len = cur_new_embed.shape[0]
            if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
                new_input_embeds_padded.append(torch.cat((
                    torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
                    cur_new_embed
                ), dim=0))
                if cur_len > 0:
                    new_labels_padded[i, -cur_len:] = cur_new_labels
                    attention_mask[i, -cur_len:] = True
                    position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
            else:
                new_input_embeds_padded.append(torch.cat((
                    cur_new_embed,
                    torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
                ), dim=0))
                if cur_len > 0:
                    new_labels_padded[i, :cur_len] = cur_new_labels
                    attention_mask[i, :cur_len] = True
                    position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)

        new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)

        if _labels is None:
            new_labels = None
        else:
            new_labels = new_labels_padded

        if _attention_mask is None:
            attention_mask = None
        else:
            attention_mask = attention_mask.to(dtype=_attention_mask.dtype)

        if _position_ids is None:
            position_ids = None

        return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels

    def initialize_vision_tokenizer(self, model_args, tokenizer):
        if model_args.mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

        if model_args.mm_use_im_start_end:
            num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

            if num_new_tokens > 0:
                input_embeddings = self.get_input_embeddings().weight.data
                output_embeddings = self.get_output_embeddings().weight.data

                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
                    dim=0, keepdim=True)
                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
                    dim=0, keepdim=True)

                input_embeddings[-num_new_tokens:] = input_embeddings_avg
                output_embeddings[-num_new_tokens:] = output_embeddings_avg

            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = True
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

            if model_args.pretrain_mm_mlp_adapter:
                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
                embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
                assert num_new_tokens == 2
                if input_embeddings.shape == embed_tokens_weight.shape:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
                elif embed_tokens_weight.shape[0] == num_new_tokens:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight
                else:
                    raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
        elif model_args.mm_use_im_patch_token:
            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = False
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

class RLlavaMetaForCausalLM(LlavaMetaForCausalLM, ABC):

    def sort_regions(self, sam_masks: torch.Tensor):
        # sam_masks: torch.ByteTensor of shape (N, H, W)

        if self.region_sort=="default":
            return sam_masks
        elif self.region_sort=="random":
            device = sam_masks.device
            if getattr(self, "random_generator", None) is None:
                self.random_generator = torch.Generator(device=device)
                self.random_generator.manual_seed(114514)
            return sam_masks[torch.randperm(len(sam_masks), generator=self.random_generator, device=device)]
        elif self.region_sort in ["center_of_mass", "com_patch", "diag_patch"]:
            patch_size = self.get_vision_tower().config.patch_size
            N, H, W = sam_masks.shape
            device = sam_masks.device
            coords_y = torch.arange(H, device=device).view(1, H, 1)
            coords_x = torch.arange(W, device=device).view(1, 1, W)

            mass = sam_masks.sum(dim=(1, 2))
            y_com = (coords_y * sam_masks).sum(dim=(1, 2)) / mass
            x_com = (coords_x * sam_masks).sum(dim=(1, 2)) / mass
            # com = torch.stack((y_com, x_com), dim=1)
            if self.region_sort == "center_of_mass":
                sort_key = y_com * W + x_com
            elif self.region_sort == "com_patch":
                sort_key = y_com.long() // patch_size * W + x_com
            elif self.region_sort == "diag_patch":
                sort_key = (y_com + x_com).long() // patch_size * (H+W) + (x_com - y_com + H)
            sort_idx = torch.argsort(sort_key)
            return sam_masks[sort_idx]
        else:
            raise ValueError(f"Unexpected region_sort: {self.region_sort}")

    def filter_regions(self, sam_masks: torch.Tensor, method: str = None):
        # sam_masks: torch.ByteTensor of shape (N, H, W)
        n_bef, H, W = sam_masks.shape
        method = self.region_filter if method is None else method
        if "corner_heuristic" in method:
            # filter out masks that occupy multiple corners
            threshold = 0.2
            area_threshold = 0.5
            max_ncorners = 1
            Ht, Wt = max(1, int(H * threshold)), max(1, int(W * threshold))
            area_threshold = int(Ht * Wt * area_threshold)
            corners = torch.stack([
                sam_masks[:, :Ht, :Wt].sum(dim=(1, 2)) > area_threshold,
                sam_masks[:, :Ht, -Wt:].sum(dim=(1, 2)) > area_threshold,
                sam_masks[:, -Ht:, :Wt].sum(dim=(1, 2)) > area_threshold,
                sam_masks[:, -Ht:, -Wt:].sum(dim=(1, 2)) > area_threshold,
            ], dim=-1)
            corners_count = corners.sum(dim=-1)
            sam_masks = sam_masks[corners_count <= max_ncorners]
        if "area_heuristic" in method:
            # filter out masks that are too small or too large
            min_area = 0.0003
            max_area = 0.5
            min_pixels = int(min_area * H * W)
            max_pixels = int(max_area * H * W)
            pixels_count = sam_masks.sum(dim=(1, 2))
            sam_masks = sam_masks[(pixels_count >= min_pixels) & (pixels_count <= max_pixels)]
        n_aft = sam_masks.shape[0]
        # if n_aft < n_bef:
        #     print(f"Filtered out {n_bef - n_aft}/{n_bef} regions.")
        return sam_masks

    def add_extra_regions(self, sam_masks: torch.Tensor, extra_regions: str = None):
        # sam_masks: torch.ByteTensor of shape (N, H, W)
        extra_regions = self.region_extra if extra_regions is None else extra_regions

        if 'global' in extra_regions:
            global_mask = torch.ones(sam_masks.shape[-2:], dtype=torch.bool, device=sam_masks.device)
            sam_masks = torch.cat((global_mask.unsqueeze(0), sam_masks), dim=0) if len(sam_masks) > 0 else global_mask.unsqueeze(0)

        if 'uncovered' in extra_regions:
            uncovered_mask = torch.any(sam_masks, dim=0).logical_not() if len(sam_masks) > 0 else torch.ones(sam_masks.shape[-2:], dtype=torch.bool, device=sam_masks.device)
            if torch.sum(uncovered_mask).item() > 0:
                sam_masks = torch.cat((sam_masks, uncovered_mask.unsqueeze(0)), dim=0) if len(sam_masks) > 0 else uncovered_mask.unsqueeze(0)
        return sam_masks

    @torch.inference_mode()
    def process_masks(self, features: torch.Tensor, sam_masks: torch.Tensor):
        # features: torch.Tensor of shape (H_P*W_P, C)
        # sam_masks: torch.ByteTensor of shape (N, H, W)
        if features.shape[0] > self.get_vision_tower().num_patches:
            # has extra tokens like cls or reg
            num_extra_features = features.shape[0] - self.get_vision_tower().num_patches
            features = features[num_extra_features:]
        # (H_P*W_P, C') -> (H_P, W_P, C')
        features = features.view(-1, self.get_vision_tower().num_patches_per_side, features.shape[-1])

        if self.region_source == "clustering":
            cluster_args = {}
            for item in self.region_cluster_args.split(','):
                if '=' in item:
                    k, v = item.split('=')
                    cluster_args[k] = eval(v)
            # threshes=[0.9, 0.8, 0.7, 0.6, 0.5]
            threshes=cluster_args['t']
            if not isinstance(threshes, list):
                threshes = [threshes]
            all_masks = iterative_merge(features, threshes, min_size=cluster_args['m'], merge_masks=True)
            sam_masks = all_masks[0]
        sam_masks = self.sort_regions(sam_masks)
        sam_masks = self.add_extra_regions(sam_masks)
        if self.region_filter != 'none':
            sam_masks = self.filter_regions(sam_masks)
        if sam_masks.shape[0] == 0:
            logger.warning_once("No regions in the image, adding a global region.")
            sam_masks = torch.ones((1, features.shape[0], features.shape[1]), dtype=torch.bool, device=features.device)

        if self.region_source == "clustering":
            downsampled_masks = sam_masks
        elif self.region_interpolate.startswith('upsample'):
            downsampled_masks = sam_masks
        else: # let SAM mask do the same resize, pad and crop as the image
            MASK_DOWNSAMPLE_THRESHOLD = 0.25
            target_h, target_w = self.get_vision_tower().config.image_size, self.get_vision_tower().config.image_size
            nmasks, h, w = sam_masks.shape

            sam_masks = sam_masks.unsqueeze(0).half()
            if self.region_interpolate.endswith("pad"): # pad the shortest edge, then resize
                if h < w:
                    pad = (w - h) // 2
                    sam_masks = F.pad(sam_masks, (0, 0, pad, w - h - pad), value=0)
                elif h > w:
                    pad = (h - w) // 2
                    sam_masks = F.pad(sam_masks, (pad, h - w - pad), value=0)
                downsampled_masks = F.interpolate(sam_masks, size=[target_h, target_w], mode='bilinear').squeeze(0) > MASK_DOWNSAMPLE_THRESHOLD
                if not downsampled_masks.any():
                    logger.warning_once("region is too small, retry with smaller threshold")
                    downsampled_masks = F.interpolate(sam_masks, size=[target_h, target_w], mode='bilinear').squeeze(0) > MASK_DOWNSAMPLE_THRESHOLD/5
            elif self.region_interpolate.endswith("crop"): # resize according to the shortest edge, then center crop
                mask_short, mask_long = (h, w) if h < w else (w, h)
                resize_short, resize_long = (target_h, int(target_h*mask_long/mask_short))
                resize_h, resize_w = (resize_short, resize_long) if h < w else (resize_long, resize_short)
                downsampled_masks = F.interpolate(sam_masks, size=[resize_h, resize_w], mode='bilinear').squeeze(0) > MASK_DOWNSAMPLE_THRESHOLD
                if not downsampled_masks.any():
                    logger.warning_once("region is too small, retry with smaller threshold")
                    downsampled_masks = F.interpolate(sam_masks, size=[resize_h, resize_w], mode='bilinear').squeeze(0) > MASK_DOWNSAMPLE_THRESHOLD/5
                downsampled_masks = crop(downsampled_masks, (target_h, target_w), center=True)
            else: # directly resize
                downsampled_masks = F.interpolate(sam_masks, size=[target_h, target_w], mode='bilinear').squeeze(0) > MASK_DOWNSAMPLE_THRESHOLD

            if self.region_interpolate.startswith('downsample'):
                # downsampled_mask are further convoluted to get the final mask
                MASK_CONV_THRESHOLD = 0.07
                patch_size = self.get_vision_tower().config.patch_size
                downsampled_masks = F.avg_pool2d(downsampled_masks.half(), kernel_size=patch_size, stride=patch_size) > MASK_CONV_THRESHOLD
            if self.region_source.startswith('split_'):
                assert self.region_interpolate.startswith('downsample'), f"region_source {self.region_source} only works with downsample"
                cluster_algo = self.region_source.split('_')[1]
                cluster_args = {}
                for item in self.region_cluster_args.split(','):
                    if '=' in item:
                        k, v = item.split('=')
                        cluster_args[k] = eval(v)
                skip_ids = []
                if 'global' in self.region_extra:
                    skip_ids.append(0)
                # start_time = time.time()
                downsampled_masks = cluster_in_masks(features, downsampled_masks, algo=cluster_algo, cluster_args=cluster_args, skip_ids=skip_ids)
                # self.cluster_time += time.time() - start_time

        return downsampled_masks

    def extract_region_features(self, features: torch.Tensor, masks: torch.Tensor, return_masks: bool = False):
        # features: torch.Tensor of shape (H_P*W_P, C)
        # masks: torch.ByteTensor of shape (N, H, W)
        patch_features = features
        has_extra_features = False
        if features.shape[0] > self.get_vision_tower().num_patches:
            # has extra tokens like cls or reg
            num_extra_features = features.shape[0] - self.get_vision_tower().num_patches
            extra_features = features[:num_extra_features]
            features = features[num_extra_features:]
            has_extra_features = True
        if self.need_region_upproj:
            features = self.get_model().mm_projector.region_upproj(features)
        # (H_P*W_P, C') -> (H_P, W_P, C')
        features = features.view(-1, self.get_vision_tower().num_patches_per_side, features.shape[-1])

        def get_sam_features(features: torch.Tensor, masks: torch.Tensor, masks_areas: torch.Tensor, parallel_avg: bool = True):
            # features: torch.Tensor of shape (H, W, C)
            # masks: torch.ByteTensor of shape (N, H, W)
            # masks_areas: torch.Tensor of shape (N,)
            H, W, C = features.shape
            N = masks.shape[0]
            all_region_features_in_image = []

            if parallel_avg and self.region_pooling_method in ['average', 'cross-attn']:
                # all_features_in_sam = torch.einsum('nhw,chw->nc', masks.to(features.dtype), features) / mask_areas.unsqueeze(-1)
                all_features_in_sam = masks.float().view(N, H*W) @ features.view(H*W, C).float() / mask_areas.unsqueeze(-1)
                all_features_in_sam = all_features_in_sam.to(features.dtype)
                if all_features_in_sam.isnan().any() or all_features_in_sam.isinf().any():
                    logger.warning_once("NaN or Inf detected in all_features_in_sam.")
                all_region_features_in_image.extend([f for f in all_features_in_sam])
            else:
                for mask in masks:
                    if self.region_pooling_method in ['average', 'cross-attn']:
                        features_in_sam = features[mask].mean(dim=1)
                    elif self.region_pooling_method == 'max':
                        input_max, max_indices = torch.max(features[mask], dim=1)
                        features_in_sam = input_max
                    else:
                        raise ValueError(f'Unexpected pooling method: {self.region_pooling_method}')
                    all_region_features_in_image.append(features_in_sam)
            return all_region_features_in_image

        if self.region_source == "clustering":
            upsample_feature = features
        elif self.region_interpolate.startswith('upsample'):
            if len(masks) > 0:
                # sam regions within an image all have the same total size
                new_h, new_w = masks[0].shape
                patch_length = self.get_vision_tower().config.patch_size
                padded_h, padded_w = math.ceil(new_h / patch_length) * patch_length, math.ceil(new_w / patch_length) * patch_length # Get the padded height and width
                mode = 'nearest' if 'nearest' in self.region_interpolate else 'bilinear'
                # (H,W,C) -> (1, C, H, W)
                upsample_feature = F.interpolate(features.permute(2, 0, 1).unsqueeze(0), size=[padded_h,padded_w],mode=mode).squeeze(0) # First interpolate to the padded size
                upsample_feature = T.CenterCrop((new_h, new_w)) (upsample_feature).permute(1, 2, 0) # Apply center cropping to the original size
                h,w,f = upsample_feature.size()
        else: # let SAM mask do the same resize, pad and crop as the image
            target_h, target_w = self.get_vision_tower().config.image_size, self.get_vision_tower().config.image_size
            nmasks, h, w = masks.shape
            if self.region_interpolate.startswith('downsample'):
                upsample_feature = features
            elif self.region_interpolate.startswith('process'):
                mode = 'nearest' if 'nearest' in self.region_interpolate else 'bilinear'
                upsample_feature = F.interpolate(features.permute(2, 0, 1).unsqueeze(0), size=[target_h, target_w], mode=mode).squeeze(0)
                upsample_feature = upsample_feature.permute(1, 2, 0)
            else:
                raise ValueError(f'Unexpected region_interpolate: {self.region_interpolate}')

        mask_areas = masks.sum(dim=(1, 2))
        nonzero_mask = mask_areas > 0
        masks = masks[nonzero_mask]
        mask_areas = mask_areas[nonzero_mask]
        all_region_features_in_image = get_sam_features(upsample_feature, masks, mask_areas)

        if len(all_region_features_in_image) == 0:
            logger.error("No regions in the image, should not happen now!")
            region_features = torch.zeros(0, features.shape[-1], device=features.device, dtype=features.dtype)
        else:
            region_features = torch.stack(all_region_features_in_image, dim=0)
        if self.need_region_upproj:
            # [N, C*k] -> [N*k, C]
            region_features = region_features.view(region_features.shape[0] * self.region_expand_mult, patch_features.shape[1])
        elif self.region_pooling_method=='cross-attn':
            region_features = self.get_model().mm_projector.region_attn(upsample_feature, masks, pooled_features=region_features)

        if 'patch' in self.region_extra:
            region_features = torch.cat((patch_features, region_features), dim=0)
        elif has_extra_features:
            region_features = torch.cat((extra_features, region_features), dim=0)
        return (region_features, nonzero_mask) if return_masks else region_features

    # def get_region_features_single_img(self, image: torch.Tensor, features: torch.Tensor, sam_masks: torch.Tensor, return_masks: bool = False):
    #     # features: torch.Tensor of shape (H_P*W_P, C)
    #     # sam_masks: torch.ByteTensor of shape (N, H, W)
    #     patch_features = features
    #     has_extra_features = False
    #     if features.shape[0] > self.get_vision_tower().num_patches:
    #         # has extra tokens like cls or reg
    #         num_extra_features = features.shape[0] - self.get_vision_tower().num_patches
    #         extra_features = features[:num_extra_features]
    #         features = features[num_extra_features:]
    #         has_extra_features = True
    #     if self.need_region_upproj:
    #         features = self.get_model().mm_projector.region_upproj(features)
    #     # (H_P*W_P, C') -> (H_P, W_P, C')
    #     features = features.view(-1, self.get_vision_tower().num_patches_per_side, features.shape[-1])
    #     # features = features.transpose(0, 1).view(1, features.shape[-1], self.get_vision_tower().num_patches_per_side, -1)
    #     all_region_features_in_image = []
    #     nonzero_mask = None

    #     if self.region_source == "clustering":
    #         cluster_args = {}
    #         for item in self.region_cluster_args.split(','):
    #             if '=' in item:
    #                 k, v = item.split('=')
    #                 cluster_args[k] = eval(v)
    #         # threshes=[0.9, 0.8, 0.7, 0.6, 0.5]
    #         threshes=cluster_args['t']
    #         if not isinstance(threshes, list):
    #             threshes = [threshes]
    #         # start_time = time.time()
    #         all_masks = iterative_merge(features, threshes, min_size=cluster_args['m'])
    #         # self.cluster_time += time.time() - start_time
    #         sam_masks = all_masks[0]
    #     sam_masks = self.sort_regions(sam_masks)
    #     sam_masks = self.add_extra_regions(sam_masks)
    #     if self.region_filter != 'none':
    #         sam_masks = self.filter_regions(sam_masks)
    #     if sam_masks.shape[0] == 0:
    #         logger.warning_once("No regions in the image, adding a global region.")
    #         sam_masks = torch.ones((1, features.shape[0], features.shape[1]), dtype=torch.bool, device=features.device)

    #     def get_sam_features(features: torch.Tensor, sam_masks: torch.Tensor, parallel_avg: bool = True):
    #         # features: torch.Tensor of shape (H, W, C)
    #         # sam_masks: torch.ByteTensor of shape (N, H, W)
    #         # remove columns that are all zeros
    #         mask_areas = sam_masks.sum(dim=(1, 2))
    #         nonlocal nonzero_mask
    #         nonzero_mask = mask_areas > 0
    #         sam_masks = sam_masks[nonzero_mask]
    #         mask_areas = mask_areas[nonzero_mask]
    #         H, W, C = features.shape
    #         N = sam_masks.shape[0]

    #         if parallel_avg and self.region_pooling_method in ['average', 'cross-attn']:
    #             # all_features_in_sam = torch.einsum('nhw,chw->nc', sam_masks.to(features.dtype), features) / mask_areas.unsqueeze(-1)
    #             all_features_in_sam = sam_masks.float().view(N, H*W) @ features.view(H*W, C).float() / mask_areas.unsqueeze(-1)
    #             all_features_in_sam = all_features_in_sam.to(features.dtype)
    #             if all_features_in_sam.isnan().any() or all_features_in_sam.isinf().any():
    #                 logger.warning_once("NaN or Inf detected in all_features_in_sam.")
    #             all_region_features_in_image.extend([f for f in all_features_in_sam])
    #             return sam_masks
    #         for sam_mask in sam_masks:
    #             if self.region_pooling_method in ['average', 'cross-attn']:
    #                 features_in_sam = features[sam_mask].mean(dim=1)
    #             elif self.region_pooling_method == 'max':
    #                 input_max, max_indices = torch.max(features[sam_mask], dim=1)
    #                 features_in_sam = input_max
    #             else:
    #                 raise ValueError(f'Unexpected pooling method: {self.region_pooling_method}')
    #             all_region_features_in_image.append(features_in_sam)
    #         return sam_masks

    #     if self.region_source == "clustering":
    #         upsample_feature = features
    #         final_masks = get_sam_features(upsample_feature, sam_masks)
    #         # demo_img = image.cpu().float().permute(1,2,0).numpy()
    #         # demo_img = ((demo_img - demo_img.min()) / (demo_img.max() - demo_img.min()) * 255).astype(np.uint8)
    #         # target_h, target_w = self.get_vision_tower().config.image_size, self.get_vision_tower().config.image_size
    #         # patch_size = self.get_vision_tower().config.patch_size
    #         # remain_h, remain_w = target_h % patch_size, target_w % patch_size
    #         # from llava.visualize_utils import show_img_and_mask
    #         # print([len(m) for m in all_masks])
    #         # for thresh, masks in zip(threshes, all_masks):
    #         #     demo_masks = F.interpolate(masks.unsqueeze(0).half(), size=[target_h-remain_h, target_w-remain_w], mode='nearest').squeeze(0).bool()
    #         #     demo_masks = F.pad(demo_masks, (0, remain_w, 0, remain_h), value=False)
    #         #     demo_masks = demo_masks.cpu().numpy()
    #         #     show_img_and_mask(demo_img, demo_masks, save_path=f'./playground/tmp-{thresh}.png')
    #         # import pdb; pdb.set_trace()
    #     elif self.region_interpolate.startswith('upsample'):
    #         if len(sam_masks) > 0:
    #             # sam regions within an image all have the same total size
    #             new_h, new_w = sam_masks[0].shape
    #             patch_length = self.get_vision_tower().config.patch_size
    #             padded_h, padded_w = math.ceil(new_h / patch_length) * patch_length, math.ceil(new_w / patch_length) * patch_length # Get the padded height and width
    #             mode = 'nearest' if 'nearest' in self.region_interpolate else 'bilinear'
    #             # (H,W,C) -> (1, C, H, W)
    #             upsample_feature = F.interpolate(features.permute(2, 0, 1).unsqueeze(0), size=[padded_h,padded_w],mode=mode).squeeze(0) # First interpolate to the padded size
    #             upsample_feature = T.CenterCrop((new_h, new_w)) (upsample_feature).permute(1, 2, 0) # Apply center cropping to the original size
    #             h,w,f = upsample_feature.size()

    #             final_masks = get_sam_features(upsample_feature, sam_masks)
    #     else: # let SAM mask do the same resize, pad and crop as the image
    #         MASK_DOWNSAMPLE_THRESHOLD = 0.25
    #         target_h, target_w = self.get_vision_tower().config.image_size, self.get_vision_tower().config.image_size
    #         nmasks, h, w = sam_masks.shape
    #         if self.region_interpolate.startswith('downsample'):
    #             upsample_feature = features
    #         elif self.region_interpolate.startswith('process'):
    #             mode = 'nearest' if 'nearest' in self.region_interpolate else 'bilinear'
    #             upsample_feature = F.interpolate(features.permute(2, 0, 1).unsqueeze(0), size=[target_h, target_w], mode=mode).squeeze(0)
    #             upsample_feature = upsample_feature.permute(1, 2, 0)
    #         else:
    #             raise ValueError(f'Unexpected region_interpolate: {self.region_interpolate}')

    #         sam_masks = sam_masks.unsqueeze(0).half()
    #         if self.region_interpolate.endswith("pad"): # pad the shortest edge, then resize
    #             if h < w:
    #                 pad = (w - h) // 2
    #                 sam_masks = F.pad(sam_masks, (0, 0, pad, w - h - pad), value=0)
    #             elif h > w:
    #                 pad = (h - w) // 2
    #                 sam_masks = F.pad(sam_masks, (pad, h - w - pad), value=0)
    #             downsampled_masks = F.interpolate(sam_masks, size=[target_h, target_w], mode='bilinear').squeeze(0) > MASK_DOWNSAMPLE_THRESHOLD
    #             if not downsampled_masks.any():
    #                 logger.warning_once("region is too small, retry with smaller threshold")
    #                 downsampled_masks = F.interpolate(sam_masks, size=[target_h, target_w], mode='bilinear').squeeze(0) > MASK_DOWNSAMPLE_THRESHOLD/5
    #         elif self.region_interpolate.endswith("crop"): # resize according to the shortest edge, then center crop
    #             mask_short, mask_long = (h, w) if h < w else (w, h)
    #             resize_short, resize_long = (target_h, int(target_h*mask_long/mask_short))
    #             resize_h, resize_w = (resize_short, resize_long) if h < w else (resize_long, resize_short)
    #             downsampled_masks = F.interpolate(sam_masks, size=[resize_h, resize_w], mode='bilinear').squeeze(0) > MASK_DOWNSAMPLE_THRESHOLD
    #             if not downsampled_masks.any():
    #                 logger.warning_once("region is too small, retry with smaller threshold")
    #                 downsampled_masks = F.interpolate(sam_masks, size=[resize_h, resize_w], mode='bilinear').squeeze(0) > MASK_DOWNSAMPLE_THRESHOLD/5
    #             downsampled_masks = crop(downsampled_masks, (target_h, target_w), center=True)
    #         else: # directly resize
    #             downsampled_masks = F.interpolate(sam_masks, size=[target_h, target_w], mode='bilinear').squeeze(0) > MASK_DOWNSAMPLE_THRESHOLD

    #         if self.region_interpolate.startswith('downsample'):
    #             # downsampled_mask are further convoluted to get the final mask
    #             MASK_CONV_THRESHOLD = 0.05
    #             patch_size = self.get_vision_tower().config.patch_size
    #             downsampled_masks = F.avg_pool2d(downsampled_masks.half(), kernel_size=patch_size, stride=patch_size) > MASK_CONV_THRESHOLD
    #         if self.region_source.startswith('split_'):
    #             assert self.region_interpolate.startswith('downsample'), f"region_source {self.region_source} only works with downsample"
    #             cluster_algo = self.region_source.split('_')[1]
    #             cluster_args = {}
    #             for item in self.region_cluster_args.split(','):
    #                 if '=' in item:
    #                     k, v = item.split('=')
    #                     cluster_args[k] = eval(v)
    #             skip_ids = []
    #             if 'global' in self.region_extra:
    #                 skip_ids.append(0)
    #             start_time = time.time()
    #             downsampled_masks = cluster_in_masks(features, downsampled_masks, algo=cluster_algo, cluster_args=cluster_args, skip_ids=skip_ids)
    #             self.cluster_time += time.time() - start_time
    #             # print(f"num_masks_before: {nmasks}, num_masks_after: {len(downsampled_masks)}")
    #             # import pdb; pdb.set_trace()

    #         final_masks = get_sam_features(upsample_feature, downsampled_masks)

    #     # demo_img = image.cpu().float().permute(1,2,0).numpy()
    #     # demo_img = ((demo_img - demo_img.min()) / (demo_img.max() - demo_img.min()) * 255).astype(np.uint8)
    #     # demo_masks = final_masks['global' in self.region_extra:]
    #     # if self.region_source == "clustering" or self.region_interpolate.startswith('downsample'):
    #     #     patch_size = self.get_vision_tower().config.patch_size
    #     #     remain_h, remain_w = target_h % patch_size, target_w % patch_size
    #     #     demo_masks = F.interpolate(final_masks.unsqueeze(0).half(), size=[target_h-remain_h, target_w-remain_w], mode='nearest').squeeze(0).bool()
    #     #     demo_masks = F.pad(demo_masks, (0, remain_w, 0, remain_h), value=False)
    #     # demo_masks = demo_masks.cpu().numpy()
    #     # from llava.visualize_utils import show_img_and_mask
    #     # show_img_and_mask(demo_img, demo_masks, save_path='./playground/tmp.png')
    #     # import pdb; pdb.set_trace()

    #     if len(all_region_features_in_image) == 0:
    #         logger.error("No regions in the image, should not happen now!")
    #         region_features = torch.zeros(0, features.shape[-1], device=features.device, dtype=features.dtype)
    #     else:
    #         region_features = torch.stack(all_region_features_in_image, dim=0)
    #     if self.need_region_upproj:
    #         # [N, C*k] -> [N*k, C]
    #         region_features = region_features.view(region_features.shape[0] * self.region_expand_mult, patch_features.shape[1])
    #     elif self.region_pooling_method=='cross-attn':
    #         region_features = self.get_model().mm_projector.region_attn(upsample_feature, final_masks, pooled_features=region_features)

    #     if 'patch' in self.region_extra:
    #         region_features = torch.cat((patch_features, region_features), dim=0)
    #     elif has_extra_features:
    #         region_features = torch.cat((extra_features, region_features), dim=0)
    #     return (region_features, final_masks, nonzero_mask) if return_masks else region_features

    def encode_images(self, images, sam_masks, return_masks: bool = False):
        assert len(images) == len(sam_masks), "The number of images and sam_masks must be the same."
        image_features = self.get_model().get_vision_tower()(images)
        if not self.region_late_pe:
            image_features = self.get_model().apply_vision_feature_postprocess(image_features)
        # start_time = time.time()
        # self.cluster_time = 0
        # NUM_STREAMS = min(2, len(images))
        # streams = [torch.cuda.Stream() for _ in range(NUM_STREAMS)]
        # def _process_single(idx, image_feature, masks):
        #     # Choose a stream deterministically
        #     s = streams[idx % NUM_STREAMS]
        #     with torch.cuda.stream(s):
        #         return self.process_masks(image_feature, masks)
        # with ThreadPoolExecutor(max_workers=NUM_STREAMS) as pool:
        #     futures = [pool.submit(_process_single, i, image_features[i], sam_masks[i])
        #             for i in range(len(image_features))]
        #     results = [f.result() for f in futures]
        # torch.cuda.synchronize()
        # processed_masks = results
        processed_masks = [self.process_masks(image_feature, sam_mask) for image_feature, sam_mask in zip(image_features, sam_masks)]
        # mask_time = time.time()
        image_features = [self.extract_region_features(image_feature, mask, return_masks=return_masks) 
                          for image_feature, mask in zip(image_features, processed_masks)]
        # print(f"Region feature extraction time: {time.time() - start_time:.2f}s, mask processing time: {mask_time - start_time:.2f}s")
        if return_masks:
            image_features, nonzero_masks = zip(*image_features)
        # self.cluster_time = 0
        # image_features = [self.get_region_features_single_img(image, image_feature, sam_mask, return_masks=return_masks) 
        #                   for image, image_feature, sam_mask in zip(images, image_features, sam_masks)]
        # print(f"Region feature extraction time: {time.time() - start_time:.2f}s, clustering time: {self.cluster_time:.2f}s")
        # if return_masks:
        #     image_features, processed_masks, nonzero_masks = zip(*image_features)
        split_sizes = [image_feature.shape[0] for image_feature in image_features]
        stack_image_features = torch.cat(image_features, dim=0)
        if self.region_late_pe:
            stack_image_features = self.get_model().apply_vision_feature_postprocess(stack_image_features)
        stack_image_features = self.get_model().mm_projector(stack_image_features)
        image_features = torch.split(stack_image_features, split_sizes, dim=0)
        if return_masks:
            self._cached_masks = (image_features, processed_masks, nonzero_masks)
        return image_features

    def prepare_inputs_labels_for_multimodal(
        self, input_ids, position_ids, attention_mask, past_key_values, labels,
        images, sam_masks, image_sizes=None, return_masks=False
    ):
        orig_encode = self.encode_images
        self.encode_images = partial(self.encode_images, sam_masks=sam_masks, return_masks=return_masks)
        return_values = super().prepare_inputs_labels_for_multimodal(
            input_ids, position_ids, attention_mask, past_key_values, labels,
            images, image_sizes
        )
        self.encode_images = orig_encode
        return return_values

def crop(img: torch.Tensor, size, center=True):
    if center:
        h, w = img.shape[-2:]
        th, tw = size
        i = int(round((h - th) // 2))
        j = int(round((w - tw) // 2))
        return img[..., i:i+th, j:j+tw]
    else:
        return img[..., :size[0], :size[1]]
