"""
InternVL3-8B inference adapter (Transformers + trust_remote_code).

Design goals:
- Match the Qwen2.5-VL calling style: (prompt, video_path, frame_indices) -> text + token stats.
- Merge multiple sampled frames into a single chat input (instead of per-frame calls).
- Support multi-GPU by constructing an official-style device_map to avoid device mismatch errors.
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from PIL import Image
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoConfig, AutoModel, AutoTokenizer

import decord


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def _build_transform(input_size: int) -> T.Compose:
    return T.Compose(
        [
            T.Lambda(lambda img: img.convert("RGB") if getattr(img, "mode", None) != "RGB" else img),
            T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ]
    )


def _find_closest_aspect_ratio(
    aspect_ratio: float,
    target_ratios: List[Tuple[int, int]],
    width: int,
    height: int,
    image_size: int,
) -> Tuple[int, int]:
    best_ratio_diff = float("inf")
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def _dynamic_preprocess(
    image: Image.Image,
    min_num: int = 1,
    max_num: int = 12,
    image_size: int = 448,
    use_thumbnail: bool = True,
) -> List[Image.Image]:
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / max(1, orig_height)

    target_ratios = sorted(
        {(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if min_num <= i * j <= max_num},
        key=lambda x: x[0] * x[1],
    )
    target_aspect_ratio = _find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    resized_img = image.resize((target_width, target_height))
    processed_images: List[Image.Image] = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
        processed_images.append(resized_img.crop(box))
    if use_thumbnail and len(processed_images) != 1:
        processed_images.append(image.resize((image_size, image_size)))
    return processed_images


def _load_frames_as_pixel_values(
    video_path: str,
    frame_indices: List[int],
    input_size: int = 448,
    max_num_tiles: int = 1,
) -> Tuple[torch.Tensor, List[int], List[float]]:
    """
    Load the given frame indices and apply InternVL dynamic tiling.

    Returns:
    - pixel_values: (sum(num_patches_i), 3, H, W)
    - num_patches_list: [patches_per_frame...]
    - timestamps_s: [t_per_frame...]
    """
    vr = decord.VideoReader(video_path, ctx=decord.cpu(0), num_threads=1)
    total = len(vr)
    fps = float(vr.get_avg_fps() or 0.0) or 30.0

    valid = sorted({int(i) for i in (frame_indices or []) if 0 <= int(i) < total})
    if not valid:
        valid = [0]

    transform = _build_transform(input_size=input_size)

    pixel_values_list: List[torch.Tensor] = []
    num_patches_list: List[int] = []
    timestamps_s: List[float] = []

    batch = vr.get_batch(valid).asnumpy()
    for idx, arr in zip(valid, batch):
        img = Image.fromarray(arr).convert("RGB")
        tiles = _dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num_tiles)
        pv = torch.stack([transform(tile) for tile in tiles], dim=0)
        pixel_values_list.append(pv)
        num_patches_list.append(int(pv.shape[0]))
        timestamps_s.append(float(idx) / fps if fps > 0 else 0.0)

    pixel_values = torch.cat(pixel_values_list, dim=0)
    return pixel_values, num_patches_list, timestamps_s


def _infer_internvl3_num_layers(cfg: Any) -> Optional[int]:
    for path in (
        ("llm_config", "num_hidden_layers"),
        ("language_model", "config", "num_hidden_layers"),
        ("language_model", "model", "config", "num_hidden_layers"),
        ("text_config", "num_hidden_layers"),
        ("num_hidden_layers",),
    ):
        cur = cfg
        ok = True
        for k in path:
            if not hasattr(cur, k):
                ok = False
                break
            cur = getattr(cur, k)
        if ok:
            try:
                return int(cur)
            except Exception:
                pass
    return None


def build_internvl3_device_map(model_path: str) -> Optional[Dict[str, int]]:
    """
    Build a device_map based on the official examples:
    - Place vision_model on GPU0.
    - Distribute LLM layers roughly evenly, but treat GPU0 as "half a GPU" (leave room for ViT).
    - Force the first/last LLM layer and output heads on GPU0 to avoid device mismatch during inference.
    """
    world_size = int(torch.cuda.device_count() or 0)
    if world_size <= 1:
        return None

    cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True, local_files_only=True)
    num_layers = _infer_internvl3_num_layers(cfg)
    if not num_layers:
        return None

    device_map: Dict[str, int] = {}
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    per = [num_layers_per_gpu] * world_size
    per[0] = math.ceil(per[0] * 0.5)

    layer_cnt = 0
    for i, n in enumerate(per):
        for _ in range(n):
            if layer_cnt >= num_layers:
                break
            device_map[f"language_model.model.layers.{layer_cnt}"] = i
            layer_cnt += 1

    # Vision modules
    device_map["vision_model"] = 0
    device_map["mlp1"] = 0

    # Place key LLM modules back on GPU0
    for k in (
        "language_model.model.tok_embeddings",
        "language_model.model.embed_tokens",
        "language_model.output",
        "language_model.model.norm",
        "language_model.model.rotary_emb",
        "language_model.lm_head",
    ):
        device_map[k] = 0

    # Ensure the first/last layers are on the same GPU
    device_map["language_model.model.layers.0"] = 0
    device_map[f"language_model.model.layers.{num_layers - 1}"] = 0
    return device_map


def _resolve_vision_device(model: Any) -> torch.device:
    try:
        if hasattr(model, "vision_model"):
            p = next(model.vision_model.parameters())
            return p.device
    except Exception:
        pass
    return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def _count_tokens(tokenizer: Any, text: str) -> int:
    try:
        ids = tokenizer(text, return_tensors=None, add_special_tokens=False).get("input_ids")
        if isinstance(ids, list):
            return len(ids)
        return int(len(ids[0]))  # type: ignore[index]
    except Exception:
        return 0


@dataclass
class InternVL3LoadOptions:
    torch_dtype: torch.dtype = torch.bfloat16
    use_flash_attn: bool = True
    load_in_8bit: bool = False
    local_files_only: bool = True


class InternVL3ModelHandler:
    def __init__(self, model_path: str, device_map: str = "auto", **kwargs):
        self.model_path = model_path
        self.device_map = device_map
        self.model = None
        self.tokenizer = None
        self.config = kwargs

    def load_model(self) -> Tuple[Any, Any]:
        if self.model is not None and self.tokenizer is not None:
            return self.model, self.tokenizer

        opts = InternVL3LoadOptions()
        dtype = opts.torch_dtype

        inferred_map = None
        if self.device_map in (None, "auto"):
            inferred_map = build_internvl3_device_map(self.model_path)
        elif isinstance(self.device_map, str) and self.device_map.lower() == "split":
            inferred_map = build_internvl3_device_map(self.model_path)

        kwargs = dict(
            torch_dtype=dtype,
            low_cpu_mem_usage=True,
            trust_remote_code=True,
            local_files_only=opts.local_files_only,
        )
        if opts.load_in_8bit:
            kwargs["load_in_8bit"] = True

        if inferred_map is not None:
            kwargs["device_map"] = inferred_map
        else:
            # Single GPU: load with the default device mapping.
            kwargs["device_map"] = "auto"

        # use_flash_attn is not supported by all transformers/remote-code versions; fall back on failure.
        try:
            self.model = AutoModel.from_pretrained(self.model_path, use_flash_attn=opts.use_flash_attn, **kwargs).eval()
        except TypeError:
            self.model = AutoModel.from_pretrained(self.model_path, **kwargs).eval()

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_path, trust_remote_code=True, use_fast=False, local_files_only=opts.local_files_only
        )

        try:
            setattr(self.model, "_scope_backend", "internvl3")
        except Exception:
            pass

        return self.model, self.tokenizer


@torch.no_grad()
def get_internvl3_response_generic(
    model: Any,
    tokenizer: Any,
    prompt: str,
    video_path: str,
    frame_indices: List[int],
    generation_kwargs: Dict,
    input_size: int = 448,
    max_num_tiles: int = 1,
) -> Dict[str, Any]:
    pixel_values, num_patches_list, timestamps_s = _load_frames_as_pixel_values(
        video_path, frame_indices, input_size=input_size, max_num_tiles=max_num_tiles
    )
    vision_device = _resolve_vision_device(model)

    # InternVL expects pixel_values dtype/device aligned with the vision module (often cuda:0 / bf16).
    try:
        pixel_values = pixel_values.to(dtype=torch.bfloat16, device=vision_device)
    except Exception:
        pixel_values = pixel_values.to(device=vision_device)

    # Merge multiple frames into a single chat input; use a prefix to provide temporal hints.
    video_prefix = "".join(
        [f"Frame{i + 1} (t={timestamps_s[i]:.2f}s): <image>\n" for i in range(len(num_patches_list))]
    )
    question = video_prefix + prompt

    # Map common generation_kwargs fields to InternVL generation_config.
    max_new_tokens = int(generation_kwargs.get("max_new_tokens", 1024) or 1024)
    temperature = float(generation_kwargs.get("temperature", 0.0) or 0.0)
    do_sample = bool(generation_kwargs.get("do_sample", False))
    num_beams = int(generation_kwargs.get("num_beams", 1) or 1)

    generation_config = dict(
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        num_beams=num_beams,
    )

    text = model.chat(
        tokenizer,
        pixel_values,
        question,
        generation_config,
        num_patches_list=num_patches_list,
    )

    prompt_tokens = _count_tokens(tokenizer, question)
    output_tokens = _count_tokens(tokenizer, str(text))

    return {
        "text": str(text) if text is not None else "",
        "tokens": {"prompt": prompt_tokens, "output": output_tokens, "total": prompt_tokens + output_tokens},
    }
