from typing import List, Dict, Tuple, Any, Optional
from .base import BaseVLM
from PIL import Image
import torch
from rich.console import Console
from collections import defaultdict
import importlib

# 从utils导入_normalize_to_list/parse_input/get_image_path（保持与 qwen.py/phi.py 一致的项目结构）
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path

console = Console()


class VilaVLM(BaseVLM):
    """
    Efficient-Large-Model/VILA15-3b-hf-preview 推理后端（HF preview / trust_remote_code）

    关键点：
    - tokenizer 严格使用 model.tokenizer（避免混用）
    - 推理路径使用：model._embed(...) -> model.llm.generate(inputs_embeds=...)
      （对齐 generate_content 的内部实现，避免直接 model.generate 产生的 dummy prefix 导致乱码/AAAA）
    - right_pad 计算：完全复用你现有的 calculate_right_padding_length（不做任何修改）
    - hit_limit：与 qwen.py/phi.py 同口径
    """

    def __init__(self, model, tokenizer=None, processor=None, model_key="", device: str = "cuda"):
        # 1) 严格只用一个 tokenizer：优先使用 model.tokenizer（VILA remote-code 会挂在这里）
        tok = getattr(model, "tokenizer", None)
        if tok is None:
            # 兜底：仅当 model.tokenizer 不存在时才尝试外部传入，仍然保证只使用一个 tokenizer
            tok = tokenizer or getattr(processor, "tokenizer", None)

        if tok is None:
            raise ValueError(
                "VILA15 后端无法找到 tokenizer：请确保使用 trust_remote_code=True 加载模型（model.tokenizer 会存在）。"
            )

        super().__init__(model=model, tokenizer=tok, processor=None, device=device)

        self.model = model
        self.tokenizer = tok
        self.device = device
        # self.model_key=model_key

        print(self.model_key)

        # 2) 统一 left padding（与 qwen.py/phi.py 对齐）
        self.tokenizer.padding_side = "left"

        # 3) pad_token_id 仅在缺失时兜底（不要把 0 当异常：很多 tokenizer 的 pad_token_id 就是 0）
        if getattr(self.tokenizer, "pad_token_id", None) is None and getattr(self.tokenizer, "eos_token_id", None) is not None:
            self.tokenizer.pad_token_id = int(self.tokenizer.eos_token_id)

        # 同步到模型，供 generate / right_pad 逻辑读取
        if getattr(self.model, "generation_config", None) is not None:
            if getattr(self.model.generation_config, "pad_token_id", None) is None and getattr(self.tokenizer, "pad_token_id", None) is not None:
                self.model.generation_config.pad_token_id = int(self.tokenizer.pad_token_id)
        if getattr(self.model, "config", None) is not None:
            if getattr(self.model.config, "pad_token_id", None) is None and getattr(self.tokenizer, "pad_token_id", None) is not None:
                self.model.config.pad_token_id = int(self.tokenizer.pad_token_id)

        # 4) dtype（图像 tensor 需要 cast 到模型 dtype）
        self._model_dtype = next(self.model.parameters()).dtype

        # 5) 导入 VILA repo 内工具函数
        base_mod = self.model.__class__.__module__.rsplit(".", 1)[0]
        self._mm_utils = importlib.import_module(f"{base_mod}.mm_utils")
        self._tok_utils = importlib.import_module(f"{base_mod}.tokenizer_utils")
        self._media_mod = importlib.import_module(f"{base_mod}.media")

        self._process_images = getattr(self._mm_utils, "process_images")
        self._tokenize_conversation = getattr(self._tok_utils, "tokenize_conversation")
        self._extract_media = getattr(self._media_mod, "extract_media")

        try:
            tid = self.tokenizer.convert_tokens_to_ids("<image>")
            unk = getattr(self.tokenizer, "unk_token_id", None)
            if unk is not None and tid == unk:
                print("[VILA][warn] '<image>' 被分到 unk，说明 tokenizer 不认识 image token，会导致 embeddings 无法消费")
        except Exception:
            pass

        # 6) 检查关键组件
        if not hasattr(self.model, "_embed"):
            raise RuntimeError("VILA15 模型缺少 _embed 方法：请确认加载的是 VILA remote-code wrapper。")
        if not hasattr(self.model, "llm"):
            raise RuntimeError("VILA15 模型缺少 llm 子模型：请确认加载的是 VILA remote-code wrapper。")
        if not hasattr(self.model, "vision_tower") or not hasattr(self.model.vision_tower, "image_processor"):
            raise RuntimeError("VILA15 模型缺少 vision_tower.image_processor，无法进行图像预处理。")

    # ======== right_pad：完全复用你现有 qwen.py 里的实现（不做任何修改）========
    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
    # ======================================================================

    def _build_one(self, item: Dict[str, Any]) -> Tuple[List[int], Image.Image]:
        """把单样本 item 转为 (input_ids(list[int]), PIL image)。"""
        prompt = parse_input(item)
        image_path = get_image_path(item)
        if image_path is None:
            raise ValueError("VILA15 输入缺少 image_path（get_image_path(item) 返回 None）")

        raw_image = Image.open(image_path).convert("RGB")

        # VILA extract_media 支持 value=[PIL.Image, text]，会自动插入 <image>
        conversation = [{"from": "human", "value": [raw_image, prompt]}]
        media = self._extract_media(conversation, config=self.model.config)

        # 你数据单图：这里默认期望 1 张图
        if "image" not in media or len(media["image"]) != 1:
            raise ValueError(
                f"VILA15 期望单图输入，但 extract_media 得到 image 数量={len(media.get('image', []))}"
            )

        input_ids = self._tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True)
        if isinstance(input_ids, torch.Tensor):
            input_ids = input_ids.tolist()

        # ---- 新增：确保 conversation 被正确改写，至少包含 <image> ----
        val = conversation[0].get("value", "")
        if not isinstance(val, str) or ("<image>" not in val):
            # 强制改成最常见的格式
            conversation[0]["value"] = "<image>\n" + str(prompt)

        return input_ids, media["image"][0]

    def generate_one(
        self,
        item: Dict[str, Any],
        max_new_tokens: int = 1024,
        gen_cfg: Optional[Dict[str, Any]] = None,
    ) -> Tuple[str, int, bool]:
        outs, rps, hits = self.generate_batch([item], max_new_tokens=max_new_tokens, gen_cfg=gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outs[0], rps[0], hits[0]

    @torch.inference_mode()
    def generate_batch(
        self,
        items: List[Dict[str, Any]],
        max_new_tokens: int = 1024,
        gen_cfg: Optional[Dict[str, Any]] = None,
        oom_estimate: bool = False,
        bs_estimate_gen_cfg: Dict = {}
    ) -> Tuple[List[str], List[int], List[bool]]:
        gen_cfg = gen_cfg or {}
        if not items:
            return [], [], []

        # VILA15 在 image_aspect_ratio 非 resize/pad（dynamic/dynamic_s2）时会改 prompt/切 tiles
        # 这里保守处理：batch>1 时回退到逐样本（保证正确性）
        iar = getattr(self.model.config, "image_aspect_ratio", "resize")
        if len(items) > 1 and iar not in ("resize", "pad"):
            outs, rps, hits = [], [], []
            for it in items:
                o, rp, h = self.generate_one(it, max_new_tokens=max_new_tokens, gen_cfg=gen_cfg)
                outs.append(o)
                rps.append(rp)
                hits.append(h)
            return outs, rps, hits

        # 1) 构造 batch 的 input_ids(list) 与 PIL images
        ids_list: List[List[int]] = []
        pil_images: List[Image.Image] = []
        for it in items:
            ids, im = self._build_one(it)
            ids_list.append(ids)
            pil_images.append(im)

        # 2) tokenizer padding（left pad）
        enc = [{"input_ids": ids} for ids in ids_list]
        inputs = self.tokenizer.pad(enc, return_tensors="pt", padding=True)

        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs.get("attention_mask", None)
        if attention_mask is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
            if pad_id is None:
                attention_mask = torch.ones_like(input_ids, dtype=torch.bool, device=self.device)
            else:
                attention_mask = input_ids.ne(int(pad_id))
        attention_mask = attention_mask.to(self.device).to(torch.bool)

        # 3) 图像预处理（process_images -> Tensor[B, C, H, W]）
        images_tensor = self._process_images(pil_images, self.model.vision_tower.image_processor, self.model.config)
        if isinstance(images_tensor, torch.Tensor) and images_tensor.dim() == 3:
            images_tensor = images_tensor.unsqueeze(0)
        if not isinstance(images_tensor, torch.Tensor) or images_tensor.dim() != 4:
            raise RuntimeError(
                f"VILA15 process_images 返回异常：type={type(images_tensor)}, shape={getattr(images_tensor, 'shape', None)}"
            )

        images_tensor = images_tensor.to(device=self.device, dtype=self._model_dtype)

        # VILA _embed 期望 media['image'] 是 list[tensor]，每个 tensor 对应一张图
        media = {"image": [images_tensor[i] for i in range(images_tensor.size(0))]}
        media_config = defaultdict(dict)

        # ===================================================================
        # 3) 图像预处理后（images_tensor 已经是 [B, C, H, W]）
        B = images_tensor.size(0)

        # ---- 新增：检查 image token 是否存在、数量是否正确 ----
        img_tid = None
        if hasattr(self.tokenizer, "media_token_ids") and isinstance(self.tokenizer.media_token_ids, dict):
            img_tid = self.tokenizer.media_token_ids.get("image", None)

        if img_tid is None:
            # 兜底：有些 tokenizer 没挂 media_token_ids，但仍有 "<image>" 特殊词
            try:
                img_tid = self.tokenizer.convert_tokens_to_ids("<image>")
            except Exception:
                img_tid = None

        if img_tid is not None:
            # 每条样本统计“有效 token 中的 <image> 数”
            per_sample_img_tok = []
            for i in range(input_ids.size(0)):
                m = attention_mask[i].bool()
                per_sample_img_tok.append(int(((input_ids[i] == int(img_tid)) & m).sum().item()))
            if any(x != 1 for x in per_sample_img_tok) or B != len(items):
                print("[VILA][debug] img_tid=", img_tid, "per_sample_img_tok=", per_sample_img_tok, "num_images(B)=", B, "bs=", len(items))
        # =========================================================

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            batch_model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
            batch_model_inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=batch_model_inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
            input_ids, attention_mask = batch_model_inputs["input_ids"], batch_model_inputs["attention_mask"]
        # ===========================================================================

        # 4) _embed 获取 inputs_embeds / attention_mask（prefix_len 用 inputs_embeds 的长度）
        inputs_embeds, _, attn2 = self.model._embed(
            input_ids=input_ids,
            media=media,
            media_config=media_config,
            labels=None,
            attention_mask=attention_mask,
        )
        prefix_len_padded = int(inputs_embeds.shape[1])

        # 5) 生成参数（尽量只传 llm.generate 认可的字段；不要混用 tokenizer）
        do_sample = bool(gen_cfg.get("do_sample", False))
        temperature = gen_cfg.get("temperature", None)
        top_p = gen_cfg.get("top_p", None)
        top_k = gen_cfg.get("top_k", None)
        repetition_penalty = gen_cfg.get("repetition_penalty", None)
        num_beams = gen_cfg.get("num_beams", None)

        # pad/eos：union(model.generation_config, tokenizer)
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = [e for e in set(gc_eos + tok_eos) if e is not None]
        eos_arg = eos_ids[0] if len(eos_ids) == 1 else eos_ids if len(eos_ids) > 1 else None

        gen_kwargs: Dict[str, Any] = {
            "max_new_tokens": max_new_tokens,
            "do_sample": do_sample,
            "pad_token_id": int(pad_id) if pad_id is not None else None,
            "eos_token_id": eos_arg,
            "repetition_penalty": repetition_penalty,
            "num_beams": num_beams,
        }

        # sampling 参数只在 do_sample 时传，避免部分实现报 “ignored flags”
        if do_sample:
            if temperature is not None:
                gen_kwargs["temperature"] = float(temperature)
            if top_p is not None:
                gen_kwargs["top_p"] = float(top_p)
            if top_k is not None and int(top_k) > 0:
                gen_kwargs["top_k"] = int(top_k)

        # 去掉 None
        gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}

        # 6) 生成（用 llm.generate，避免 wrapper.generate 的 dummy prefix 处理差异）
        output_ids = self.model.llm.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attn2,
            **gen_kwargs,
        )

        # 7) 解码 + right_pad + hit_limit（与 qwen.py 同口径）
        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []

        for i in range(len(items)):
            seq_all = output_ids[i]
            total_len = int(seq_all.size(0))

            cut = self.calculate_right_padding_length(seq_all)
            right_pad_lens.append(cut)

            output_len = total_len - cut
            gen_ids = seq_all[:-cut] if cut > 0 else seq_all

            ended_with_eos = False
            if gen_ids.numel() > 0 and eos_ids:
                ended_with_eos = int(gen_ids[-1].item()) in set(int(x) for x in eos_ids if x is not None)

            hit_limit = (output_len >= max_new_tokens) and (not ended_with_eos)
            hit_limits.append(hit_limit)

            outputs.append(self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip())

        for o in outputs:
            console.print("\n[yellow]output without prompt:", o)

        return outputs, right_pad_lens, hit_limits
