#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.


from abc import ABC, abstractmethod

import math
import re
import time
import torch
import torch.nn as nn
from .multimodal_encoder.builder import build_vision_tower
from .multimodal_projector.builder import build_vision_projector

from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

from .mm_utils import get_anyres_image_grid_shape
from .sampler import topk_softmax


class LlavaMetaModel:

    def __init__(self, config):
        super(LlavaMetaModel, self).__init__(config)

        if hasattr(config, "mm_vision_tower"):
            delay_load = getattr(config, "delay_load", False)
            self.vision_tower = build_vision_tower(config, delay_load=delay_load)
            # self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
            self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)

            if "unpad" in getattr(config, "mm_patch_merge_type", ""):
                self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))

    def get_vision_tower(self):
        vision_tower = getattr(self, "vision_tower", None)
        if type(vision_tower) is list:
            vision_tower = vision_tower[0]
        return vision_tower

def unpad_image(tensor, original_size):
    """
    Unpads a PyTorch tensor of a padded and resized image.

    Args:
    tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
    original_size (tuple): The original size of the image (height, width).

    Returns:
    torch.Tensor: The unpadded image tensor.
    """
    original_width, original_height = original_size
    current_height, current_width = tensor.shape[1:]

    # Compute aspect ratios
    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    # Determine padding size and direction
    if original_aspect_ratio > current_aspect_ratio:
        # Padding was added to the height
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
        padding = (current_height - new_height) // 2
        unpadded_tensor = tensor[:, padding : current_height - padding, :]
    else:
        # Padding was added to the width
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
        padding = (current_width - new_width) // 2
        unpadded_tensor = tensor[:, :, padding : current_width - padding]

    return unpadded_tensor


def unpad_logits(tensor, original_size, unpad_const):
    """
    Unpads a PyTorch tensor of a padded and resized image.

    Args:
        tensor (`torch.Tensor`):
            The image tensor, assumed to be of shape (num_channels, height, width).
        original_size (`tuple`):
            The original size of the image (height, width).

    Returns:
        `torch.Tensor`: The unpadded image tensor.
    """
    if not isinstance(original_size, (list, tuple)):
        if not isinstance(original_size, (torch.Tensor, np.ndarray)):
            raise TypeError(
                f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
            )
        original_size = original_size.tolist()
    original_width, original_height = original_size
    current_height, current_width = tensor.shape

    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    tensor += unpad_const
    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(round(original_height * scale_factor, 7))
        padding = (current_height - new_height) // 2
        tensor[padding : current_height - padding, :] = 0
    else:
        scale_factor = current_height / original_height
        new_width = int(round(original_width * scale_factor, 7))
        padding = (current_width - new_width) // 2
        tensor[:, padding : current_width - padding] = 0

    return tensor


class LlavaMetaForCausalLM(ABC):

    @abstractmethod
    def get_model(self):
        pass

    def get_vision_tower(self):
        return self.get_model().get_vision_tower()

    def get_2dPool(self, image_feature, stride=2):
        height = width = self.get_vision_tower().num_patches_per_side
        num_frames, num_tokens, num_dim = image_feature.shape
        image_feature = image_feature.view(num_frames, height, width, -1)
        image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
        # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
        if self.config.mm_spatial_pool_mode == "average":
            image_feature = nn.functional.avg_pool2d(image_feature, stride)
        elif self.config.mm_spatial_pool_mode == "max":
            image_feature = nn.functional.max_pool2d(image_feature, stride)
        elif self.config.mm_spatial_pool_mode == "bilinear":
            height, width = image_feature.shape[2:]
            scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
            image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear')

        else:
            raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}")
        image_feature = image_feature.permute(0, 2, 3, 1)
        image_feature = image_feature.view(num_frames, -1, num_dim)
        return image_feature

    def encode_images(self, images):
        for_selection_image_features = self.get_model().get_vision_tower()(images)
        image_features = self.get_model().mm_projector(for_selection_image_features)
        return image_features, for_selection_image_features

    def add_token_per_grid(self, image_feature, gumbel_mask):
        resize_h = int(math.sqrt(image_feature.shape[1]))
        num_frames = image_feature.shape[0]
        feature_dim = image_feature.shape[-1]

        image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1)
        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
        image_feature = image_feature.flatten(1, 2).flatten(2, 3)

        gumbel_mask = gumbel_mask.view(num_frames, 1, resize_h, resize_h, 1)
        gumbel_mask = gumbel_mask.permute(4, 0, 2, 1, 3).contiguous()
        gumbel_mask = gumbel_mask.flatten(1, 2).flatten(2, 3)

        image_feature = torch.cat(
            (
                image_feature, 
                self.model.image_newline[:, None, None]\
                .expand(*image_feature.shape[:-1], 1)\
                .to(image_feature.device)
            ), dim=-1)
        gumbel_mask = torch.cat(
            (
                gumbel_mask, 
                torch.tensor([True])[:, None, None]\
                .expand(*gumbel_mask.shape[:-1], 1)\
                .to(image_feature.device)
            ), dim=-1)

        # if getattr(self.config, "add_faster_video", False):
        #     # import pdb; pdb.set_trace()
        #     # (3584, 832, 14) -> (3584, 64, 13, 14)
        #     image_feature = image_feature.view(feature_dim, num_frames,resize_h, -1)
        #     #  (3584, 64, 13, 14) -> (64, 13, 14, 3584)
        #     image_feature = image_feature.permute(1, 2, 3, 0).contiguous()
        #     # (64, 13, 14, 3584) -> (64, 13*14, 3584)
        #     image_feature = image_feature.flatten(1, 2)
        #     # import pdb; pdb.set_trace()
        #     return image_feature
        # # import pdb; pdb.set_trace()
        image_feature = image_feature.flatten(1, 2).transpose(0, 1)
        gumbel_mask = gumbel_mask.flatten(1, 2).transpose(0, 1)[:, 0]
        return image_feature, gumbel_mask

    def add_token_per_frame(self, image_feature):
        image_feature = image_feature.permute(2, 0, 1).contiguous()
        image_feature =  torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
        image_feature = image_feature.permute(1, 2, 0).contiguous()
        return image_feature

#####################################################################################
#####################################################################################

    def draw_image(self, gumbel_masks_list, logits_list, image_sizes, pixel_values_list, gumbel_masks_list_for_drawing):
        import numpy as np
        import matplotlib.pyplot as plt
        from math import ceil
        from torch.nn import functional as F

        for gumbel_mask, logits, gumbel_mask_for_drawing, pixel_values, image_size in \
                zip(gumbel_masks_list, logits_list, gumbel_masks_list_for_drawing, pixel_values_list, image_sizes):
            logits = torch.sigmoid(logits[..., 0])
            pixel_values = pixel_values[None] # [1, 5, 3, 448, 448]
            ################################################################################
            non = gumbel_mask.sum().item()
            all = np.prod(gumbel_mask.shape)
            pop = gumbel_mask.float().mean().item()
            print(f'Number of patches: {non} from {all}')
            print(f'Persantage of patches: {round(pop, 3)}')
            print('-------------------------------------------------------')
            nop = gumbel_mask[0].sum().item() + gumbel_mask_for_drawing[:, :-1].sum().item()
            all = (gumbel_mask.shape[-1] + np.prod(gumbel_mask_for_drawing[:, :-1].shape))
            pop = nop / all
            print(f'Number of patches after truncation of the margins: {nop} from {all}')
            print(f'Persantage of patches after truncation of the margins: {round(pop, 3)}')
            ################################################################################

            vision_tower_image_size = self.get_vision_tower().image_size
            num_patch_width, num_patch_height = get_anyres_image_grid_shape(
                image_size,
                self.config.image_grid_pinpoints,
                vision_tower_image_size
            )

            mask_side = int(gumbel_mask.shape[1]**0.5)
            image_side = pixel_values.shape[-1]

            # ------------- IMAGE -------------
            small_image = pixel_values[0, 0].float() # [3, 448, 448]
            big_image = pixel_values[0, 1:num_patch_height*num_patch_width+1]\
                .reshape(num_patch_height, num_patch_width,
                        pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4]).float() # [2, 5, 3, 448, 448]
            small_image = small_image.permute(1, 2, 0).cpu() # [448, 448, 3]
            big_image = big_image.permute(0, 1, 3, 4, 2).cpu() # [2, 5, 448, 448, 3]
            # --------------- GUMBEL MASK ---------------
            gms = gumbel_mask.reshape(gumbel_mask.shape[0], 1, mask_side, mask_side).float() # [5, 1, 24, 24]
            gms = gms.repeat(1, 3, 1, 1) # [5, 3, 24, 24]
            gms = F.interpolate(gms, size=(image_side, image_side)) > 0 # [5, 3, 24, 24]

            small_mask = gms[0] # [3, 448, 448]
            big_mask = gms[1:num_patch_height*num_patch_width+1] # [10, 3, 448, 448]
            big_mask = big_mask.reshape(num_patch_height, num_patch_width,
                                        big_mask.shape[1], big_mask.shape[2], big_mask.shape[3])
                    
            small_mask = small_mask.permute(1, 2, 0).cpu().to(torch.float32) # [448, 448, 3]
            big_mask = big_mask.permute(0, 1, 3, 4, 2).cpu().to(torch.float32) # [2, 5, 448, 448, 3]
            # --------------- LOGITS ----------------
            ls = logits.reshape(logits.shape[0], 1, mask_side, mask_side).float()
            ls = ls.repeat(1, 3, 1, 1)
            ls = F.interpolate(ls, size=(image_side, image_side))

            small_ls = ls[0] # [3, 448, 448]
            big_ls = ls[1:num_patch_height*num_patch_width+1] # [10, 3, 448, 448]
            big_ls = big_ls.reshape(num_patch_height, num_patch_width,
                                        big_ls.shape[1], big_ls.shape[2], big_ls.shape[3])
                    
            small_ls = small_ls.permute(1, 2, 0).cpu().to(torch.float32) # [448, 448, 3]
            big_ls = big_ls.permute(0, 1, 3, 4, 2).cpu().to(torch.float32) # [2, 5, 448, 448, 3]
            # --------------- BIG TRANSFORM ----------------
            big_image = big_image.permute(0, 2, 1, 3, 4).reshape(big_image.shape[0] * big_image.shape[2], 
                                                                big_image.shape[1] * big_image.shape[3], 
                                                                big_image.shape[4])
            big_mask = big_mask.permute(0, 2, 1, 3, 4).reshape(big_mask.shape[0] * big_mask.shape[2], 
                                                            big_mask.shape[1] * big_mask.shape[3], 
                                                            big_mask.shape[4])
            big_ls = big_ls.permute(0, 2, 1, 3, 4).reshape(big_ls.shape[0] * big_ls.shape[2], 
                                                            big_ls.shape[1] * big_ls.shape[3], 
                                                            big_ls.shape[4])
            # -------------------------------
            big_image = (big_image - big_image.min()) / (big_image.max() - big_image.min())
            small_image = (small_image - small_image.min()) / (small_image.max() - small_image.min())
            # --------------- PLOT ----------------
            #######################################
            #######################################

            big_image[big_mask[..., 0] < 0.5] = torch.tensor([135, 206, 235]) / 255
            small_image[small_mask[..., 0] < 0.5] = torch.tensor([135, 206, 235]) / 255

            ncols, nrows, scale = 3, 2, 10
            fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * scale, nrows * scale))

            axes[0, 0].imshow(gumbel_mask_for_drawing.cpu().detach())
            axes[1, 0].hist(logits.flatten().cpu().detach(), bins=60, density=True)
            axes[0, 1].imshow(big_image)
            axes[1, 1].imshow(small_image)
            axes[0, 2].imshow(big_ls)
            axes[1, 2].imshow(small_ls)
            if self.sampler.sampler_type == 'threshold': 
                axes[1, 0].plot([self.sampler.th, self.sampler.th], [0, 5])
            plt.show()

    def draw_video(self, gumbel_masks_list, logits_list, pixel_values_list):
        import numpy as np
        import matplotlib.pyplot as plt
        from math import ceil
        from torch.nn import functional as F

        for gumbel_mask, logits, pixel_values in zip(gumbel_masks_list, logits_list, pixel_values_list):
            logits = torch.sigmoid(logits[..., 0]) # [N, 144, 2]
            frames = pixel_values # [N, 3, 336, 336]
            ################################################################################
            non = gumbel_mask.sum().item()
            all = np.prod(gumbel_mask.shape)
            pop = gumbel_mask.float().mean().item()
            print(f'Number of patches: {non} from {all}')
            print(f'Persantage of patches: {round(pop, 3)}')
            print('-------------------------------------------------------')
            ################################################################################

            mask_side = int(gumbel_mask.shape[1]**0.5)
            image_side = pixel_values.shape[-1]

            # ------------- IMAGE -------------
            frames = frames.permute(0, 2, 3, 1).cpu().to(torch.float32) # [N, 336, 336, 3]
            # --------------- GUMBEL MASK ---------------
            gms = gumbel_mask.reshape(gumbel_mask.shape[0], 1, mask_side, mask_side).float() # [N, 1, 12, 12]
            gms = gms.repeat(1, 3, 1, 1) # [N, 3, 12, 12]
            gms = F.interpolate(gms, size=(image_side, image_side)) > 0 # [N, 3, 336, 336]
            gms = gms.permute(0, 2, 3, 1).cpu().to(torch.float32) # [N, 336, 336, 3]
            # ------------- LOGITS -------------
            logits = logits.reshape(gumbel_mask.shape[0], 1, mask_side, mask_side).float()
            logits = logits.repeat(1, 3, 1, 1).permute(0, 2, 3, 1) # [N, 12, 12, 3]
            logits = logits.cpu().to(torch.float32)
            # -------------------------------
            frames = (frames - frames.min()) / (frames.max() - frames.min())
            # --------------- PLOT ----------------
            frames[gms[..., 0] < 0.5] = torch.tensor([135, 206, 235]) / 255
            
            ncols, nrows, scale = int(frames.shape[0]**0.5), int(frames.shape[0]**0.5), 5
            fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * scale, nrows * scale))
            axes = axes.flat
            for frame, mask, logit, ax in zip(frames, gms, logits, axes):
                ax.imshow(frame)
            plt.show()

    def prepare_inputs_labels_for_multimodal(self, 
                                             input_ids, 
                                             position_ids, 
                                             attention_mask, 
                                             past_key_values, 
                                             labels, 
                                             images, 
                                             modalities=["image"], 
                                             image_sizes=None, 
                                             keep_small_image=None,
                                             calculate_all_tokens_number_with_fields=False,
                                             draw_image=False
                                             ):
        vision_tower = self.get_vision_tower()
        if vision_tower is None or images is None or input_ids.shape[1] == 1:
            return input_ids, position_ids, attention_mask, past_key_values, None, labels
        if isinstance(modalities, str):
            modalities = [modalities]

        ################################################################################
        is_tensor = isinstance(images, torch.Tensor)
        is_list_of_tensors = isinstance(images, list) and isinstance(images[0], torch.Tensor)
        assert is_tensor or is_list_of_tensors
        assert (is_tensor and 4 <= images.ndim <= 5) or (is_list_of_tensors and 4 <= images[0].ndim <= 5)

        if is_list_of_tensors or images.ndim == 5:
            video_idx_in_batch = []
            for i, m in enumerate(modalities):
                if m == "video": video_idx_in_batch.append(i)
            
            images_list = [image for image in images] # ([5, 3, 336, 336], )

            concat_images = torch.cat(images_list, dim=0) # ([sum_i N_i, 3, 336, 336], )
            split_sizes = [image.shape[0] for image in images_list] # [5, ...]
            encoded_image_features, for_selection_image_features = \
                self.encode_images(concat_images) # [sum_i N_i, 576, 4096], [sum_i N_i, 576, 1024]

            encoded_image_features = torch.split(encoded_image_features, split_sizes) # ([N_i, 576, 4096], )
            for_selection_image_features = torch.split(for_selection_image_features, split_sizes) # ([N_i, 576, 1024], )
            image_features = []
            for idx, image_feat in enumerate(encoded_image_features):
                if idx in video_idx_in_batch:
                    image_features.append(self.get_2dPool(image_feat)) # ([N_i, 144, 4096], )
                else:
                    image_features.append(image_feat)

            keep_lists = [
                [keep_small_image] + [None for _ in range(image_features[i].shape[0] - 1)]
                for i in range(len(image_features))
            ]

            mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") # spatial_unpad
            image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") # anyres
            mm_newline_position = getattr(self.config, "mm_newline_position", "one_token") # one_token

            assert mm_patch_merge_type == 'spatial_unpad'
            assert image_aspect_ratio == 'anyres'
            # assert mm_newline_position in ['one_token', 'no_token']

            if mm_patch_merge_type.startswith("spatial"):
                new_image_features = [] #
                new_gumbel_masks = [] #
                feature_lens = [] #
                new_gumbel_masks_for_drawing = []
                gumbel_masks_list = []
                logits_list = []
                all_tokens_number_list = []
                selected_tokens_number_list = []
                for image_idx, (image_feature, for_selection_image_feature, keep_list) \
                    in enumerate(zip(image_features, for_selection_image_features, keep_lists)):
#############################################################################################################
                    if image_idx in video_idx_in_batch:
                        logits_shift = torch.zeros_like(for_selection_image_feature[:, :, 0]) # [N, 576]
                        all_tokens_number = image_feature.shape[0] * image_feature.shape[1]

                        val_dict = self.sampler(for_selection_image_feature, logits_shift, all_tokens_number)
                        gumbel_mask_batch, logits_batch = val_dict['gumbel_mask'], val_dict['logits'] # [N, 576], [N, 576, 2]
                        gumbel_mask_batch = gumbel_mask_batch[..., None].to(torch.float32)
                        gumbel_mask_batch, logits_batch = self.get_2dPool(gumbel_mask_batch)[..., 0], self.get_2dPool(logits_batch) # [N, 144], [N, 144, 2]

                        all_tokens_number = gumbel_mask_batch.shape[0] * gumbel_mask_batch.shape[1]
                        gumbel_mask_batch = topk_softmax(logits_batch, all_tokens_number, self.sampler.iva_factor) # [N, 144]
                        gumbel_mask_batch = gumbel_mask_batch[..., 0] > 0.5

                        gumbel_masks_list.append(gumbel_mask_batch)
                        logits_list.append(logits_batch)

                        selected_tokens_number = gumbel_mask_batch.sum().item()
                        selected_tokens_number_list.append(selected_tokens_number)

                        if mm_newline_position == "grid":
                            image_feature, gumbel_mask_batch = self.add_token_per_grid(image_feature, gumbel_mask_batch) #[N, 144, D], [N, 144]
                            new_image_features.append(image_feature)
                            new_gumbel_masks.append(gumbel_mask_batch)
                        elif mm_newline_position == "one_token":
                            image_feature = image_feature.flatten(0, 1)
                            if 'unpad' in mm_patch_merge_type:
                                image_feature = torch.cat(
                                    (image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0)
                                gumbel_mask_batch = torch.cat(
                                    (gumbel_mask_batch, torch.tensor(True)[None].to(gumbel_mask.device, gumbel_mask.dtype)), dim=-1) # [36, 49]
                            new_image_features.append(image_feature)
                            new_gumbel_masks.append(gumbel_mask)
                        elif mm_newline_position == "no_token":
                            new_image_features.append(image_feature.flatten(0, 1))
                            new_gumbel_masks.append(gumbel_mask_batch.flatten(0, 1))
                        else:
                            raise ValueError(f"Unexpected mm_newline_position: {mm_newline_position}")
#############################################################################################################
                    else:
                        logits_shift = torch.zeros_like(for_selection_image_feature[:, :, 0]) # [5, 576]
                        print(f'{logits_shift.shape=}')
                        base_image_feature, base_logits_shift = image_feature[0], logits_shift[0] # [576, 4096], [576]
                        image_feature, logits_shift = image_feature[1:], logits_shift[1:] # [4, 576, 4096], [4, 576]

                        height = width = self.get_vision_tower().num_patches_per_side
                        assert height * width == base_image_feature.shape[0]

                        ######################################################################
                        assert hasattr(self.get_vision_tower(), "image_size")
                        vision_tower_image_size = self.get_vision_tower().image_size
                        print(f'{image_sizes[image_idx]=}')
                        print(f'{self.config.image_grid_pinpoints=}')
                        print(f'{vision_tower_image_size=}')
                        num_patch_width, num_patch_height = get_anyres_image_grid_shape(
                            image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size) # ~2,2

                        ######################################################################
                        unpad_const = 1000
                        if keep_list[0] == True:
                            base_logits_shift += unpad_const
                        elif keep_list[0] == False:
                            base_logits_shift -= unpad_const

                        logits_shift = logits_shift.view(num_patch_height, num_patch_width, height, width) # [2, 2, 24, 24]
                        logits_shift = logits_shift.permute(0, 2, 1, 3).contiguous() # [2, 24, 2, 24]
                        logits_shift = logits_shift.flatten(0, 1).flatten(1, 2) # [48, 48]
                        logits_shift = unpad_logits(logits_shift, image_sizes[image_idx], -unpad_const) # [6 + 36 + 6, 48]
                        if calculate_all_tokens_number_with_fields:
                            all_tokens_number = image_feature.shape[0] * image_feature.shape[1] + base_image_feature.shape[0]
                        else:
                            all_tokens_number = (logits_shift > -unpad_const / 2).sum().item() + base_logits_shift.shape[0]
                        all_tokens_number_list.append(all_tokens_number)
                        logits_shift = logits_shift.view(num_patch_height, height, num_patch_width, width) # [2, 24, 2, 24]
                        logits_shift = logits_shift.permute(0, 2, 1, 3) # [2, 2, 24, 24]
                        logits_shift = logits_shift.flatten(0, 1).flatten(1, 2) # [4, 576]
                        logits_shift = torch.cat([base_logits_shift[None], logits_shift], dim=0)

                        val_dict = self.sampler(for_selection_image_feature, logits_shift, all_tokens_number)
                        gumbel_mask_batch, logits_batch = val_dict['gumbel_mask'], val_dict['logits'] # [5, 576], [5, 576, 2]

                        gumbel_masks_list.append(gumbel_mask_batch)
                        logits_list.append(logits_batch)

                        selected_tokens_number = gumbel_mask_batch.sum().item()
                        selected_tokens_number_list.append(selected_tokens_number)
                        base_gumbel_mask, gumbel_mask = gumbel_mask_batch[0], gumbel_mask_batch[1:]

                        gumbel_mask = gumbel_mask.view(num_patch_height, num_patch_width, height, width) # [2, 2, 24, 24]
                        gumbel_mask = gumbel_mask.permute(0, 2, 1, 3).contiguous() # [2, 24, 2, 24]
                        gumbel_mask = gumbel_mask.flatten(0, 1).flatten(1, 2) # [48, 48]
                        gumbel_mask = unpad_image(gumbel_mask[None], image_sizes[image_idx])[0] # [36, 48]

                        ######################################################################
                        image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) # [2, 2, 24, 24, 4096]
                        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() # [4096, 2, 24, 2, 24]
                        image_feature = image_feature.flatten(1, 2).flatten(2, 3) # [4096, 48, 48]
                        image_feature = unpad_image(image_feature, image_sizes[image_idx]) # [4096, 36, 48]

                        if mm_newline_position == 'one_token':
                            image_feature = torch.cat(
                                (
                                    image_feature, 
                                    self.model.image_newline[:, None, None]
                                    .expand(*image_feature.shape[:-1], 1)
                                    .to(image_feature.device, image_feature.dtype)
                                ), 
                                dim=-1
                            ) # [4096, 36, 49]
                            gumbel_mask = torch.cat(
                                (
                                    gumbel_mask, # [36, 48]
                                    torch.tensor(True)[None, None]
                                    .expand(*gumbel_mask.shape[:-1], 1)
                                    .to(gumbel_mask.device, gumbel_mask.dtype), # [36, 1]
                                ),
                                dim=-1,
                            ) # [36, 49]

                        new_gumbel_masks_for_drawing.append(torch.clone(gumbel_mask))
                        image_feature = image_feature.flatten(1, 2).transpose(0, 1) # [36 * 49, 4096]
                        gumbel_mask = gumbel_mask.flatten() # [36 * 49]

                        ######################################################################
                        image_feature = torch.cat((base_image_feature, image_feature), dim=0) # [24 * 24 + 36 * 49, 4096]
                        gumbel_mask = torch.cat((base_gumbel_mask, gumbel_mask), dim=0) # [24 * 24 + 36 * 49]

                        new_image_features.append(image_feature) # ([24 * 24 + 36 * 49, 4096], )
                        new_gumbel_masks.append(gumbel_mask) # ([24 * 24 + 36 * 49], )
                        feature_lens.append(image_feature.size(0)) # ([], )

                if draw_image:
                    if image_idx in video_idx_in_batch:
                        self.draw_video(gumbel_masks_list, logits_list, images)
                    else:
                        self.draw_image(gumbel_masks_list, logits_list, image_sizes, images, new_gumbel_masks_for_drawing)

                image_features = new_image_features # ([24 * 24 + 36 * 49, 4096], )
                gumbel_masks_list = new_gumbel_masks # ([24 * 24 + 36 * 49], )
        else:
            raise Exception('Thet part of code unchecked')
            image_features, for_selection_image_features = self.encode_images(images)

        ######################################################################
        ######################################################################
        ######################################################################
        assert labels is None
        # assert attention_mask is None
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
        else:
            attention_mask = attention_mask.bool()

        input_ids = [
            cur_input_ids[cur_attention_mask] 
            for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
            ]

        assert len(input_ids) == 1

        new_input_embeds = []
        new_masks = []
        new_position_ids = []
        cur_image_idx = 0
        for batch_idx, cur_input_ids in enumerate(input_ids):
            num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
            cur_attention_mask = torch.ones_like(cur_input_ids, dtype=torch.bool)

            if num_images == 0:
                cur_image_features = image_features[cur_image_idx]
                cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
                cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
                new_input_embeds.append(cur_input_embeds)
                cur_image_idx += 1
                continue

            image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
            cur_input_ids_noim = []
            cur_mask_noim = []
            for i in range(len(image_token_indices) - 1):
                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
                cur_mask_noim.append(cur_attention_mask[image_token_indices[i] + 1 : image_token_indices[i + 1]])
            split_sizes = [x.shape[0] for x in cur_mask_noim]
            cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
            
            cur_new_input_embeds = []
            cur_new_mask = []
            for i in range(num_images + 1):
                cur_new_input_embeds.append(cur_input_embeds_no_im[i])
                cur_new_mask.append(cur_mask_noim[i])
                if i < num_images:
                    cur_image_features = image_features[cur_image_idx]
                    cur_image_mask = gumbel_masks_list[cur_image_idx]
                    cur_image_idx += 1
                    cur_new_input_embeds.append(cur_image_features)
                    cur_new_mask.append(cur_image_mask)

            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
            cur_new_mask = [x.to(self.device) for x in cur_new_mask]

            cur_new_input_embeds = torch.cat(cur_new_input_embeds)
            cur_new_mask = torch.cat(cur_new_mask)

            new_input_embeds.append(cur_new_input_embeds)
            new_masks.append(cur_new_mask)
            new_position_ids.append(torch.arange(len(cur_new_mask), device=self.device))

        for i in range(len(new_input_embeds)):
            new_input_embeds[i] = new_input_embeds[i][new_masks[i]]
            new_position_ids[i] = new_position_ids[i][new_masks[i]]
            new_masks[i] = new_masks[i][new_masks[i]]

        tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
        if tokenizer_model_max_length is not None:
            new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
            new_masks = [x[:tokenizer_model_max_length] for x in new_masks]
            new_position_ids = [x[:tokenizer_model_max_length] for x in new_position_ids]

        max_len = max(x.shape[0] for x in new_input_embeds)
        batch_size = len(new_input_embeds)

        new_input_embeds_padded = []
        new_attention_mask_padded = []
        new_position_ids_padded = []
        for i, (cur_new_embed,
                cur_new_mask,
                cur_new_position_ids) in \
            enumerate(zip(new_input_embeds, 
                          new_masks,
                          new_position_ids)):
            cur_len = cur_new_embed.shape[0]
            embed_pads = torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=self.device)
            mask_pads = torch.zeros(max_len - cur_len, dtype=cur_new_mask.dtype, device=self.device)
            id_pads = torch.zeros(max_len - cur_len, dtype=cur_new_position_ids.dtype, device=self.device)

            if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
                new_input_embeds_padded.append(torch.cat((embed_pads, cur_new_embed), dim=0))
                new_attention_mask_padded.append(torch.cat((mask_pads.to(cur_new_mask.dtype), cur_new_mask), dim=0))
                new_position_ids_padded.append(torch.cat((id_pads.to(cur_new_position_ids.dtype), cur_new_position_ids), dim=0))
            else:
                new_input_embeds_padded.append(torch.cat((cur_new_embed, embed_pads), dim=0))
                new_attention_mask_padded.append(torch.cat((cur_new_mask, mask_pads.to(cur_new_mask.dtype)), dim=0))
                new_position_ids_padded.append(torch.cat((cur_new_position_ids, id_pads.to(cur_new_position_ids.dtype)), dim=0))

        new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
        new_attention_mask = torch.stack(new_attention_mask_padded, dim=0)
        new_position_ids = torch.stack(new_position_ids_padded, dim=0)
        return None, new_position_ids, new_attention_mask, past_key_values, new_input_embeds, None