import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Union
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import (
    OPENAI_CLIP_MEAN,
    OPENAI_CLIP_STD,
    ChannelDimension,
    ImageInput,
    PILImageResampling,
    VideoInput,
)
from transformers.utils import TensorType
from PIL import Image
import math


def smart_resize_tensor(
    tensor: torch.Tensor, 
    factor: int = 28, 
    min_pixels: int = 56 * 56, 
    max_pixels: int = 14 * 14 * 4 * 1280
):
    """
    对张量进行智能调整大小，保持与原版 smart_resize 相同的逻辑
    
    Args:
        tensor: 输入张量，形状为 [C, H, W]
        factor: 调整因子
        min_pixels: 最小像素数
        max_pixels: 最大像素数
    
    Returns:
        调整大小后的张量
    """
    _, height, width = tensor.shape
    
    if height < factor or width < factor:
        raise ValueError(f"height:{height} and width:{width} must be larger than factor:{factor}")
    elif max(height, width) / min(height, width) > 200:
        raise ValueError(
            f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
        )
    
    h_bar = round(height / factor) * factor
    w_bar = round(width / factor) * factor
    
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = math.floor(height / beta / factor) * factor
        w_bar = math.floor(width / beta / factor) * factor
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = math.ceil(height * beta / factor) * factor
        w_bar = math.ceil(width * beta / factor) * factor
    
    # 使用双线性插值调整张量大小
    resized_tensor = F.interpolate(
        tensor.unsqueeze(0), 
        size=(h_bar, w_bar), 
        mode='bilinear', 
        align_corners=False
    ).squeeze(0)
    
    return resized_tensor


class Qwen2VLTensorImageProcessor(BaseImageProcessor):
    """
    支持张量输入的 Qwen2-VL 图像处理器
    """
    
    model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]

    def __init__(
        self,
        do_resize: bool = True,
        size: Dict[str, int] = None,
        do_rescale: bool = True,
        rescale_factor: Union[int, float] = 1 / 255,
        do_normalize: bool = True,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        do_convert_rgb: bool = True,
        min_pixels: Optional[int] = None,
        max_pixels: Optional[int] = None,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
        merge_size: int = 2,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        
        if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
            raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
        else:
            size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
        
        # 处理向后兼容性
        if min_pixels is not None:
            size["shortest_edge"] = min_pixels
        if max_pixels is not None:
            size["longest_edge"] = max_pixels
            
        self.min_pixels = size["shortest_edge"]
        self.max_pixels = size["longest_edge"]
        self.size = size

        self.do_resize = do_resize
        self.do_rescale = do_rescale
        self.rescale_factor = rescale_factor
        self.do_normalize = do_normalize
        self.image_mean = torch.tensor(image_mean if image_mean is not None else OPENAI_CLIP_MEAN)
        self.image_std = torch.tensor(image_std if image_std is not None else OPENAI_CLIP_STD)

        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.merge_size = merge_size
        self.do_convert_rgb = do_convert_rgb

    def _preprocess_tensor(
        self,
        images: torch.Tensor,
        do_resize: Optional[bool] = None,
        size: Dict[str, int] = None,
        do_rescale: Optional[bool] = None,
        rescale_factor: Optional[float] = None,
        do_normalize: Optional[bool] = None,
        image_mean: Optional[torch.Tensor] = None,
        image_std: Optional[torch.Tensor] = None,
        patch_size: Optional[int] = None,
        temporal_patch_size: Optional[int] = None,
        merge_size: Optional[int] = None,
    ):
        """
        预处理张量输入的图像
        
        Args:
            images: 输入张量，形状为 [B, C, H, W] 或 [C, H, W]
        
        Returns:
            flatten_patches: 扁平化的 patches
            grid_info: (grid_t, grid_h, grid_w) 信息
        """
        # 处理参数默认值
        do_resize = do_resize if do_resize is not None else self.do_resize
        do_rescale = do_rescale if do_rescale is not None else self.do_rescale
        rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
        do_normalize = do_normalize if do_normalize is not None else self.do_normalize
        image_mean = image_mean if image_mean is not None else self.image_mean
        image_std = image_std if image_std is not None else self.image_std
        patch_size = patch_size if patch_size is not None else self.patch_size
        temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size
        merge_size = merge_size if merge_size is not None else self.merge_size
        size = size if size is not None else self.size
        
        # 确保张量在正确的设备上
        device = images.device
        image_mean = image_mean.to(device)
        image_std = image_std.to(device)
        
        # 处理输入形状
        if images.dim() == 3:  # [C, H, W]
            images = images.unsqueeze(0)  # [1, C, H, W]
        elif images.dim() != 4:
            raise ValueError(f"Expected 3D or 4D tensor, got {images.dim()}D")
        
        batch_size, channels, height, width = images.shape
        processed_images = []
        
        for i in range(batch_size):
            image = images[i]  # [C, H, W]
            
            # 1. Resize
            if do_resize:
                image = smart_resize_tensor(
                    image,
                    factor=patch_size * merge_size,
                    min_pixels=size["shortest_edge"],
                    max_pixels=size["longest_edge"],
                )
            
            # 2. Rescale (如果需要)
            if do_rescale:
                image = image * rescale_factor
            
            # 3. Normalize
            if do_normalize:
                # 确保 mean 和 std 的形状正确 [C, 1, 1]
                mean = image_mean.view(-1, 1, 1)
                std = image_std.view(-1, 1, 1)
                image = (image - mean) / std
            
            processed_images.append(image)
        
        # 转换为 patches
        patches = torch.stack(processed_images, dim=0)  # [B, C, H, W]
        
        # 获取调整后的尺寸
        resized_height, resized_width = patches.shape[2], patches.shape[3]
        
        # 处理时间维度的padding
        if patches.shape[0] % temporal_patch_size != 0:
            # 重复最后一帧来填充
            repeats_needed = temporal_patch_size - (patches.shape[0] % temporal_patch_size)
            last_frame = patches[-1:].repeat(repeats_needed, 1, 1, 1)
            patches = torch.cat([patches, last_frame], dim=0)
        
        # 计算grid信息
        grid_t = patches.shape[0] // temporal_patch_size
        grid_h = resized_height // patch_size
        grid_w = resized_width // patch_size
        channel = patches.shape[1]
        
        # 重塑为patch格式
        patches = patches.reshape(
            grid_t,
            temporal_patch_size,
            channel,
            grid_h // merge_size,
            merge_size,
            patch_size,
            grid_w // merge_size,
            merge_size,
            patch_size,
        )
        
        # 调整维度顺序
        patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
        
        # 扁平化
        flatten_patches = patches.reshape(
            grid_t * grid_h * grid_w, 
            channel * temporal_patch_size * patch_size * patch_size
        )
        
        return flatten_patches, (grid_t, grid_h, grid_w)

    def preprocess_tensor(
        self,
        images: torch.Tensor,
        **kwargs
    ):
        """
        处理张量输入的主要接口
        
        Args:
            images: 输入张量，形状为 [B, C, H, W] 或 [C, H, W]
        
        Returns:
            BatchFeature 对象，包含 pixel_values 和 image_grid_thw
        """
        flatten_patches, grid_info = self._preprocess_tensor(images, **kwargs)
        
        # 转换为numpy（如果需要与原始处理器兼容）
        pixel_values = flatten_patches.detach().cpu().numpy()
        image_grid_thw = np.array([grid_info])
        
        data = {
            "pixel_values": pixel_values,
            "image_grid_thw": image_grid_thw
        }
        
        return BatchFeature(data=data, tensor_type="pt")

    def preprocess_tensor_keep_grad(
        self,
        images: torch.Tensor,
        **kwargs
    ):
        """
        保持梯度的张量预处理（用于对抗攻击）
        
        Args:
            images: 输入张量，形状为 [B, C, H, W] 或 [C, H, W]
        
        Returns:
            tuple: (flatten_patches, grid_info)，保持梯度
        """
        return self._preprocess_tensor(images, **kwargs)