from typing import List, Dict, Tuple

import os
import sys

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 从 utils 导入 _normalize_to_list / parse_input / get_image_path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()


class QTuneQwenVLVLM(BaseVLM):
    """
    QTuneVL1.5 系列（Qwen2.5-VL 派生模型，如 hanchaow/QTuneVL1_5-3B）：
    - 直接走 model.generate（不依赖 model.chat / model.batch_chat）
    - prompt 用 tokenizer 的 chat_template（content 里包含 {"type":"image"} + {"type":"text"}）
    - inputs 用 AutoProcessor（同时处理 images + text）
    - 输出按 prefix_len_padded 截掉 prompt，只对“新生成 tokens”做 right_pad_len / hit_limit
    - right_pad_len 计算：完全复用你当前文件里的 calculate_right_padding_length（不改）
    """

    def __init__(self, model, tokenizer, processor=None, device="cuda"):
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        # QTuneVL (Qwen2.5-VL) 必须用能同时处理 text+image 的 processor（AutoProcessor/Qwen2_5_VLProcessor）
        assert self.processor is not None, "QTuneVLVLM 需要 AutoProcessor（能同时处理 images + text）"

        # 1) 左 padding（decoder-only + batch 生成必须）
        if getattr(self.tokenizer, "padding_side", None) != "left":
            self.tokenizer.padding_side = "left"
        if getattr(self.processor, "tokenizer", None) is not None:
            self.processor.tokenizer.padding_side = "left"

        # 2) 给个最小兜底：如果 tokenizer 没 pad_token 才设回 <|endoftext|>
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = "<|endoftext|>"
            self.tokenizer.pad_token_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
        if getattr(self.processor, "tokenizer", None) is not None and self.processor.tokenizer.pad_token is None:
            self.processor.tokenizer.pad_token = self.tokenizer.pad_token
            self.processor.tokenizer.pad_token_id = self.tokenizer.pad_token_id

        # === 额外修复2：确保 model 在 self.device（否则会出现 CUDA input vs CPU weight）===
        try:
            target = torch.device(self.device)
            # 只要发现有参数还在 CPU，就强制搬一次（避免 vision tower 留在 CPU）
            has_cpu_param = any(p.device.type == "cpu" for p in self.model.parameters())
            if has_cpu_param:
                self.model.to(target)
        except Exception as e:
            console.print(f"[yellow][QTuneVL] warning: move model to {self.device} failed, will follow model device at runtime: {e}[/yellow]")

    def _build_prompts(self, questions: List[str]) -> List[str]:
        prompts: List[str] = []
        for q in questions:
            q_text = (q or "").replace("<image>", "").strip()

            messages = [{
                "role": "user",
                "content": [
                    {"type": "image"},                 # 注意：就用占位符！让 template 插入 image_pad
                    {"type": "text", "text": q_text},
                ],
            }]

            try:
                prompt = self.processor.apply_chat_template(
                    messages, add_generation_prompt=True, tokenize=False
                )
            except Exception:
                prompt = ""

            # 核心硬保证：没有 image_pad 就直接兜底成 Qwen2.5-VL 预期格式
            if "<|image_pad|>" not in prompt:
                prompt = (
                    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
                    "<|im_start|>user\n"
                    "<|vision_start|><|image_pad|><|vision_end|>"
                    f"{q_text}"
                    "<|im_end|>\n"
                    "<|im_start|>assistant\n"
                )

            prompts.append(prompt)
        return prompts

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:

        # 1) 读取问题与图像
        questions: List[str] = []
        images: List[Image.Image] = []

        for item in items:
            q = parse_input(item)
            image_path = get_image_path(item)
            if not image_path:
                raise ValueError("QTuneVLVLM 目前仅支持多模态输入（需要 image_path）")

            questions.append(q)
            images.append(Image.open(image_path).convert("RGB"))

        # 2) chat template -> prompts
        prompts = self._build_prompts(questions)

        # 3) processor 打包 inputs（文本 + 图像）
        # padding=True 让 batch 对齐；left padding 由 tokenizer 控制
        inputs = self.processor(
            images=images,
            text=prompts,
            return_tensors="pt",
            padding=True,
            return_attention_mask=True,
        )

        # 4) 放到 device，并把 pixel_values 转成模型 dtype，避免 dtype mismatch
        target_device = torch.device(self.device)
        try:
            # 如果仍有 CPU 参数，优先尝试把整个模型搬到 target_device
            if any(p.device.type == "cpu" for p in self.model.parameters()):
                self.model.to(target_device)
        except Exception:
            # 搬不动就退回：跟随模型参数所在设备（至少不再 CPU/GPU 混用导致崩溃）
            try:
                target_device = next(self.model.parameters()).device
            except Exception:
                target_device = torch.device("cpu")

        # 4.2 再取模型 dtype（以“搬完之后”的 dtype 为准）
        try:
            model_dtype = next(self.model.parameters()).dtype
        except Exception:
            model_dtype = getattr(self.model, "dtype", torch.float16) or torch.float16

        # 4.3 inputs 搬到 target_device
        if hasattr(inputs, "to"):
            inputs = inputs.to(target_device)
        else:
            for k, v in list(inputs.items()):
                if torch.is_tensor(v):
                    inputs[k] = v.to(target_device)

        # 4.4 所有 float tensor 统一 cast 到 model_dtype（不只 pixel_values，更稳）
        for k, v in list(inputs.items()):
            if torch.is_tensor(v) and v.is_floating_point():
                inputs[k] = v.to(dtype=model_dtype)

        # 5) 生成配置（与 internvl_generate / qwen.py 同风格）
        gen_conf = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **gen_cfg,
        )

        if not gen_conf.get("do_sample", False):
            gen_conf.pop("temperature", None)

        # pad/eos：同步到 generation_config，确保 calculate_right_padding_length 与 hit_limit 判断一致
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)

        if pad_id is not None:
            gen_conf.setdefault("pad_token_id", pad_id)
            try:
                self.model.generation_config.pad_token_id = pad_id
                self.model.config.pad_token_id = pad_id
            except Exception:
                pass

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set([x for x in (gc_eos + tok_eos) if x is not None]))

        # console.print(f"[cyan][QTuneVL] 执行批处理，batch_size={len(items)}, max_new_tokens={max_new_tokens}[/cyan]")

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        # 6) generate + 截 prompt + 逐样本 right_pad/hit_limit + decode
        with torch.no_grad():
            out_ids = self.model.generate(**inputs, **gen_conf)

        if "input_ids" not in inputs:
            raise RuntimeError("processor 输出缺少 input_ids（请确认使用的是 AutoProcessor / Qwen2_5_VLProcessor）")

        prefix_len_padded = int(inputs["input_ids"].shape[1])
        new_tokens_all = out_ids[:, prefix_len_padded:]  # 只保留新生成 tokens（对齐 internvl_generate）

        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limit_flags: List[bool] = []

        for i in range(new_tokens_all.size(0)):
            seq = new_tokens_all[i]

            cut = self.calculate_right_padding_length(seq)  # 复用你现有逻辑
            right_pad_lens.append(cut)

            seq_trim = seq[:-cut] if cut > 0 else seq
            out_len = int(seq_trim.shape[0])

            ended_with_eos = False
            if out_len > 0 and len(eos_ids) > 0:
                last_id = int(seq_trim[-1].item())
                ended_with_eos = last_id in eos_ids

            hit_limit = (out_len >= max_new_tokens) and (not ended_with_eos)
            hit_limit_flags.append(hit_limit)

            text = self.tokenizer.decode(
                seq_trim,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )
            outputs.append(text)

        for o in outputs:
            console.print("\n[yellow]output without prompt: ", o)

        return outputs, right_pad_lens, hit_limit_flags

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outputs[0], right_pads[0], hit_limits[0]

    # ====== 下面这段：按你的要求“完全复用，不改” ======
    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
