import torch
import torch.nn as nn

from transformers import CLIPImageProcessor, SiglipVisionModel
import open_clip


DEFAULT_IMAGE_PROCESSOR = 'openai/clip-vit-large-patch14-336'


def interpolate_pos_encoding(position_embedding: torch.Tensor, height: int, width: int) -> torch.Tensor:
    import math

    num_patches = height * width
    num_positions = position_embedding.shape[0]
    if num_patches == num_positions and height == width:
        return position_embedding

    dim = position_embedding.shape[-1]
    patch_pos_embed = position_embedding.reshape(1, int(math.sqrt(num_positions + 0.1)), int(math.sqrt(num_positions + 0.1)), dim)
    patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
    patch_pos_embed = nn.functional.interpolate(
        patch_pos_embed,
        scale_factor=((height + 0.1) / math.sqrt(num_positions), (width + 0.1) / math.sqrt(num_positions)),
        mode='bicubic',
        align_corners=False,
    )
    if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
        raise ValueError('Width or height does not match with the interpolated position embeddings')

    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim)
    return patch_pos_embed


class SigLIPVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()

        self.is_loaded = False

        self.vision_tower_name = vision_tower
        self.select_layer = args.mm_vision_select_layer
        self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
        self.resolution = getattr(args, 'mm_vision_resolution', 336)

        if not delay_load:
            self.load_model()
        elif getattr(args, 'unfreeze_mm_vision_tower', False):
            self.load_model()

    def load_model(self, device_map=None):
        if self.is_loaded:
            print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
            return

        self.image_processor = CLIPImageProcessor.from_pretrained(DEFAULT_IMAGE_PROCESSOR)
        self.image_processor.size['shortest_edge'] = self.resolution
        self.image_processor.crop_size['width'] = self.resolution
        self.image_processor.crop_size['height'] = self.resolution

        self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)

        num_patches_per_side = self.resolution // 14
        embed = self.vision_tower.vision_model.embeddings
        if num_patches_per_side * num_patches_per_side != embed.num_patches:
            embed.position_embedding.weight.data = interpolate_pos_encoding(
                embed.position_embedding.weight.data,
                num_patches_per_side,
                num_patches_per_side,
            )
            embed.num_patches = num_patches_per_side * num_patches_per_side
            embed.num_positions = num_patches_per_side * num_patches_per_side
            embed.position_ids = torch.arange(embed.num_positions).expand(1, -1).to(device=embed.position_ids.device, dtype=embed.position_ids.dtype)

        self.vision_tower.requires_grad_(False)

        self.is_loaded = True

    @torch.no_grad()
    def forward(self, images):
        images = images.to(device=self.device, dtype=self.dtype)
        bs = images.shape[0]
        image_features = self.vision_tower(images, output_hidden_states=True).last_hidden_state

        return image_features

    @property
    def dtype(self):
        return self.vision_tower.dtype

    @property
    def device(self):
        return self.vision_tower.device
