# varco.py
# VARCO-VISION-2.0 backend (architecture: LLaVA-OneVision)
# Recommended upstream usage:
#   from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
#   model = LlavaOnevisionForConditionalGeneration.from_pretrained(...)
#   processor = AutoProcessor.from_pretrained(...)
#   inputs = processor.apply_chat_template(conversation, tokenize=True, return_dict=True, ...)

from typing import List, Dict, Tuple
import os
import sys

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 从 utils 导入通用工具（保持与现有后端一致的导入方式）
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()


class VarcoVisionVLM(BaseVLM):
    """
    NCSOFT VARCO-VISION-2.0 (LLaVA-OneVision) 适配类

    - 参考官方 README：使用 processor.apply_chat_template(..., tokenize=True, return_dict=True)
      直接构造可用于 model.generate 的 inputs
    - 支持单图/多图
    - 注意：同一 batch 内必须保持相同“模态结构”（比如：全单图 或 全多图），否则官方也不保证正确 batch 推理
    """

    def __init__(self, model, processor, device: str = "cuda"):
        tokenizer = getattr(processor, "tokenizer", None)
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            processor=processor,
            device=device,
        )

        # 1) 选主：processor.tokenizer 存在就用它（因为 processor 负责 tokenize）
        if self.processor is not None and getattr(self.processor, "tokenizer", None) is not None:
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = tokenizer  # 外部传入

        # 2) 强制绑定：让 processor.tokenizer 与 self.tokenizer 指向同一对象
        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer = self.tokenizer

        # R-4B 是 decoder-only 生成，左 padding 更常见；但不要强行把 pad_token 设成 eos（有些模型 pad_token_id != eos）
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
            if pad_id is not None:
                try:
                    tok = self.tokenizer.convert_ids_to_tokens(int(pad_id))
                    # 验证 round-trip，避免设了一个“看起来像 token 但不对应 pad_id”的字符串
                    if tok is not None and self.tokenizer.convert_tokens_to_ids(tok) == int(pad_id):
                        self.tokenizer.pad_token = tok
                except Exception:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None:
                pad_id = getattr(self.processor.tokenizer, "pad_token_id", None)
                if pad_id is not None:
                    try:
                        tok = self.processor.tokenizer.convert_ids_to_tokens(
                            self.processor.tokenizer.pad_token_id
                        )
                        if tok is not None and self.processor.tokenizer.convert_tokens_to_ids(tok) == int(pad_id):
                            self.processor.tokenizer.pad_token = tok
                    except Exception:
                        self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
                else:
                    self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

    def _build_conversation(self, item: dict):
        prompt = parse_input(item)
        image_paths = _normalize_to_list(get_image_path(item))

        if not image_paths:
            raise ValueError("VarcoVisionVLM 目前仅支持多模态输入（需要至少一张图片）")

        images: List[Image.Image] = []
        for p in image_paths:
            if p is None:
                continue
            images.append(Image.open(p).convert("RGB"))

        if len(images) == 0:
            raise ValueError("VarcoVisionVLM: 未成功加载任何图片，请检查 image_path 是否正确")

        # LLaVA-OneVision 风格的 conversation：content 是 list，image 项可以直接放 PIL.Image
        content = [{"type": "image", "image": img} for img in images]
        content.append({"type": "text", "text": prompt})

        conversation = [
            {
                "role": "user",
                "content": content,
            }
        ]
        return conversation, len(images)

    def _prepare_generation_config(self, max_new_tokens: int, gen_cfg: Dict) -> Dict:
        gen_conf = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **(gen_cfg or {}),
        )

        # eos_token_id
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None)) if self.tokenizer is not None else []
        eos_ids = list({e for e in (gc_eos + tok_eos) if e is not None})
        if eos_ids:
            gen_conf["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        # pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None and self.tokenizer is not None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None and self.tokenizer is not None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)

        if pad_id is not None:
            gen_conf["pad_token_id"] = pad_id
            # 同步到 model config，避免 generate 行为不一致
            self.model.config.pad_token_id = pad_id
            self.model.generation_config.pad_token_id = pad_id

        return gen_conf

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:
        assert self.processor is not None, "VarcoVisionVLM 需要 AutoProcessor"
        assert len(items) > 0, "generate_batch: items 不能为空"

        conversations: List[list] = []
        n_images_list: List[int] = []

        for it in items:
            conv, n_img = self._build_conversation(it)
            conversations.append(conv)
            n_images_list.append(n_img)

        # 官方说明：batch inference 要求同模态结构（这里用“图片张数一致”作为硬约束）
        if len(set(n_images_list)) != 1:
            raise ValueError(
                f"VarcoVisionVLM: 同一 batch 内图片数量必须一致（官方要求同模态结构），"
                f"当前={n_images_list}。请把不同图片数的样本分开跑。"
            )

        # 用 apply_chat_template 直接得到可用于 generate 的 inputs（官方写法）
        # 注意：这里 conversations 是“batch of conversations”（list of list-of-messages）
        inputs = self.processor.apply_chat_template(
            conversations,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            padding=True,
            return_tensors="pt",
        )

        # inputs.to(model.device, dtype)（参考官方）
        device = getattr(self.model, "device", None) or self.device
        model_dtype = getattr(self.model, "dtype", torch.float16)
        inputs = inputs.to(device, model_dtype)

        gen_conf = self._prepare_generation_config(max_new_tokens=max_new_tokens, gen_cfg=gen_cfg)

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        with torch.no_grad():
            output_ids = self.model.generate(**inputs, **gen_conf)

        # 按官方：逐样本用 len(in_ids) 去 trim 掉 prompt 部分
        trimmed: List[torch.Tensor] = []
        input_ids = inputs["input_ids"]
        for in_ids, out_ids in zip(input_ids, output_ids):
            trimmed.append(out_ids[len(in_ids):])

        # 右侧 pad / 多余 eos 处理 + hit_limit
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []
        cleaned_token_ids: List[List[int]] = []

        eos_ids = _normalize_to_list(gen_conf.get("eos_token_id", None))
        eos_ids = [e for e in eos_ids if e is not None]

        for seq in trimmed:
            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)
            seq2 = seq[:-cut] if cut > 0 else seq

            out_len = int(seq2.shape[0])

            ended_with_eos = False
            if out_len > 0 and eos_ids:
                ended_with_eos = int(seq2[-1].item()) in eos_ids

            hit_limits.append(out_len >= max_new_tokens and (not ended_with_eos))
            cleaned_token_ids.append(seq2.detach().cpu().tolist())

        # 解码（processor 在官方示例里有 decode/batch_decode）
        if hasattr(self.processor, "batch_decode"):
            texts = self.processor.batch_decode(
                cleaned_token_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )
        else:
            # 兜底：用 tokenizer 解码
            if self.tokenizer is None:
                raise ValueError("VarcoVisionVLM: processor 无 batch_decode 且 tokenizer 不存在，无法 decode")
            texts = [
                self.tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
                for ids in cleaned_token_ids
            ]

        outputs = [t.strip() for t in texts]

        # 保持与你现有后端一致的调试输出风格（需要可自行注释掉）
        for o in outputs:
            console.print("\n[yellow]output without prompt: ", o)

        return outputs, right_pad_lens, hit_limits

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outputs[0], right_pads[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len

