#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_siglip2.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass
from functools import partial, reduce
import torch.utils.checkpoint
from PIL import Image
from typing import Any, Optional, Tuple, Union, Dict
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn.init import _calculate_fan_in_and_fan_out

from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from transformers.configuration_utils import PretrainedConfig
from transformers.image_processing_utils import BatchFeature, get_size_dict
from transformers.image_transforms import (
    convert_to_rgb,
    normalize,
    rescale,
    resize,
    to_channel_dimension_format,
)
from transformers.image_utils import (
    ChannelDimension,
    PILImageResampling,
    to_numpy_array,
)
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers import PretrainedConfig
from transformers.utils import ModelOutput
from llava.utils import rank0_print
from einops import rearrange

if is_flash_attn_2_available():
    try:
        from transformers.modeling_flash_attention_utils import _flash_attention_forward
    except:
        pass


class SigLipImageProcessor:
    def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
        crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
        crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")

        self.image_mean = image_mean
        self.image_std = image_std
        self.size = size
        self.resample = resample
        self.rescale_factor = rescale_factor
        self.data_format = data_format
        self.crop_size = crop_size

    def preprocess(self, images, do_resize = True, do_center_crop = True, do_rescale = True, do_normalize = True, return_tensors = 'pt'):
        if isinstance(images, Image.Image):
            images = [images]
        else:
            # to adapt video data
            images = [to_numpy_array(image) for image in images]
            assert isinstance(images, list)

        # do_resize=False, do_center_crop=False, do_rescale=True, do_normalize=True, 

        transforms = [
            convert_to_rgb,
            to_numpy_array
        ]

        if do_resize:
            transforms.append(partial(resize, size=self.size, resample=self.resample, data_format=self.data_format))
        if do_rescale:
            transforms.append(partial(rescale, scale=self.rescale_factor, data_format=self.data_format))
        if do_normalize:
            transforms.append(partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format))
        
        transforms.append(partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format))

        images = reduce(lambda x, f: [*map(f, x)], transforms, images)
        data = {"pixel_values": images}
        return BatchFeature(data=data, tensor_type=return_tensors)


class Siglip2TextConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`Siglip2TextModel`]. It is used to instantiate a
    Siglip2 text encoder according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip2
    [google/siglip2-base-patch16-224](https://huggingface.co/google/siglip2-base-patch16-224) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        vocab_size (`int`, *optional*, defaults to 32000):
            Vocabulary size of the Siglip2 text model. Defines the number of different tokens that can be represented by
            the `inputs_ids` passed when calling [`Siglip2Model`].
        hidden_size (`int`, *optional*, defaults to 768):
            Dimensionality of the encoder layers and the pooler layer.
        intermediate_size (`int`, *optional*, defaults to 3072):
            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
        num_hidden_layers (`int`, *optional*, defaults to 12):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 12):
            Number of attention heads for each attention layer in the Transformer encoder.
        max_position_embeddings (`int`, *optional*, defaults to 64):
            The maximum sequence length that this model might ever be used with. Typically set this to something large
            just in case (e.g., 512 or 1024 or 2048).
        hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
        layer_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the layer normalization layers.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        pad_token_id (`int`, *optional*, defaults to 1):
            The id of the padding token in the vocabulary.
        bos_token_id (`int`, *optional*, defaults to 49406):
            The id of the beginning-of-sequence token in the vocabulary.
        eos_token_id (`int`, *optional*, defaults to 49407):
            The id of the end-of-sequence token in the vocabulary.
        projection_size (`int`, *optional*, defaults to `hidden_size`):
            The size of the projection head.

    Example:

    ```python
    >>> from transformers import Siglip2TextConfig, Siglip2TextModel

    >>> # Initializing a Siglip2TextConfig with google/siglip2-base-patch16-224 style configuration
    >>> configuration = Siglip2TextConfig()

    >>> # Initializing a Siglip2TextModel (with random weights) from the google/siglip2-base-patch16-224 style configuration
    >>> model = Siglip2TextModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "siglip2_text_model"
    base_config_key = "text_config"

    def __init__(
        self,
        vocab_size=32000,
        hidden_size=768,
        intermediate_size=3072,
        num_hidden_layers=12,
        num_attention_heads=12,
        max_position_embeddings=64,
        hidden_act="gelu_pytorch_tanh",
        layer_norm_eps=1e-6,
        attention_dropout=0.0,
        # This differs from `CLIPTokenizer`'s default and from openai/siglip2
        # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
        pad_token_id=1,
        bos_token_id=49406,
        eos_token_id=49407,
        projection_size=None,
        **kwargs,
    ):
        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.max_position_embeddings = max_position_embeddings
        self.layer_norm_eps = layer_norm_eps
        self.hidden_act = hidden_act
        self.attention_dropout = attention_dropout
        self.projection_size = projection_size if projection_size is not None else hidden_size


class Siglip2VisionConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a
    Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2
    [google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        hidden_size (`int`, *optional*, defaults to 768):
            Dimensionality of the encoder layers and the pooler layer.
        intermediate_size (`int`, *optional*, defaults to 3072):
            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
        num_hidden_layers (`int`, *optional*, defaults to 12):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 12):
            Number of attention heads for each attention layer in the Transformer encoder.
        num_channels (`int`, *optional*, defaults to 3):
            Number of channels in the input images.
        num_patches (`int`, *optional*, defaults to 256):
            The number of patches in the image with the size of (`patch_size`, `patch_size`).
            The image is resized to fill maximum of this number of patches, and to preserve
            the aspect ratio. In case the resulted number of patches is lower, the image is
            padded in "patch" dimension.
        patch_size (`int`, *optional*, defaults to 16):
            The size (resolution) of each patch.
        hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
        layer_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the layer normalization layers.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.

    Example:

    ```python
    >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel

    >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration
    >>> configuration = Siglip2VisionConfig()

    >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration
    >>> model = Siglip2VisionModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "siglip2_vision_model"
    base_config_key = "vision_config"

    def __init__(
        self,
        hidden_size=1152,
        intermediate_size=4304,
        num_hidden_layers=27,
        num_attention_heads=16,
        num_channels=3,
        num_patches=256,
        patch_size=16,
        hidden_act="gelu_pytorch_tanh",
        layer_norm_eps=1e-6,
        attention_dropout=0.0,
        image_size = 384,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_channels = num_channels
        self.patch_size = patch_size
        self.image_size = image_size #fixme
        self.attention_dropout = attention_dropout
        self.layer_norm_eps = layer_norm_eps
        self.hidden_act = hidden_act
        self.num_patches = num_patches


class Siglip2Config(PretrainedConfig):
    r"""
    [`Siglip2Config`] is the configuration class to store the configuration of a [`Siglip2Model`]. It is used to
    instantiate a Siglip2 model according to the specified arguments, defining the text model and vision model configs.
    Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip2
    [google/siglip2-base-patch16-224](https://huggingface.co/google/siglip2-base-patch16-224) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        text_config (`dict`, *optional*):
            Dictionary of configuration options used to initialize [`Siglip2TextConfig`].
        vision_config (`dict`, *optional*):
            Dictionary of configuration options used to initialize [`Siglip2VisionConfig`].
        kwargs (*optional*):
            Dictionary of keyword arguments.

    Example:

    ```python
    >>> from transformers import Siglip2Config, Siglip2Model

    >>> # Initializing a Siglip2Config with google/siglip2-base-patch16-224 style configuration
    >>> configuration = Siglip2Config()

    >>> # Initializing a Siglip2Model (with random weights) from the google/siglip2-base-patch16-224 style configuration
    >>> model = Siglip2Model(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config

    >>> # We can also initialize a Siglip2Config from a Siglip2TextConfig and a Siglip2VisionConfig
    >>> from transformers import Siglip2TextConfig, Siglip2VisionConfig

    >>> # Initializing a Siglip2Text and Siglip2Vision configuration
    >>> config_text = Siglip2TextConfig()
    >>> config_vision = Siglip2VisionConfig()

    >>> config = Siglip2Config.from_text_vision_configs(config_text, config_vision)
    ```"""

    model_type = "siglip2"
    sub_configs = {"text_config": Siglip2TextConfig, "vision_config": Siglip2VisionConfig}

    def __init__(self, text_config=None, vision_config=None, **kwargs):
        super().__init__(**kwargs)

        if text_config is None:
            text_config = {}
            logger.info("`text_config` is `None`. Initializing the `Siglip2TextConfig` with default values.")

        if vision_config is None:
            vision_config = {}
            logger.info("`vision_config` is `None`. initializing the `Siglip2VisionConfig` with default values.")

        self.text_config = Siglip2TextConfig(**text_config)
        self.vision_config = Siglip2VisionConfig(**vision_config)

        self.initializer_factor = 1.0

    @classmethod
    def from_text_vision_configs(cls, text_config: Siglip2TextConfig, vision_config: Siglip2VisionConfig, **kwargs):
        r"""
        Instantiate a [`Siglip2Config`] (or a derived class) from siglip2 text model configuration and siglip2 vision
        model configuration.

        Returns:
            [`Siglip2Config`]: An instance of a configuration object
        """

        return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
    r"""
    This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a
    Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2
    [google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        hidden_size (`int`, *optional*, defaults to 768):
            Dimensionality of the encoder layers and the pooler layer.
        intermediate_size (`int`, *optional*, defaults to 3072):
            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
        num_hidden_layers (`int`, *optional*, defaults to 12):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 12):
            Number of attention heads for each attention layer in the Transformer encoder.
        num_channels (`int`, *optional*, defaults to 3):
            Number of channels in the input images.
        num_patches (`int`, *optional*, defaults to 256):
            The number of patches in the image with the size of (`patch_size`, `patch_size`).
            The image is resized to fill maximum of this number of patches, and to preserve
            the aspect ratio. In case the resulted number of patches is lower, the image is
            padded in "patch" dimension.
        patch_size (`int`, *optional*, defaults to 16):
            The size (resolution) of each patch.
        hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
        layer_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the layer normalization layers.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.

    Example:

    ```python
    >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel

    >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration
    >>> configuration = Siglip2VisionConfig()

    >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration
    >>> model = Siglip2VisionModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "siglip2_vision_model"
    base_config_key = "vision_config"

    def __init__(
        self,
        hidden_size=768,
        intermediate_size=3072,
        num_hidden_layers=12,
        num_attention_heads=12,
        num_channels=3,
        num_patches=256,
        patch_size=16,
        hidden_act="gelu_pytorch_tanh",
        layer_norm_eps=1e-6,
        attention_dropout=0.0,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_channels = num_channels
        self.patch_size = patch_size
        self.attention_dropout = attention_dropout
        self.layer_norm_eps = layer_norm_eps
        self.hidden_act = hidden_act
        self.num_patches = num_patches

logger = logging.get_logger(__name__)

# General docstring
_CONFIG_FOR_DOC = "Siglip2VisionConfig"


@dataclass
class Siglip2VisionOutput(ModelOutput):
    """
    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.

    Args:
        image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The image embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    image_embeds: Optional[torch.FloatTensor] = None
    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


class Siglip2VisionEmbeddingsCNN(nn.Module):
    def __init__(self, config: Siglip2VisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )
        self.num_patches_per_side = self.image_size // self.patch_size
        self.num_patches = self.num_patches_per_side**2
        self.num_positions = self.num_patches
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)

    def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
        """
        Args:
            pixel_values (`List`):
                [C, H, W]
            spatial_shapes (`List[Tuple[int, int]]`):
                Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
        """
        batch_size = len(pixel_values)
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = []
        max_seq_len = max(h * w for h, w in spatial_shapes)
        boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
        position_ids = torch.full(
            size=(
                batch_size,
                max_seq_len,
            ),
            fill_value=0,
        )
        for batch_idx, image in enumerate(pixel_values):
            single_image_patch_embed = self.patch_embedding(image.to(dtype=target_dtype))  ### (bs, dim, h, w)
            single_embed = rearrange(single_image_patch_embed, 'b d h w -> b (h w) d')
            patch_embeds.append(single_embed.squeeze(0))

            nb_patches_h = spatial_shapes[batch_idx][0]
            nb_patches_w = spatial_shapes[batch_idx][1]
            fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
            fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
            bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
            bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
            pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
            position_ids[batch_idx][:nb_patches_h*nb_patches_w] = pos_ids
        embeddings = torch.nn.utils.rnn.pad_sequence(patch_embeds, batch_first=True, padding_value=0.0)
        position_ids = position_ids.to(self.position_embedding.weight.device)
        embeddings = embeddings + self.position_embedding(position_ids)
        return embeddings

class Siglip2VisionEmbeddings(nn.Module):
    def __init__(self, config: Siglip2VisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Linear(
            in_features=config.num_channels * self.patch_size * self.patch_size,
            out_features=self.embed_dim,
        )

        self.num_patches = config.num_patches
        self.position_embedding_size = int(self.num_patches**0.5)
        self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
    
    @staticmethod
    def resize_positional_embeddings(
        positional_embeddings: torch.Tensor,
        spatial_shapes: torch.LongTensor,
        max_length: int,
    ) -> torch.Tensor:
        """
        Resize positional embeddings to image-specific size and pad to a fixed size.

        Args:
            positional_embeddings (`torch.Tensor`):
                Position embeddings of shape (height, width, embed_dim)
            spatial_shapes (`torch.LongTensor`):
                Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
            max_length (`int`):
                Maximum length of the positional embeddings to pad resized positional embeddings to

        Returns:
            `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
        """
        batch_size = spatial_shapes.shape[0]
        embed_dim = positional_embeddings.shape[-1]
        source_dtype = positional_embeddings.dtype

        resulted_positional_embeddings = torch.empty(
            (batch_size, max_length, embed_dim),
            device=positional_embeddings.device,
            dtype=source_dtype,
        )

        # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
        positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)

        # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
        if positional_embeddings.device.type == "cpu":
            positional_embeddings = positional_embeddings.to(torch.float32)

        for i in range(batch_size):
            # (1, dim, height, width) -> (1, dim, target_height, target_width)
            height, width = spatial_shapes[i]
            resized_embeddings = F.interpolate(
                positional_embeddings,
                size=(height, width),
                mode="bilinear",
                align_corners=False,
                antialias=True,
            )

            # (1, dim, target_height, target_width) -> (target_height * target_width, dim)
            resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)

            # Cast to original dtype
            resized_embeddings = resized_embeddings.to(source_dtype)

            resulted_positional_embeddings[i, : height * width] = resized_embeddings
            resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]

        return resulted_positional_embeddings

    def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
        """
        Args:
            pixel_values (`torch.FloatTensor`):
                Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
            spatial_shapes (`List[Tuple[int, int]]`):
                Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
        """
        # Apply patch embeddings to already patchified pixel values
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))

        # Get positional resized and padded positional embeddings
        positional_embeddings = self.position_embedding.weight.reshape(
            self.position_embedding_size, self.position_embedding_size, -1
        )
        resized_positional_embeddings = self.resize_positional_embeddings(
            positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
        )

        # Add positional embeddings to patch embeddings
        embeddings = patch_embeds + resized_positional_embeddings
        return embeddings

def apply_rope(xq, xk, freqs_cis, use_flash_attention=False):
    if freqs_cis is None:
        if use_flash_attention:
            return xq, xk
        else:
            return xq.transpose(1, 2), xk.transpose(1, 2)
    freqs_cis = freqs_cis.unsqueeze(-2)  # ..., 1, head_dim/2
    # ..., num_heads, head_dim/2
    xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)  # ..., num_heads, head_dim
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)  # ..., num_heads, head_dim
    xq_out = xq_out.type_as(xq)
    xk_out = xk_out.type_as(xk)
    if use_flash_attention:
        return xq_out, xk_out
    else:
        return xq_out.transpose(1, 2), xk_out.transpose(1, 2)

def window_partition_4(qkv, spatial_shapes, attn_type='flash_attn', sdpa_type=None):
    B, _, num_heads, head_dim = qkv.shape
    windows, restore_info = [], []

    for b in range(B):
        H_i, W_i = spatial_shapes[b]
        qkv_b = qkv[b][:H_i * W_i].view(H_i, W_i, num_heads, head_dim)
        h_mid, w_mid = H_i // 2, W_i // 2
        regions = [(0, h_mid, 0, w_mid), (0, h_mid, w_mid, W_i),
                   (h_mid, H_i, 0, w_mid), (h_mid, H_i, w_mid, W_i)]

        for h0, h1, w0, w1 in regions:
            window = qkv_b[h0:h1, w0:w1].reshape(-1, num_heads, head_dim)
            windows.append(window)

            grid_h = torch.arange(h0, h1)
            grid_w = torch.arange(w0, w1)
            gh, gw = torch.meshgrid(grid_h, grid_w, indexing="ij")
            coords = torch.stack([gh.flatten(), gw.flatten()], dim=1)
            restore_info.append({"batch_idx": b, "coords": coords, "H": H_i, "W": W_i})
    
    lengths = [w.shape[0] for w in windows]
    padded_windows = torch.nn.utils.rnn.pad_sequence(windows, batch_first=True)
    max_len = padded_windows.size(1)
    
    if attn_type == "flash_attn":
        attn_mask = torch.arange(max_len, device=qkv.device)[None, :].expand(len(lengths), -1)  # (B, L)
        attn_mask = attn_mask <= torch.tensor(lengths, device=qkv.device)[:, None]
        attn_mask = attn_mask.int()
    elif attn_type == "sdpa":
        if sdpa_type is not None:
            dtype = sdpa_type
        else:
            dtype = qkv.dtype
        inf_value = torch.finfo(dtype).min
        seq_mask = torch.arange(max_len, device=qkv.device)[None, :] < torch.tensor(lengths, device=qkv.device)[:, None]  # (B, L)
        visible = seq_mask[:, :, None] & seq_mask[:, None, :]
        attn_mask = (~visible).to(dtype) * inf_value
        attn_mask = attn_mask.unsqueeze(1)
    return padded_windows, attn_mask, restore_info, max_len

def window_restore_spatial(padded_out, restore_info, spatial_shapes, seq_len):
    B = max(r["batch_idx"] for r in restore_info) + 1
    H_pad = max(h for h, w in spatial_shapes)
    W_pad = max(w for h, w in spatial_shapes)
    H, D = padded_out.shape[2], padded_out.shape[3]
    out_spatial = torch.zeros((B, H_pad, W_pad, H, D), device=padded_out.device, dtype=padded_out.dtype)

    for i, info in enumerate(restore_info):
        b = info["batch_idx"]
        coords = info["coords"]
        h, w = coords[:, 0], coords[:, 1]
        assert ((h * w) < seq_len).all()
        out_spatial[b, h, w] = padded_out[i, :coords.shape[0]]

    return out_spatial.view(B, H_pad * W_pad, H, D)[:, :seq_len, :, :]

class Siglip2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        spatial_shapes: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
        position_embedding: Optional[torch.Tensor] = None,
        window_attention: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Input shape: Batch x Time x Channel"""

        batch_size, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
        key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
        query_states, key_states = apply_rope(query_states, key_states, position_embedding)
        
        value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)

        k_v_seq_len = key_states.shape[-2]
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale

        if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights


class Siglip2SdpaAttention(Siglip2Attention):
    """
    Siglip2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `Siglip2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """

    is_causal = False

    # Adapted from Siglip2Attention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        spatial_shapes: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
        position_embedding: Optional[torch.Tensor] = None,
        window_attention: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        if output_attentions:
            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
            logger.warning_once(
                "Siglip2Model is using Siglip2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
            )

        batch_size, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
        key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
        query_states, key_states = apply_rope(query_states, key_states, position_embedding)

        value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and attention_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True if self.is_causal and q_len > 1 else False
        if window_attention:
            query_window_states, attn_mask_window, restore_info, max_len_window = window_partition_4(query_states.transpose(1, 2), spatial_shapes, attn_type='sdpa', sdpa_type=attention_mask.dtype)
            key_window_states, _, _, _ = window_partition_4(key_states.transpose(1, 2), spatial_shapes, attn_type='sdpa', sdpa_type=attention_mask.dtype)
            value_window_states, _, _, _ = window_partition_4(value_states.transpose(1, 2), spatial_shapes, attn_type='sdpa', sdpa_type=attention_mask.dtype)

            attn_output_window = torch.nn.functional.scaled_dot_product_attention(
                query_window_states.transpose(1, 2),
                key_window_states.transpose(1, 2),
                value_window_states.transpose(1, 2),
                attn_mask=attn_mask_window,
                dropout_p=self.dropout if self.training else 0.0,
                is_causal=is_causal,
            )

            attn_output = window_restore_spatial(attn_output_window.transpose(1, 2), restore_info, spatial_shapes, q_len)
            attn_output = attn_output.contiguous()
        else:
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                query_states,
                key_states,
                value_states,
                attn_mask=attention_mask,
                dropout_p=self.dropout if self.training else 0.0,
                is_causal=is_causal,
            )

            attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
        attn_output = self.out_proj(attn_output)

        return attn_output, None

class Siglip2FlashAttention2(Siglip2Attention):
    """
    Siglip2Attention flash attention module. This module inherits from `Siglip2Attention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flash attention and deal with padding tokens in case the input contains any of them.
    """

    is_causal = False

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        spatial_shapes: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        position_embedding: Optional[torch.Tensor] = None,
        window_attention: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        output_attentions = False

        batch_size, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dim x hidden_dim
        # therefore we just need to keep the original shape
        query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
        key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
        value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)

        query_states, key_states = apply_rope(query_states, key_states, position_embedding, use_flash_attention=True)
        dropout_rate = self.dropout if self.training else 0.0

        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)
        if window_attention:
            query_window_states, attn_mask_window, restore_info, max_len_window = window_partition_4(query_states, spatial_shapes)
            key_window_states, _, _, _ = window_partition_4(key_states, spatial_shapes)
            value_window_states, _, _, _ = window_partition_4(value_states, spatial_shapes)
            attn_output_window = _flash_attention_forward(
                query_window_states,
                key_window_states,
                value_window_states,
                attn_mask_window,
                max_len_window,
                dropout=dropout_rate,
                is_causal=self.is_causal,
                use_top_left_mask=self._flash_attn_uses_top_left_mask,
            )
            attn_output = window_restore_spatial(attn_output_window, restore_info, spatial_shapes, q_len)
        else:
            attn_output = _flash_attention_forward(
                query_states,
                key_states,
                value_states,
                attention_mask,
                q_len,
                dropout=dropout_rate,
                is_causal=self.is_causal,
                use_top_left_mask=self._flash_attn_uses_top_left_mask,
            )

        attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
        attn_output = self.out_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights

class Siglip2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.activation_fn = ACT2FN[config.hidden_act]
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


SIGLIP2_ATTENTION_CLASSES = {
    "eager": Siglip2Attention,
    "sdpa": Siglip2SdpaAttention,
    "flash_attention_2": Siglip2FlashAttention2,
}

###### WINDOWED TOKEN COMPRESSION ######
class WTC(nn.Module):
    def __init__(self, embed_dim, enable_merging=True, merging_method="avg_pooling", norm_layer=nn.LayerNorm):
        super().__init__()
        self.enable_merging = enable_merging
        self.merging_method = merging_method
        self.reduction = nn.Identity()
        self.norm = nn.Identity()
        self.res_reduction = nn.Identity()
        self.res_norm = nn.Identity()
        self.zero_init_fc = nn.Linear(embed_dim, embed_dim, bias=False)

        if self.merging_method == 'avg_pooling':
            pass
        elif self.merging_method == 'm_pooling':
            self.attn_layer = nn.Sequential( 
                nn.Linear(embed_dim * 2, embed_dim),
                nn.GELU(),
                nn.Linear(embed_dim, embed_dim)
            )
            self.num_head = 16

    def forward(self, x, spatial_shapes, attention_mask=None):
        if not self.enable_merging:
            return x, spatial_shapes, attention_mask
        feature_x = x
        batch_size, max_seq_len, embed_dim = x.shape
        output_x = torch.zeros_like(x[:, :max_seq_len//4, :], dtype=x.dtype, device=x.device)
        if (attention_mask == 1).any():
            output_attention_mask = torch.zeros((batch_size, max_seq_len//4), dtype=attention_mask.dtype, device=attention_mask.device)
        else:
            output_attention_mask = torch.zeros((batch_size, 1, max_seq_len//4, max_seq_len//4), dtype=attention_mask.dtype, device=attention_mask.device)
        res_list = []
        x_i_list = []
        idx_list = []
        seq_len_list = []
        idx = 0
        for i, spatial_shape in enumerate(spatial_shapes):
            H, W = spatial_shape
            x_i = x[i][:H*W].reshape(H, W, embed_dim)

            if self.merging_method == 'avg_pooling':
                x_i = rearrange(x_i, 'h w c -> c h w')
                x_i = F.avg_pool2d(x_i, kernel_size=2, stride=2)  
                x_i = rearrange(x_i, 'c h w -> (h w) c') 
                x_i_list.append(x_i)

            elif 'm_pooling' in self.merging_method:  ## SE-attention
                x_i = rearrange(x_i, '(h p1) (w p2) c -> (h w) (p1 p2) c', p1=2, p2=2)
                pooled_x_i = x_i.mean(-2, keepdim=True).expand(-1, 4, -1)
                fused_x_i = torch.cat([x_i, pooled_x_i], dim=-1)
                attn_logits = self.attn_layer(fused_x_i)
                # multi-head attn
                attn_logits = rearrange(attn_logits, 'n s (m d) -> n m s d', m=self.num_head)
                attn_weights = F.softmax(attn_logits, dim=-2)
                attn_weights = rearrange(attn_weights, 'n m s d -> n s (m d)')
                # multi-head attn
                x_i = (x_i * attn_weights).sum(-2)
                x_i_list.append(x_i)
            seq_len = x_i.size(0)
            seq_len_list.append(seq_len)
            idx_list.append((idx, idx + seq_len))
            idx += seq_len
        new_x = torch.cat(x_i_list, dim=0)   
        new_x = self.norm(new_x)
        new_x = self.reduction(new_x)

        if res_list != []:
            res_x = torch.cat(res_list, dim=0)
            res_x = self.res_reduction(res_x)
            res_x = self.zero_init_fc(res_x)
            new_x += res_x
            res_x = self.res_norm(res_x)

        for i in range(batch_size):
            m, n = idx_list[i]
            seq_len = seq_len_list[i]
            output_x[i][:seq_len] = new_x[m:n]
            if attention_mask is not None:
                if (attention_mask == 1).any(): 
                    output_attention_mask[i][:seq_len] = 1
                else:
                    inf_value = torch.finfo(attention_mask.dtype).min
                    output_attention_mask[i][0][:, seq_len:] = inf_value
        return output_x, spatial_shapes // 2, output_attention_mask, feature_x
            

class Siglip2EncoderLayer(nn.Module):
    def __init__(self, config: Siglip2Config, layer_index):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.self_attn = SIGLIP2_ATTENTION_CLASSES[config._attn_implementation](config=config)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = Siglip2MLP(config)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

        self.position_embed_dim = self.embed_dim//config.num_attention_heads
        self.layer_index = layer_index
        self.window_attention = False
        if hasattr(config, 'vision_config'):
            if layer_index in config.vision_config['merger_layer_index']:
                self.merger = WTC(config.hidden_size, merging_method=config.vision_config['merging_method'])
            else:
                self.merger = None
            if layer_index < config.vision_config['merger_layer_index'][0]:
                self.window_attention = config.vision_config.get("window_attention", False)
        else:
            if layer_index in config.merger_layer_index:
                self.merger = WTC(config.hidden_size, merging_method=config.merging_method)
            else:
                self.merger = None
            if layer_index < config.merger_layer_index[0]:
                self.window_attention = getattr(config, "window_attention", False)

    def get_position_embedding(self, position_embedding, spatial_shapes, target_length=None):
        shapes = spatial_shapes.tolist()
        _position_embedding = [position_embedding[:h, :w].reshape(-1, self.position_embed_dim // 2) for h, w in shapes]
        
        real_list = [p.real for p in _position_embedding]
        imag_list = [p.imag for p in _position_embedding]

        real_padded = torch.nn.utils.rnn.pad_sequence(real_list, batch_first=True, padding_value=1.0)
        imag_padded = torch.nn.utils.rnn.pad_sequence(imag_list, batch_first=True, padding_value=0.0)

        position_embedding_complex = torch.complex(real_padded, imag_padded)
        return position_embedding_complex

    # Ignore copy
    def forward(
        self,
        hidden_states: torch.Tensor,
        spatial_shapes,
        attention_mask: torch.Tensor,
        position_embedding,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor]:
        """
        Args:
            hidden_states (`torch.FloatTensor`):
                Input to the layer of shape `(batch, seq_len, embed_dim)`.
            attention_mask (`torch.FloatTensor`):
                Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        if position_embedding is not None:
            position_embedding = self.get_position_embedding(position_embedding, spatial_shapes)
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            spatial_shapes=spatial_shapes,
            output_attentions=output_attentions,
            position_embedding=position_embedding,
            window_attention = self.window_attention
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        if self.merger is not None:
            hidden_states, spatial_shapes, attention_mask, feature_x = self.merger(hidden_states, spatial_shapes, attention_mask)
            outputs = (hidden_states, spatial_shapes, attention_mask, attn_weights, feature_x)
        else:
            outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)

        return outputs

class Siglip2Encoder(nn.Module):
    """
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`Siglip2EncoderLayer`].

    Args:
        config: Siglip2Config
    """

    def __init__(self, config: Siglip2Config):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([Siglip2EncoderLayer(config, layer_index=i) for i in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    # Ignore copy
    def forward(
        self,
        inputs_embeds,
        spatial_shapes,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        position_embedding: Optional[list] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        r"""
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        
        hidden_states = inputs_embeds
        new_attention_mask = attention_mask
        if position_embedding is None:
            cur_position_embedding = None
        else:
            position_embedding_idx = 0
            cur_position_embedding = position_embedding[position_embedding_idx]
        for encoder_layer in self.layers:
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    encoder_layer.__call__,
                    hidden_states,
                    spatial_shapes,
                    new_attention_mask,
                    cur_position_embedding,
                    output_attentions,
                )
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    spatial_shapes,
                    new_attention_mask,
                    cur_position_embedding,
                    output_attentions=output_attentions,
                )

            hidden_states = layer_outputs[0]

            ## swin
            if len(layer_outputs) > 2 and not output_attentions:
                spatial_shapes = layer_outputs[1]
                new_attention_mask = layer_outputs[2]
                ## TODO:position_embedding
                if position_embedding is not None:
                    position_embedding_idx += 1
                    cur_position_embedding = position_embedding[position_embedding_idx]
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
        )


SIGLIP2_VISION_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
            Whether to interpolate the pre-trained position encodings.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

class Rope2DPosEmb(nn.Module):
    """2D rotary position embedding with multi-resolution support.
    This class is intended to be used in the following way:
    1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
    2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
    3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
        The rope is shared across all attention layers and all heads.
    Refs:
    - RoFormer: https://arxiv.org/abs/2104.09864
    - VisionLLaMA: https://arxiv.org/abs/2403.00522
    - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
    Args:
        dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
        max_height (int): the maximum height of the 2D grid
        max_width (int): the maximum width of the 2D grid
        theta_base (float): the base of the theta
        device (str): the device to store the precomputed cis
    """

    def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
        super().__init__()
        self.dim = dim
        assert self.dim % 4 == 0, "dim must be divisible by 4"
        self.max_height = max_height
        self.max_width = max_width
        self.theta_base = theta_base
        self.freqs_cis = None

    def _precompute_freqs_cis(self, max_height, max_width, device: torch.device) -> torch.Tensor:
        """Calculate the cis(freqs) for each position in the 2D grid.
        Return: complex tensor of shape (max_height, max_width, dim//2) and value:
            height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
            weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim))   with (i in [0, dim//4))
            note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
        """
        N = max_height * max_width
        flat_pos = torch.arange(0, N).float().to(device)
        x_pos = flat_pos % self.max_width
        y_pos = flat_pos // self.max_width
        dim_range = (
            torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device)
        )  # C/4
        freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
        x_freqs = torch.outer(x_pos, freqs).float()  # N, C/4
        y_freqs = torch.outer(y_pos, freqs).float()  # N, C/4
        x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)  # N, C/4
        y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)  # N, C/4
        # N, C/4, 2
        freqs_cis = torch.cat(
            [x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1
        )
        # max_height, max_width, C/2
        freqs_cis = freqs_cis.reshape(max_height, max_width, -1)
        return freqs_cis
    
    def precompute_n_freqs_cis(self, merger_layer_num, device):
        max_height, max_width = self.max_height, self.max_width
        n_freqs_cis = []
        ori_freqs_cis = self._precompute_freqs_cis(max_height, max_width, device)
        n_freqs_cis.append(ori_freqs_cis)
        for i in range(merger_layer_num):
            max_height = max_height // 2
            max_width = max_width // 2
            freqs_cis = self._precompute_freqs_cis(max_height, max_width, device)
            n_freqs_cis.append(freqs_cis)
        return n_freqs_cis


class Siglip2VisionTransformer(nn.Module):
    def __init__(self, config: Siglip2VisionConfig):
        super().__init__()
        config._attn_implementation = "flash_attention_2" if hasattr(config, "use_flash_attention_2") and config.use_flash_attention_2 else "sdpa"
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" 
        
        self.config = config
        embed_dim = config.hidden_size
        self.encoder = Siglip2Encoder(config)
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.use_head = False if not hasattr(config, "vision_use_head") else config.vision_use_head

        if hasattr(config, 'vision_config'):
            self.use_rope2d = False if 'use_rope2d' not in config.vision_config else config.vision_config['use_rope2d']
            if self.use_rope2d:
                self.rope2d = Rope2DPosEmb(embed_dim//config.num_attention_heads, 512, 512)
            self.changed_patch_size = False if 'changed_patch_size' not in config.vision_config else config.vision_config['changed_patch_size']  
        else:
            self.use_rope2d = False if 'use_rope2d' not in config else config.use_rope2d
            if self.use_rope2d:
                self.rope2d = Rope2DPosEmb(embed_dim//config.num_attention_heads, 512, 512)
            self.changed_patch_size = False if 'changed_patch_size' not in config else config.changed_patch_size    
        if self.changed_patch_size:
            self.embeddings = Siglip2VisionEmbeddingsCNN(config)
        else:
            self.embeddings = Siglip2VisionEmbeddings(config)
    @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig)
    def forward(
        self,
        pixel_values,
        image_list,
        attention_mask: torch.Tensor,
        spatial_shapes: torch.LongTensor,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        r"""
        Returns:

        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        if self.changed_patch_size:
            hidden_states = self.embeddings(image_list, spatial_shapes)
        else:
            hidden_states = self.embeddings(pixel_values, spatial_shapes)
        if attention_mask is not None and not self._use_flash_attention_2:
            # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
            encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
        else:
            encoder_attention_mask = attention_mask.detach().to(dtype=torch.int32)

        ### position_embedding ###
        if self.use_rope2d:
            if hasattr(self.config, 'vision_config'):
                position_embedding = self.rope2d.precompute_n_freqs_cis(len(self.config.vision_config['merger_layer_index']), hidden_states.device)
            else:
                position_embedding = self.rope2d.precompute_n_freqs_cis(len(self.config.merger_layer_index), hidden_states.device)
        else:
            position_embedding = None

        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            spatial_shapes=spatial_shapes,
            attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            position_embedding=position_embedding,
            return_dict=return_dict,
        )
        last_hidden_state = encoder_outputs[0]
        if isinstance(last_hidden_state, tuple):
            last_hidden_state, feature_x_list = last_hidden_state
            last_hidden_state = self.post_layernorm(last_hidden_state)
            pooled_output = self.head(last_hidden_state)
            last_hidden_state = (last_hidden_state, feature_x_list)
        else:
            last_hidden_state = self.post_layernorm(last_hidden_state)
            pooled_output = self.head(last_hidden_state)

        if not return_dict:
            return (last_hidden_state, pooled_output, feature_x_list) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )


def _trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn(
            "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
            "The distribution of values may be incorrect.",
            stacklevel=2,
        )

    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * l - 1, 2 * u - 1)

    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    tensor.erfinv_()

    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.0))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)


def trunc_normal_tf_(
    tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
) -> torch.Tensor:
    """Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \\leq \text{mean} \\leq b`.

    NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
    bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
    and the result is subsequently scaled and shifted by the mean and std args.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    """
    with torch.no_grad():
        _trunc_normal_(tensor, 0, 1.0, a, b)
        tensor.mul_(std).add_(mean)


def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    if mode == "fan_in":
        denom = fan_in
    elif mode == "fan_out":
        denom = fan_out
    elif mode == "fan_avg":
        denom = (fan_in + fan_out) / 2

    variance = scale / denom

    if distribution == "truncated_normal":
        # constant is stddev of standard normal truncated to (-2, 2)
        trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
    elif distribution == "normal":
        with torch.no_grad():
            tensor.normal_(std=math.sqrt(variance))
    elif distribution == "uniform":
        bound = math.sqrt(3 * variance)
        with torch.no_grad():
            tensor.uniform_(-bound, bound)
    else:
        raise ValueError(f"invalid distribution {distribution}")


def lecun_normal_(tensor):
    variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")


def default_flax_embed_init(tensor):
    variance_scaling_(tensor, mode="fan_in", distribution="normal")


class Siglip2PreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = Siglip2Config
    base_model_prefix = "siglip2"
    supports_gradient_checkpointing = True

    _no_split_modules = [
        "Siglip2TextEmbeddings",
        "Siglip2EncoderLayer",
        "Siglip2VisionEmbeddings",
        "Siglip2EncoderLayer",
        "Siglip2MultiheadAttentionPoolingHead",
    ]
    _supports_flash_attn_2 = True
    _supports_sdpa = True

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, Siglip2VisionEmbeddings):
            width = self.config.hidden_size
            nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
        elif isinstance(module, nn.Embedding):
            default_flax_embed_init(module.weight)
        elif isinstance(module, Siglip2Attention):
            nn.init.xavier_uniform_(module.q_proj.weight)
            nn.init.xavier_uniform_(module.k_proj.weight)
            nn.init.xavier_uniform_(module.v_proj.weight)
            nn.init.xavier_uniform_(module.out_proj.weight)
            nn.init.zeros_(module.q_proj.bias)
            nn.init.zeros_(module.k_proj.bias)
            nn.init.zeros_(module.v_proj.bias)
            nn.init.zeros_(module.out_proj.bias)
        elif isinstance(module, Siglip2MLP):
            nn.init.xavier_uniform_(module.fc1.weight)
            nn.init.xavier_uniform_(module.fc2.weight)
            nn.init.normal_(module.fc1.bias, std=1e-6)
            nn.init.normal_(module.fc2.bias, std=1e-6)
        elif isinstance(module, (nn.Linear, nn.Conv2d)):
            lecun_normal_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


class Siglip2VisionModel(Siglip2PreTrainedModel):
    config_class = Siglip2VisionConfig
    main_input_name = "pixel_values"

    def __init__(self, config: Siglip2VisionConfig):
        super().__init__(config)

        self.vision_model = Siglip2VisionTransformer(config)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

    @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig)
    def forward(
        self,
        pixel_values: torch.FloatTensor,
        image_list,
        pixel_attention_mask: torch.Tensor,
        spatial_shapes: torch.LongTensor,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        r"""
        Returns:

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Siglip2VisionModel

        >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled features
        ```"""
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        return self.vision_model(
            pixel_values=pixel_values,
            image_list=image_list,
            attention_mask=pixel_attention_mask,
            spatial_shapes=spatial_shapes,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


class SigLip2SVisionTower(nn.Module):
    def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
        super().__init__()

        self.is_loaded = False

        self.config = Siglip2VisionConfig()

        self.vision_tower_name = vision_tower

        self.image_processor = SigLipImageProcessor()

        if not delay_load:
            rank0_print(f"Loading vision tower: {vision_tower}")
            self.load_model()
            if getattr(vision_tower_cfg, "merger_from_prev", False):
                self._init_merger_from_prev_(self.vision_tower)
        elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
            # TODO: better detector is needed.
            rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
            self.load_model()
            if getattr(vision_tower_cfg, "merger_from_prev", False):
                self._init_merger_from_prev_(self.vision_tower)
        elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
            rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
            self.load_model()
            if getattr(vision_tower_cfg, "merger_from_prev", False):
                self._init_merger_from_prev_(self.vision_tower)
        else:
            self.cfg_only = self.config

    def load_model(self, device_map=None):
        if self.is_loaded:
            rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
            return

        #### ignore_mismatched_sizes=True ####
        self.vision_tower = Siglip2VisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
        print('siglip2_naflex_swin')
        self.vision_tower.vision_model.head = nn.Identity()
        self._init_zero_merger_(self.vision_tower)
        self.vision_tower.requires_grad_(False)
        self.is_loaded = True
    
    def _init_zero_merger_(self, model):
        """
        Initialize the merger layer.
        """
        for name, param in model.named_parameters():
            if "zero" in name and "merger" in name:
                param.data.zero_()
    
    def _init_merger_from_prev_(self, model):
        """
        Initialize the merger layer.
        """
        rank0_print("Initializing merger from previous model.")
        layers = model.vision_model.encoder.layers
        for i in range(len(layers)):
            if hasattr(layers[i], "merger") and layers[i].merger is not None:
                prev_attn = layers[i].self_attn
                prev_mlp = layers[i].mlp
                prev_ln1 = layers[i].layer_norm1
                prev_ln2 = layers[i].layer_norm2
                if hasattr(layers[i].merger, "token_packer"):
                    token_packer = layers[i].merger.token_packer
                    for j in range(len(token_packer)):
                        with torch.no_grad():
                            token_packer[j].k_proj.weight.copy_(prev_attn.k_proj.weight)
                            token_packer[j].v_proj.weight.copy_(prev_attn.v_proj.weight)
                            token_packer[j].q_proj.weight.copy_(prev_attn.q_proj.weight)
                            token_packer[j].out_proj.weight.copy_(prev_attn.out_proj.weight)

                            token_packer[j].k_proj.bias.copy_(prev_attn.k_proj.bias)
                            token_packer[j].v_proj.bias.copy_(prev_attn.v_proj.bias)
                            token_packer[j].q_proj.bias.copy_(prev_attn.q_proj.bias)
                            token_packer[j].out_proj.bias.copy_(prev_attn.out_proj.bias)

                            token_packer[j].ffn[0].weight.copy_(prev_mlp.fc1.weight)
                            token_packer[j].ffn[2].weight.copy_(prev_mlp.fc2.weight)
                            token_packer[j].ffn[0].bias.copy_(prev_mlp.fc1.bias)
                            token_packer[j].ffn[2].bias.copy_(prev_mlp.fc2.bias)

                            token_packer[j].ln_q.weight.copy_(prev_ln1.weight)
                            token_packer[j].ln_q.bias.copy_(prev_ln1.bias)
                            token_packer[j].ln_kv.weight.copy_(prev_ln1.weight)
                            token_packer[j].ln_kv.bias.copy_(prev_ln1.bias)
                            token_packer[j].ln_ffn.weight.copy_(prev_ln2.weight)
                            token_packer[j].ln_ffn.bias.copy_(prev_ln2.bias)
        
    def forward(self, images, patch_sizes):
        if type(images) is list: 
            image_list = []
            pixel_values = []
            pixel_attention_masks = []
            spatial_shapes = []
            ### TODO: ###
            # max_length = 16384
            max_length = max([patch_size[0] * patch_size[1] for patch_size in patch_sizes])
            encoder_patch_size = self.vision_tower.vision_model.embeddings.patch_size
            for image, spatial_shape in zip(images, patch_sizes):
                valid_pixel_num = spatial_shape[0] * spatial_shape[1]
                spatial_shape = torch.as_tensor(spatial_shape)[None]
                
                image = image.to(device=self.device, dtype=self.dtype).unsqueeze(0)
                _,_,h,w = image.shape
                new_h = (h // encoder_patch_size) * encoder_patch_size
                new_w = (w // encoder_patch_size) * encoder_patch_size
                image = F.interpolate(image, size=(new_h, new_w), mode='bilinear', align_corners=False)
                
                pixel_value = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=encoder_patch_size, p2=encoder_patch_size)
                # b, n, c
                padding_pixel = torch.zeros_like(pixel_value)[:, :1]
                pixel_value = torch.cat([pixel_value, padding_pixel.repeat(1, max_length - valid_pixel_num, 1)], dim=1)
                pixel_attention_mask = torch.zeros_like(pixel_value[:, :, 0])
                pixel_attention_mask[:valid_pixel_num, :valid_pixel_num] = 1
                
                image_list.append(image)
                pixel_values.append(pixel_value)
                pixel_attention_masks.append(pixel_attention_mask)
                spatial_shapes.append(spatial_shape)
            
            pixel_values = torch.cat(pixel_values)
            pixel_attention_masks = torch.cat(pixel_attention_masks)
            spatial_shapes = torch.cat(spatial_shapes)

            image_forward_outs = self.vision_tower(pixel_values, 
                                                   image_list,
                                                   pixel_attention_mask=pixel_attention_masks, 
                                                   spatial_shapes=spatial_shapes, 
                                                   output_hidden_states=True)
            
            image_features = image_forward_outs.last_hidden_state.to(pixel_values.dtype)
            image_features = image_features.split(1)
        
        else: 
            print('no support for paralla')
            exit()
            image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),spatial_shapes=patch_sizes, output_hidden_states=True)
            image_features = image_forward_outs.last_hidden_state.to(images.dtype)

        return image_features

    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        for p in self.vision_tower.parameters():
            return p.dtype

    @property
    def device(self):
        for p in self.vision_tower.parameters():
            return p.device

    @property
    def hidden_size(self):
        return self.config.hidden_size

    @property
    def num_patches(self):
        return (self.config.image_size // self.config.patch_size) ** 2

    @property
    def num_patches_per_side(self):
        return self.config.image_size // self.config.patch_size

    @property
    def image_size(self):
        return self.config.image_size
