from typing import List, Dict, Tuple
from .base import BaseVLM
from PIL import Image
import torch
from rich.console import Console

import sys
import os
import importlib

# 从utils导入_normalize_to_list函数
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path

console = Console()


class QTuneVLVLM(BaseVLM):
    """
    QTuneVL1.5 系列（InternVLChatModel）：参考 qwen.py 的真 batch 推理格式实现

    关键点：
    - tokenizer 只用外部传入的 tokenizer（不混用 processor.tokenizer）
    - left padding，对齐 decoder-only generate 行为
    - right_pad_len / hit_limit 与 qwen.py 完全一致
    - right_pad_len 计算直接复用你现有的 calculate_right_padding_length
    """

    def __init__(self, model, tokenizer, processor=None, device="cuda"):
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        # QTune / InternVL 这里 processor 通常是 AutoImageProcessor（不要用它的 tokenizer）
        assert self.processor is not None, "QTuneVLVLM 需要 image processor（如 AutoImageProcessor/AutoFeatureExtractor）"

        # 强制左 padding，避免 warning + 对齐生成逻辑
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # 一些模型会缺 pad_token_id
        if getattr(self.tokenizer, "pad_token_id", None) is None and getattr(self.tokenizer, "eos_token_id", None) is not None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

    def _get_conv_template_fn(self):
        """从 trust_remote_code 动态模块中拿到 get_conv_template"""
        # 典型：transformers_modules.xxx.yyy.modeling_internvl_chat
        base_pkg = self.model.__module__.rsplit(".", 1)[0]
        conv_mod_name = f"{base_pkg}.conversation"
        try:
            conv_mod = importlib.import_module(conv_mod_name)
            return getattr(conv_mod, "get_conv_template")
        except Exception as e:
            raise RuntimeError(
                f"无法导入 get_conv_template（尝试导入 {conv_mod_name} 失败）: {e}"
            )

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:
        # 1) 构造 questions + images
        questions: List[str] = []
        images: List[Image.Image] = []

        for item in items:
            prompt = parse_input(item)
            image_path = get_image_path(item)
            if not image_path:
                raise ValueError("QTuneVLVLM 目前仅支持多模态输入（需要 image_path）")

            # InternVL 默认期望 '<image>\n' 前缀
            if "<image>" not in prompt:
                prompt = "<image>\n" + prompt

            questions.append(prompt)
            images.append(Image.open(image_path).convert("RGB"))

        # 2) 图像 -> pixel_values
        # processor 输出一般是 float32；为避免 dtype mismatch，转成模型 dtype
        proc_out = self.processor(images=images, return_tensors="pt")
        if isinstance(proc_out, dict):
            pixel_values = proc_out.get("pixel_values", None)
        else:
            pixel_values = getattr(proc_out, "pixel_values", None)

        if pixel_values is None:
            raise RuntimeError("image processor 没有返回 pixel_values（请检查 processor 类型/版本）")

        model_dtype = getattr(self.model, "dtype", None)
        if model_dtype is None:
            # 兜底：跟随语言模型 dtype（如果存在）
            lm = getattr(self.model, "language_model", None)
            model_dtype = getattr(lm, "dtype", torch.float16) if lm is not None else torch.float16

        pixel_values = pixel_values.to(device=self.device, dtype=model_dtype)

        # 3) 构造 InternVL 模板 prompt（对齐 modeling_internvl_chat.py 的 batch_chat）
        get_conv_template = self._get_conv_template_fn()

        # 一般每个样本 1 张图 => 1 patch（更复杂的 dynamic patch 在外部做）
        num_patches_list = [1 for _ in items]

        IMG_START_TOKEN = "<img>"
        IMG_END_TOKEN = "</img>"
        IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"

        # 设置 img_context_token_id（InternVLChatModel.generate 依赖）
        self.model.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)

        queries: List[str] = []
        template_for_eos = None
        for idx, num_patches in enumerate(num_patches_list):
            q = questions[idx]

            template = get_conv_template(self.model.template)
            template.system_message = self.model.system_message
            template.append_message(template.roles[0], q)
            template.append_message(template.roles[1], None)

            query = template.get_prompt()
            image_tokens = IMG_START_TOKEN + (IMG_CONTEXT_TOKEN * self.model.num_image_token * num_patches) + IMG_END_TOKEN
            query = query.replace("<image>", image_tokens, 1)

            queries.append(query)
            template_for_eos = template  # 保存任意一个用于 eos 计算

        if template_for_eos is None:
            raise RuntimeError("构造 prompt 失败：template_for_eos is None")

        # 4) tokenize（left pad）
        self.tokenizer.padding_side = "left"
        model_inputs = self.tokenizer(queries, return_tensors="pt", padding=True)
        input_ids = model_inputs["input_ids"].to(self.device)
        attention_mask = model_inputs["attention_mask"].to(self.device)

        # 5) 生成配置（参考 qwen.py）
        gen_conf = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **gen_cfg,
        )

        # eos：对齐 InternVL 的 batch_chat 写法：使用 template.sep 对应 token id
        eos_token_id = self.tokenizer.convert_tokens_to_ids(template_for_eos.sep.strip())
        if eos_token_id is not None:
            gen_conf["eos_token_id"] = eos_token_id
            # 同步到 model.generation_config，确保 calculate_right_padding_length 能识别
            try:
                self.model.generation_config.eos_token_id = eos_token_id
                self.model.config.eos_token_id = eos_token_id
            except Exception:
                pass

        # 获取EOS token集合（用于 hit_limit 判断）
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set([x for x in (gc_eos + tok_eos) if x is not None]))

        # pad token
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)

        if pad_id is not None:
            gen_conf["pad_token_id"] = pad_id
            try:
                self.model.config.pad_token_id = pad_id
                self.model.generation_config.pad_token_id = pad_id
            except Exception:
                pass

        # console.print(f"[cyan][QTuneVL] 执行批处理，batch_size={len(items)}, max_new_tokens={max_new_tokens}[/cyan]")

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            batch_model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
            batch_model_inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=batch_model_inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
            input_ids, attention_mask = batch_model_inputs["input_ids"], batch_model_inputs["attention_mask"]
        # ===========================================================================

        with torch.no_grad():
            output_ids = self.model.generate(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                **gen_conf,
            )

            batch_size = input_ids.size(0)
            outputs: List[str] = []
            right_pad_lens: List[int] = []
            hit_limit_flags: List[bool] = []

            for i in range(batch_size):
                # 计算长度（与 qwen.py 保持一致）
                total_len = output_ids[i].size(0)

                cut = self.calculate_right_padding_length(output_ids[i])
                right_pad_lens.append(cut)
                output_len = total_len - cut

                generated_ids = output_ids[i][:-cut] if cut > 0 else output_ids[i]

                # hit limit
                ended_with_eos = False
                if output_len > 0 and generated_ids.numel() > 0:
                    last_token_id = int(generated_ids[-1].item())
                    if last_token_id in eos_ids:
                        ended_with_eos = True

                hit_limit = (output_len >= max_new_tokens) and (not ended_with_eos)
                hit_limit_flags.append(hit_limit)

                output_text = self.tokenizer.decode(
                    generated_ids,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=True,
                )
                outputs.append(output_text)

        for o in outputs:
            console.print("\n[yellow]output without prompt: ", o)

        return outputs, right_pad_lens, hit_limit_flags

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outputs[0], right_pads[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
