import math
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode

# Constants for image normalization
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

### Utility Functions


def build_transform(input_size):
    """Builds a transformation pipeline for image preprocessing."""
    transform = T.Compose(
        [
            T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
            T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ]
    )
    return transform


def dynamic_preprocess(
    image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
):
    """Dynamically preprocesses an image into tiles based on aspect ratio."""
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height
    target_ratios = set(
        (i, j)
        for n in range(min_num, max_num + 1)
        for i in range(1, n + 1)
        for j in range(1, n + 1)
        if i * j <= max_num and i * j >= min_num
    )
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size
    )

    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)

    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)

    return processed_images


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    """Finds the closest aspect ratio from a set of target ratios."""
    best_ratio_diff = float("inf")
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def load_image(image_file, input_size=448, max_num=12):
    """Loads and preprocesses an image into tensor format."""
    image = Image.open(image_file).convert("RGB")
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(
        image, image_size=input_size, use_thumbnail=True, max_num=max_num
    )
    pixel_values = [transform(image) for image in images]
    return torch.stack(pixel_values)


def split_model(model_path):
    device_map = {}
    world_size = torch.cuda.device_count()
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    num_layers = config.llm_config.num_hidden_layers
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f"language_model.model.layers.{layer_cnt}"] = i
            layer_cnt += 1
    device_map["vision_model"] = 0
    device_map["mlp1"] = 0
    device_map["language_model.model.tok_embeddings"] = 0
    device_map["language_model.model.embed_tokens"] = 0
    device_map["language_model.output"] = 0
    device_map["language_model.model.norm"] = 0
    device_map["language_model.model.rotary_emb"] = 0
    device_map["language_model.lm_head"] = 0
    device_map[f"language_model.model.layers.{num_layers - 1}"] = 0
    return device_map


### Wrapper Class


class InternVLWrapper:
    def __init__(
        self,
        path: str = "OpenGVLab/InternVL2_5-26B-MPO",
        input_size: int = 448,
        max_num: int = 12,
        temperature: float = 0.9,
        max_tokens: int = 1024,
        **kwargs,
    ):
        """
        Initializes the InternVLWrapper with the specified model and parameters.

        Args:
            path (str): Path to the pretrained model.
            input_size (int): Size to resize images to.
            max_num (int): Maximum number of image tiles.
            temperature (float): Sampling temperature for generation.
            max_tokens (int): Maximum number of tokens to generate.
            **kwargs: Additional arguments.
        """
        self.path = path
        self.input_size = input_size
        self.max_num = max_num
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.fail_msg = "Failed to obtain answer via API."

        # Load model and tokenizer
        device_map = split_model(self.path)
        self.model = AutoModel.from_pretrained(
            self.path,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            use_flash_attn=True,
            trust_remote_code=True,
            device_map=device_map,
        ).eval()
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.path, trust_remote_code=True, use_fast=False
        )

    def get_prediction(self, image_path, prompt, passages=None, passage_prompt=None):
        generation_config = dict(
            max_new_tokens=self.max_tokens,
            do_sample=False,
            temperature=self.temperature,
        )
        prompt_parts = []
        image_count = 1
        pixel_values_list = []
        num_patches_list = []

        if passages:
            prompt_parts.append("\n--- Passages ---")  # Add a separator
            for i, passage in enumerate(passages):
                passage_text = passage.get("caption", "")
                passage_img_path = passage.get("image_path")
                if passage_img_path and isinstance(passage_img_path, str):
                    prompt_parts.append(f"Image-{image_count}: <image>")
                if passage_text and isinstance(passage_text, str):
                    prompt_parts.append("Caption: " + passage_text)
                pixel_values = load_image(
                    passage_img_path, input_size=self.input_size, max_num=self.max_num
                )
                pixel_values_list.append(pixel_values.to(torch.bfloat16).cuda())
                num_patches_list.append(pixel_values.size(0))
                image_count += 1

        if passage_prompt:
            prompt_parts.append(passage_prompt)

        pixel_values = load_image(
            image_path, input_size=self.input_size, max_num=self.max_num
        )
        pixel_values_list.append(pixel_values.to(torch.bfloat16).cuda())
        num_patches_list.append(pixel_values.size(0))
        pixel_values = torch.cat(pixel_values_list, dim=0)
        prompt_parts.append(
            f"Here is query image related to question.\nQuery Image-{image_count}: <image>\n"
        )
        prompt_parts.append(prompt)
        final_question = "\n".join(prompt_parts).strip()
        response = self.model.chat(
            self.tokenizer,
            pixel_values,
            final_question,
            generation_config,
            num_patches_list=num_patches_list,
            history=None,
            return_history=None,
        )
        return response
