from typing import List, Dict, Tuple
import os
import sys

import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from rich.console import Console
from torchvision.transforms.functional import InterpolationMode

from .base import BaseVLM

# 把项目根目录加入 sys.path，方便导入 utils
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

from utils import parse_input, get_image_path, _normalize_to_list  # noqa: E402

console = Console()

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def build_transform(input_size: int):
    """和官方示例保持一致的图像预处理."""
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose(
        [
            T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
            T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=MEAN, std=STD),
        ]
    )
    return transform


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float("inf")
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    """官方 InternVL/Vintern 的 dynamic_preprocess 实现（轻微整理）."""
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # 枚举所有候选网格比例
    target_ratios = {
        (i, j)
        for n in range(min_num, max_num + 1)
        for i in range(1, n + 1)
        for j in range(1, n + 1)
        if min_num <= i * j <= max_num
    }
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # 选一个和原图长宽比最接近的网格
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size
    )

    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # 先整体 resize，再按网格裁块
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)

    assert len(processed_images) == blocks

    # 额外缩略图（官方默认打开）
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)

    return processed_images


def load_image(image_file: str, input_size: int = 448, max_num: int = 12) -> torch.Tensor:
    """单张图片 -> 若干 tiles 的像素张量，形状 [num_patches, 3, H, W]."""
    image = Image.open(image_file).convert("RGB")
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(
        image, image_size=input_size, use_thumbnail=True, max_num=max_num
    )
    pixel_values = [transform(img) for img in images]
    pixel_values = torch.stack(pixel_values)  # [num_patches, 3, H, W]
    return pixel_values


class VinternVLM(BaseVLM):
    """Vintern-1B-v2 系列：基于官方 InternVL/Vintern 示例封装.

    关键点：
    - 批量推理必须使用 model.batch_chat（而不是循环调用 chat）
    - 官方 chat/batch_chat 只返回文本，不返回 output_ids
      因此 right_pad_lens 需要通过二次 tokenize 文本手动计算
    """

    def __init__(self, model, tokenizer, device: str = "cuda"):
        super().__init__(model=model, tokenizer=tokenizer, processor=None, device=device)

        # Intern 系列要求左侧 padding
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # 图像参数，和官方 README 保持一致
        self.image_size = 448
        self.max_num_tiles = 12

    # ------------------------------------------------------------------
    # 图像处理辅助
    # ------------------------------------------------------------------
    def _build_pixels_and_patches(
        self, items: List[dict]
    ) -> Tuple[torch.Tensor, List[int]]:
        """从 items 中读取图片，构造:
        - pixel_values: [sum(num_patches), 3, H, W]
        - num_patches_list: List[int]，长度 = batch_size
        """
        all_pixels: List[torch.Tensor] = []
        num_patches_list: List[int] = []

        for item in items:
            image_paths = get_image_path(item)
            image_paths = _normalize_to_list(image_paths)

            if not image_paths:
                raise ValueError("VinternVLM 目前仅支持多模态输入（需要至少一张图片）")

            # 为简单起见：只使用第一张图片（和 seedbench/大部分 VQA 数据集一致）
            img_path = image_paths[0]
            pixels = load_image(
                img_path,
                input_size=self.image_size,
                max_num=self.max_num_tiles,
            )  # [num_patches, 3, H, W]

            all_pixels.append(pixels)
            num_patches_list.append(pixels.shape[0])

        # 拼接成一个大的 pixel_values
        pixel_values = torch.cat(all_pixels, dim=0)  # [sum_patches, 3, H, W]

        # 移动到模型所在设备 & 对齐 dtype（通常是 bfloat16）
        pixel_values = pixel_values.to(self.device)
        try:
            model_dtype = next(self.model.parameters()).dtype
        except StopIteration:
            model_dtype = torch.bfloat16
        pixel_values = pixel_values.to(model_dtype)

        return pixel_values, num_patches_list

    # ------------------------------------------------------------------
    # 公共接口：批量 & 单样本推理
    # ------------------------------------------------------------------
    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        bs_estimate_gen_cfg: Dict,
        oom_estimate: bool
    ) -> Tuple[List[str], List[int], List[bool]]:
        """Vintern 批量多模态推理：绕过 batch_chat，直接 model.generate。
        right_pad_lens / hit_limits 严格复用你原来的 _postprocess_outputs（二次 tokenize 文本）。
        """
        batch_size = len(items)
        if batch_size == 0:
            return [], [], []

        # 1) questions
        questions: List[str] = [parse_input(it) for it in items]

        # 2) pixel_values + num_patches_list
        pixel_values, num_patches_list = self._build_pixels_and_patches(items)

        # 3) generation_config
        generation_config = dict(max_new_tokens=int(max_new_tokens))
        if gen_cfg is not None:
            generation_config.update(gen_cfg)
        generation_config["do_sample"] = False

        # 4) 获取 get_conv_template（remote code）
        import importlib
        pkg = self.model.__module__.rsplit(".", 1)[0]
        conv_mod = importlib.import_module(pkg + ".conversation")
        get_conv_template = getattr(conv_mod, "get_conv_template", None)
        if get_conv_template is None:
            raise RuntimeError(f"Cannot import get_conv_template from {pkg}.conversation")

        template_name = getattr(self.model, "template", None) or "Hermes-2"
        system_message = getattr(self.model, "system_message", "")

        # 5) 构造 queries（复刻 batch_chat 的 <image> 展开逻辑）
        IMG_START_TOKEN = "<img>"
        IMG_END_TOKEN = "</img>"
        IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"

        # batch_chat 里会设置这个，generate/prepare_inputs_for_generation 可能依赖它
        self.model.img_context_token_id = int(self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN))

        # sep 作为 eos（和 batch_chat 一致）
        tmp = get_conv_template(template_name)
        tmp.system_message = system_message
        sep = tmp.sep.strip()
        eos_token_id = int(self.tokenizer.convert_tokens_to_ids(sep)) if sep else int(self.tokenizer.eos_token_id)
        generation_config["eos_token_id"] = eos_token_id
        if "pad_token_id" not in generation_config:
            generation_config["pad_token_id"] = int(self.tokenizer.pad_token_id)

        generation_config.setdefault("do_sample", False)

        num_image_token = int(getattr(self.model, "num_image_token", 0))
        if num_image_token <= 0:
            raise RuntimeError("model.num_image_token is invalid; cannot expand <image>.")

        queries: List[str] = []
        for q, n_patches in zip(questions, num_patches_list):
            # 你原来没显式加 <image> 是因为 batch_chat 会补；这里我们自己补齐
            if "<image>" not in q:
                q = "<image>\n" + q

            t = get_conv_template(template_name)
            t.system_message = system_message
            t.append_message(t.roles[0], q)
            t.append_message(t.roles[1], None)
            query = t.get_prompt()

            image_tokens = (
                IMG_START_TOKEN
                + (IMG_CONTEXT_TOKEN * (num_image_token * int(n_patches)))
                + IMG_END_TOKEN
            )
            query = query.replace("<image>", image_tokens, 1)
            queries.append(query)

        # 6) tokenize（decoder-only：左 padding）
        old_side = getattr(self.tokenizer, "padding_side", "right")
        self.tokenizer.padding_side = "left"
        model_inputs = self.tokenizer(queries, return_tensors="pt", padding=True)
        self.tokenizer.padding_side = old_side

        input_ids = model_inputs["input_ids"].to(self.device)
        attention_mask = model_inputs.get("attention_mask", None)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.device)


        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
            model_inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=model_inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
            input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"]
        # ===========================================================================

        # 7) model.generate
        with torch.no_grad():
            try:
                gen_out = self.model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pixel_values=pixel_values,
                    # num_patches_list=num_patches_list,
                    **generation_config,
                )
            except TypeError as e:
                # 有些 remote code 不是叫 pixel_values，而是 images
                if "pixel_values" in str(e) and "unexpected keyword argument" in str(e):
                    gen_out = self.model.generate(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        images=pixel_values,
                        # num_patches_list=num_patches_list,
                        **generation_config,
                    )
                else:
                    raise

        # 8) 兼容两种返回：
        # - 如果 remote code 的 generate 直接返回 List[str]（已是“无 prompt 输出”），就直接用
        # - 否则当作 output_ids（Tensor 或 GenerateOutput），手动切 prompt 后 decode
        responses: List[str] = []

        sequences = getattr(gen_out, "sequences", gen_out)  # GenerateOutput / Tensor
        # 只取新生成部分 -> decode，这样 responses 才是 “without prompt”
        for i in range(batch_size):
            new_ids = sequences[i]
            # 截断到 eos（batch_chat 用 sep 当 eos）
            eos_pos = (new_ids == eos_token_id).nonzero(as_tuple=False)
            if eos_pos.numel() > 0:
                new_ids = new_ids[: int(eos_pos[0].item())]
            txt = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
            responses.append(txt)

        assert len(responses) == batch_size, (
            f"VinternVLM: generate 返回数量 {len(responses)} 不等于 batch_size={batch_size}"
        )

        # eos for hit_limit
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set([x for x in (gc_eos + tok_eos + [eos_token_id]) if x is not None]))

        decoded_token_ids: List[List[int]] = []
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []

        for i in range(batch_size):
            seq = sequences[i]

            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)
            if cut > 0:
                seq = seq[:-cut]

            out_len = int(seq.shape[0])

            ended_with_eos = False
            if out_len > 0 and eos_ids:
                ended_with_eos = seq[-1].item() in eos_ids

            hit_limits.append(out_len >= max_new_tokens and not ended_with_eos)
            decoded_token_ids.append(seq.detach().cpu().tolist())

        outputs = self.tokenizer.batch_decode(
            decoded_token_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )

        for o in outputs:
            console.print("\n[yellow]output without prompt:", o)

        return outputs, right_pad_lens, hit_limits

    def generate_one(
        self,
        item: dict,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Tuple[str, int, bool]:
        """单样本推理：直接复用 generate_batch，保证逻辑完全一致."""
        outputs, right_pad_lens, hit_limits = self.generate_batch(
            [item],
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
            oom_estimate=False,
            bs_estimate_gen_cfg={}
        )
        return outputs[0], right_pad_lens[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
