import math

from qwen_vl_utils import process_vision_info

IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200

VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_MAX_PIXELS = 768 * 28 * 28
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768


def round_by_factor(number: int, factor: int) -> int:
    """Returns the closest integer to 'number' that is divisible by 'factor'."""
    return round(number / factor) * factor


def ceil_by_factor(number: int, factor: int) -> int:
    """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
    return math.ceil(number / factor) * factor


def floor_by_factor(number: int, factor: int) -> int:
    """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
    return math.floor(number / factor) * factor


def smart_resize(
    height: int,
    width: int,
    factor: int = IMAGE_FACTOR,
    min_pixels: int = MIN_PIXELS,
    max_pixels: int = MAX_PIXELS,
) -> tuple[int, int]:
    """
    Rescales the image so that the following conditions are met:

    1. Both dimensions (height and width) are divisible by 'factor'.

    2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].

    3. The aspect ratio of the image is maintained as closely as possible.
    """
    if max(height, width) / min(height, width) > MAX_RATIO:
        raise ValueError(
            f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
        )
    h_bar = max(factor, round_by_factor(height, factor))
    w_bar = max(factor, round_by_factor(width, factor))
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = floor_by_factor(height / beta, factor)
        w_bar = floor_by_factor(width / beta, factor)
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = ceil_by_factor(height * beta, factor)
        w_bar = ceil_by_factor(width * beta, factor)
    return h_bar, w_bar


def postprocess_fn(example, region, image_size, processor):
    text = processor.apply_chat_template(
        example, tokenize=False, add_generation_prompt=False
    )
    return {
        "input_ids": processor.tokenizer.encode(
            text, return_tensors="pt"
        ),  # only used for batch grouping
        "conversation": example,
        "region": region,
        "image_size": image_size,
    }


def collate_fn(examples, processor):
    # Get the texts and images, and apply the chat template
    # examples, regions, image_sizes = zip(*examples)
    convs = [row["conversation"] for row in examples]
    regions = [row["region"] for row in examples]
    image_sizes = [row["image_size"] for row in examples]

    texts = [
        processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
        for conv in convs
    ]  # Prepare texts for processing
    image_inputs = [
        process_vision_info(conv)[0][0] for conv in convs
    ]  # Process the images to extract inputs
    image_inputs = [
        image.resize(image_size) for image, image_size in zip(image_inputs, image_sizes)
    ]

    # Tokenize the texts and process the images
    batch = processor(
        text=texts,
        images=image_inputs,
        videos=None,
        regions=regions,
        padding=True,
        return_tensors="pt",
    )  # Encode texts and images into tensors

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()  # Clone input IDs for labels
    labels[labels == processor.tokenizer.pad_token_id] = (
        -100
    )  # Mask padding tokens in labels

    # Ignore the image token index in the loss computation (model specific)
    # if isinstance(
    #     processor, "Qwen2VLImagePointerProcessor"
    # ):  # Check if the processor is Qwen2VLProcessor
    #     image_tokens = [
    #         151652,
    #         151653,
    #         151655,
    #     ]  # Specific image token IDs for Qwen2VLProcessor
    # else:
    #     image_tokens = [
    #         processor.tokenizer.convert_tokens_to_ids(processor.image_token)
    #     ]  # Convert image token to ID
    image_tokens = [
        151652,
        151653,
        151655,
    ]  # Specific image token IDs for Qwen2VLProcessor

    # Mask image token IDs in the labels
    for image_token_id in image_tokens:
        labels[labels == image_token_id] = -100  # Mask image token IDs in labels

    batch["labels"] = labels  # Add labels to the batch

    return batch  # Return the prepared batch
