# sharegpt4v.py
from __future__ import annotations

from typing import Dict, List, Tuple, Any
import os
import sys

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 复用项目内工具
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()


class ShareGPT4VVLM(BaseVLM):
    """
    gokul-mbzuai/sharegpt4v-7b-hf (LLaVA / LlavaForConditionalGeneration)
    关键点：
      - tokenizer/processor.tokenizer 强制 left padding
      - prompt 必须包含 <image>（否则 tokens=0, features>0 会炸）
      - calculate_right_padding_length 原样复用 qwen.py 逻辑
    """

    def __init__(self, model, tokenizer, processor=None, device: str = "cuda"):
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)
        assert self.processor is not None, "ShareGPT4VVLM 需要 AutoProcessor(LlavaProcessor)"

        # 1) 选主：processor.tokenizer 存在就用它（因为 processor 负责 tokenize）
        if self.processor is not None and getattr(self.processor, "tokenizer", None) is not None:
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = tokenizer  # 外部传入

        # 2) 强制绑定：让 processor.tokenizer 与 self.tokenizer 指向同一对象
        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer = self.tokenizer

        # R-4B 是 decoder-only 生成，左 padding 更常见；但不要强行把 pad_token 设成 eos（有些模型 pad_token_id != eos）
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
            if pad_id is not None:
                try:
                    tok = self.tokenizer.convert_ids_to_tokens(int(pad_id))
                    # 验证 round-trip，避免设了一个“看起来像 token 但不对应 pad_id”的字符串
                    if tok is not None and self.tokenizer.convert_tokens_to_ids(tok) == int(pad_id):
                        self.tokenizer.pad_token = tok
                except Exception:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None:
                pad_id = getattr(self.processor.tokenizer, "pad_token_id", None)
                if pad_id is not None:
                    try:
                        tok = self.processor.tokenizer.convert_ids_to_tokens(
                            self.processor.tokenizer.pad_token_id
                        )
                        if tok is not None and self.processor.tokenizer.convert_tokens_to_ids(tok) == int(pad_id):
                            self.processor.tokenizer.pad_token = tok
                    except Exception:
                        self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
                else:
                    self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

        # 2) 兜底补齐 llava processor 的关键字段（避免 patch_size=None 那个 // NoneType）
        self._fix_llava_processor_fields()

    def _fix_llava_processor_fields(self):
        proc = self.processor
        cfg = getattr(self.model, "config", None)

        # 1) patch_size / vision_feature_select_strategy：跟 config 对齐
        if getattr(proc, "patch_size", None) is None:
            ps = getattr(getattr(cfg, "vision_config", None), "patch_size", None)
            if ps is not None:
                proc.patch_size = int(ps)

        if getattr(proc, "vision_feature_select_strategy", None) is None:
            vfs = getattr(cfg, "vision_feature_select_strategy", None)
            if vfs is not None:
                proc.vision_feature_select_strategy = vfs

        # 2) 关键：CLIP 有 CLS -> 需要 +1（否则默认策略会导致 575 vs 576）
        #    字段名必须叫 num_additional_image_tokens（HF 文档就是这个名字）
        nai = getattr(proc, "num_additional_image_tokens", None)
        if nai is None or nai == 0:
            proc.num_additional_image_tokens = 1


    def _build_text_prompts(self, items: List[dict]) -> List[str]:
        texts: List[str] = []
        for item in items:
            prompt = parse_input(item)

            # 优先走 HF 推荐的 conversation + apply_chat_template
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": prompt},
                    ],
                }
            ]

            if hasattr(self.processor, "apply_chat_template"):
                try:
                    txt = self.processor.apply_chat_template(
                        conversation,
                        tokenize=False,
                        add_generation_prompt=True,
                    )
                    texts.append(txt)
                    continue
                except Exception:
                    pass

            # fallback：按 llava-1.5 模板手拼（HF 文档示例）
            texts.append(f"USER: <image>\n{prompt} ASSISTANT:")

        return texts

    def _load_images(self, items: List[dict]) -> List[Image.Image]:
        images: List[Image.Image] = []
        for item in items:
            image_path = get_image_path(item)
            if not image_path:
                raise ValueError("ShareGPT4VVLM 当前实现仅支持多模态输入（需要 image_path）")
            images.append(Image.open(image_path).convert("RGB"))
        return images

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len


    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict[str, Any],
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:

        texts = self._build_text_prompts(items)
        images = self._load_images(items)

        # processor -> batch inputs（尽量指定 padding_side="left"）
        try:
            inputs = self.processor(
                text=texts,
                images=images,
                padding=True,
                padding_side="left",
                return_tensors="pt",
            )
        except TypeError:
            inputs = self.processor(
                text=texts,
                images=images,
                padding=True,
                return_tensors="pt",
            )

        inputs = inputs.to(self.device)

        # 关键 sanity check：必须有 <image> token（否则就是你现在 tokens=0 那个崩溃）
        image_tok = getattr(self.model.config, "image_token_index", None)
        if image_tok is not None:
            cnt = (inputs["input_ids"] == int(image_tok)).sum(dim=1).tolist()
            if any(c == 0 for c in cnt):
                bad = [i for i, c in enumerate(cnt) if c == 0]
                raise RuntimeError(f"ShareGPT4VVLM: 有样本没有 <image> token，idx={bad}。你的 prompt 组装有问题。")

        # generate config（跟 qwen/r4b 一样尊重 eos/pad）
        gen_conf = dict(max_new_tokens=max_new_tokens, do_sample=False, **(gen_cfg or {}))

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list({e for e in (gc_eos + tok_eos) if e is not None})
        if eos_ids:
            gen_conf["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is not None:
            gen_conf["pad_token_id"] = int(pad_id)
            self.model.config.pad_token_id = int(pad_id)
            self.model.generation_config.pad_token_id = int(pad_id)

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        with torch.no_grad():
            output_ids = self.model.generate(**inputs, **gen_conf)

        # 去掉输入前缀（注意：这里用 padded input_len）
        input_len_padded = inputs["input_ids"].shape[1]

        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []

        eos_set = set(eos_ids or [])

        for i in range(output_ids.shape[0]):
            seq = output_ids[i]
            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)

            gen_ids = seq[input_len_padded:-cut] if cut > 0 else seq[input_len_padded:]
            out_len = int(gen_ids.numel())

            ended_with_eos = (out_len > 0) and (int(gen_ids[-1].item()) in eos_set)
            hit_limits.append((out_len >= max_new_tokens) and (not ended_with_eos))

            txt = self.tokenizer.decode(gen_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            outputs.append(txt)

        for o in outputs:
            console.print("\n[yellow]output without prompt: ", o)

        return outputs, right_pad_lens, hit_limits

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outputs[0], right_pads[0], hit_limits[0]

