# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
The QwenVL2 preprocessor used for the multi-modal models.
"""

from io import BytesIO
from typing import Optional, Union
import os
import torch
from PIL import Image
from qwen_vl_utils import fetch_image, fetch_video

from .base_processor import BasicPreprocessor
from .registry import PREPROCESSOR_REGISTER

__all__ = ["QwenVLPreProcessor"]

IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200

VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9)))

VIDEO_FORMAT_HELP = """Currently, we only support the video formats introduced in qwen2-vl.
Refer to https://github.com/QwenLM/Qwen2.5-VL?tab=readme-ov-file#using---transformers-to-chat.

eg.
{
    "type": "video",
    "video": [
        "file:///path/to/frame1.jpg",
        "file:///path/to/frame2.jpg"
    ]
}

{
    "type": "video",
    "video": "file:///path/to/video.mp4"
}
# Defaults to fps=2, min_frames=4, max_frames=768

{
    "type": "video",
    "video": "file:///path/to/video.mp4",
    "fps": 2,
    "min_frames": 1,
    "max_frames": 32
}
"""
@PREPROCESSOR_REGISTER.register()
class QwenVLPreProcessor(BasicPreprocessor):
    def __init__(self, processor, image_key="image", video_key="video", **kwargs):
        super().__init__(processor, image_key=image_key, video_key=video_key)
        self.max_pixels = kwargs.get("max_pixels", MAX_PIXELS)
        self.min_pixels = kwargs.get("min_pixels", MIN_PIXELS)
        self.factor = kwargs.get("factor", IMAGE_FACTOR)
        self.video_min_pixels = kwargs.get("video_min_pixels", VIDEO_MIN_PIXELS)
        self.video_total_pixels = kwargs.get("video_total_pixels", VIDEO_TOTAL_PIXELS)
    
    def process_image(self, image, **kwargs) -> torch.Tensor:
        if isinstance(image, Image.Image):
            return image.convert("RGB")
        max_pixels = kwargs.get("max_pixels", self.max_pixels)
        min_pixels = kwargs.get("min_pixels", self.min_pixels)
        factor = kwargs.get("size_factor", self.factor)
        if "bytes" in image:
            assert "image" not in image, "Cannot have both `bytes` and `image`"
            image["image"] = BytesIO(image["bytes"])
        image["max_pixels"] = max_pixels
        image["min_pixels"] = min_pixels
        image = fetch_image(image, size_factor=factor)
        return image
    def process_audio(self, audio, **kwargs):
        raise ValueError("QwenVL series does not support audio input")
    
    def process_video(self, video, **kwargs) -> torch.Tensor:
        """Converts a video dict into a [n_frames, 3, H, W] tensor

        Add video sample FPS in a future MR
        """
        nframes = kwargs.get("nframes", None)
        fps = kwargs.get("fps", None)
        fps_min_frames = kwargs.get("fps_min_frames", None),
        fps_max_frames = kwargs.get("fps_max_frames", None),
        if not isinstance(video, dict) or "video" not in video:
            raise NotImplementedError(VIDEO_FORMAT_HELP)
        assert nframes is None or fps is None, "Can't use both `nframes` or `fps`"

        # Shallow copy... since we might want to add some keys
        video = dict(video)

        contains_sampling_rules = "nframes" in video or "fps" in video
        if not contains_sampling_rules:
            if nframes is not None:
                video["nframes"] = nframes
            elif fps is not None:
                video["fps"] = fps
                if fps_min_frames is not None:
                    video["min_frames"] = fps_min_frames
                if fps_max_frames is not None:
                    video["max_frames"] = fps_max_frames
        video_total_pixels = kwargs.get("video_total_pixels", self.video_total_pixels)
        video_min_pixels = kwargs.get("video_min_pixels", self.video_min_pixels)
        video["total_pixels"] = video_total_pixels
        video["min_pixels"] = video_min_pixels
        return_video_sample_fps = kwargs.get("return_video_sample_fps", False)
        image_factor = kwargs.get("image_factor", self.factor)
        return fetch_video(video, image_factor=image_factor, return_video_sample_fps=return_video_sample_fps)