# ovis_llama.py
from __future__ import annotations

import os
import sys
from typing import List, Dict, Tuple, Any

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 复用项目内工具（与 qwen.py / phi.py 一致）
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()


class OvisLlamaVLM(BaseVLM):
    """
    AIDC-AI/Ovis1.5-Llama3-8B 适配类（Transformers + trust_remote_code=True）

    关键点：
      - tokenizer 只使用 model.get_text_tokenizer()（不混用）
      - conversation_formatter.format_query(...) 负责把 '<image>\\n{prompt}' 转为 input_ids
      - 批处理采用“真·batch inference”：逐样本 encode -> 手动 left-pad -> 一次 generate
      - right_pad_len 计算：直接复用你现有 qwen/phi 的 calculate_right_padding_length 实现
      - hit_limit：out_len >= max_new_tokens 且未以 eos 结束
      - attention_mask 的构造：严格复用你给的官方示例（input_ids != pad_token_id）
    """

    def __init__(self, model, device: str = "cuda"):
        # Ovis 内部区分 text/visual tokenizer
        self.text_tokenizer = model.get_text_tokenizer()
        self.visual_tokenizer = model.get_visual_tokenizer()
        self.conversation_formatter = model.get_conversation_formatter()

        # BaseVLM.tokenizer 只绑定 text_tokenizer，避免混用
        super().__init__(
            model=model,
            tokenizer=self.text_tokenizer,
            processor=None,
            device=device,
        )

        # pad_token_id 兜底（尽量不“改”tokenizer，只做必要的 fallback）
        if getattr(self.text_tokenizer, "pad_token_id", None) is None:
            eos_id = getattr(self.text_tokenizer, "eos_token_id", None)
            if eos_id is None:
                raise ValueError("OvisLlamaVLM: text_tokenizer 缺少 pad_token_id 且 eos_token_id 也为空")
            # 仅在运行期兜底使用（不强行 set pad_token 字符串）
            self.text_tokenizer.pad_token_id = eos_id

        # 同步到 model 的 generation_config（避免 generate 输出 padding 不一致）
        pad_id = int(self.text_tokenizer.pad_token_id)
        self.model.config.pad_token_id = pad_id
        self.model.generation_config.pad_token_id = pad_id

    @staticmethod
    def _pad_left(seqs: List[torch.Tensor], pad_token_id: int) -> torch.Tensor:
        """左侧 pad 到同一长度（decoder-only 常用）"""
        max_len = max(int(s.size(0)) for s in seqs)
        dtype = seqs[0].dtype
        device = seqs[0].device
        out = torch.full((len(seqs), max_len), int(pad_token_id), dtype=dtype, device=device)
        for i, s in enumerate(seqs):
            out[i, -int(s.size(0)):] = s
        return out

    def _prepare_generation_config(self, max_new_tokens: int, gen_cfg: Dict) -> Dict[str, Any]:
        """
        生成参数：以官方示例为基准 + 允许外部 gen_cfg 覆盖
        """
        gen_cfg = dict(gen_cfg or {})

        generation_config: Dict[str, Any] = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            top_p=None,
            top_k=None,
            temperature=None,
            repetition_penalty=None,
            use_cache=True,
            **gen_cfg,
        )

        # eos_token_id：优先用 model.generation_config（官方示例）
        eos_id = getattr(self.model.generation_config, "eos_token_id", None)
        if eos_id is None:
            eos_id = getattr(self.text_tokenizer, "eos_token_id", None)
        if eos_id is not None:
            generation_config["eos_token_id"] = eos_id

        # pad_token_id：强制用 text_tokenizer.pad_token_id（避免混用）
        pad_id = getattr(self.text_tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.text_tokenizer, "eos_token_id", None)
        if pad_id is None:
            raise ValueError("OvisLlamaVLM: 无法确定 pad_token_id")
        generation_config["pad_token_id"] = int(pad_id)

        # 同步 model 侧（对齐其它脚本行为）
        self.model.config.pad_token_id = int(pad_id)
        self.model.generation_config.pad_token_id = int(pad_id)

        return generation_config

    def _encode_one(self, item: dict) -> Tuple[torch.Tensor, List[Image.Image]]:
        """
        单样本 encode：
          - input_ids: 1D LongTensor
          - images: List[PIL.Image]（当前实现要求至少 1 张图）
        """
        prompt = parse_input(item)
        image_paths = _normalize_to_list(get_image_path(item))

        if not image_paths:
            raise ValueError("OvisLlamaVLM 目前仅支持多模态输入（需要至少一张图片）")

        images: List[Image.Image] = []
        for p in image_paths:
            if p is None:
                continue
            images.append(Image.open(p).convert("RGB"))

        if len(images) == 0:
            raise ValueError("OvisLlamaVLM: 未成功加载任何图片，请检查 image_path")

        # Ovis 官方示例：'<image>\\n{text}'
        # 多图：重复 <image> 占位符（若你实际不需要多图，可只用第一张）
        prefix = "<image>\n" * len(images)
        query = f"{prefix}{prompt}"

        _prompt_text, input_ids = self.conversation_formatter.format_query(query)

        if not isinstance(input_ids, torch.Tensor):
            input_ids = torch.tensor(input_ids, dtype=torch.long)
        input_ids = input_ids.to(device=self.device)

        return input_ids, images

    def _preprocess_images_batch(self, images_list: List[List[Image.Image]]) -> List[torch.Tensor]:
        """
        将 batch 的图片预处理成 pixel_values（与官方示例一致：pixel_values 是一个 list）
        约定：每个样本的图片数量必须一致，否则走 fallback（逐条 generate）。
        返回：
          pixel_values: List[Tensor]，长度 = num_images
            其中每个 Tensor 的第 0 维是 batch 维
        """
        num_images = len(images_list[0])
        for imgs in images_list:
            if len(imgs) != num_images:
                raise ValueError("OvisLlamaVLM: batch 内图片数量不一致，无法真·batch（请走单样本）")

        pv_list: List[torch.Tensor] = []
        vt = self.visual_tokenizer

        # 视觉 tokenizer 的 dtype/device 以它自己为准，但最终放到模型 device 上
        target_dtype = getattr(vt, "dtype", torch.float16)
        target_device = self.device

        for k in range(num_images):
            per_sample_tensors: List[torch.Tensor] = []
            for b in range(len(images_list)):
                t = vt.preprocess_image(images_list[b][k])
                if not isinstance(t, torch.Tensor):
                    t = torch.tensor(t)
                # 保证有 batch 维
                if t.dim() == 3:
                    t = t.unsqueeze(0)
                t = t.to(dtype=target_dtype, device=target_device)
                per_sample_tensors.append(t)

            # 这里用 cat（每个都是 [1, ...]）
            pv_k = torch.cat(per_sample_tensors, dim=0)
            pv_list.append(pv_k)

        return pv_list

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict,
    ) -> Tuple[List[str], List[int], List[bool]]:
        """
        真·batch 推理：
          1) 逐样本 format_query -> input_ids
          2) 手动 left-pad & stack
          3) attention_mask 直接复用官方示例：input_ids != pad_id
          4) preprocess images -> pixel_values(list)
          5) 一次性 model.generate
          6) 截掉前缀 -> 计算 right_pad_len -> decode
        """
        batch_size = len(items)
        assert batch_size > 0, "generate_batch: items 不能为空"

        # 1) 逐样本 encode
        input_ids_list: List[torch.Tensor] = []
        images_list: List[List[Image.Image]] = []
        for it in items:
            ids, imgs = self._encode_one(it)
            input_ids_list.append(ids)
            images_list.append(imgs)

        # 如果 batch 内图片数量不一致，直接退化到逐条推理（保证可用性）
        try:
            pixel_values = self._preprocess_images_batch(images_list)
        except Exception as e:
            console.log(f"[yellow]Warn: {e}，回退到单样本推理模式[/yellow]")
            outputs: List[str] = []
            right_pads: List[int] = []
            hit_limits: List[bool] = []
            for it in items:
                o, rp, hl = self.generate_one(it, max_new_tokens=max_new_tokens, gen_cfg=gen_cfg)
                outputs.append(o)
                right_pads.append(rp)
                hit_limits.append(hl)
            return outputs, right_pads, hit_limits

        # 2) left-pad & stack
        pad_id = int(self.text_tokenizer.pad_token_id)
        input_ids_batch = self._pad_left(input_ids_list, pad_token_id=pad_id).to(self.device)

        # 3) attention_mask：严格复用你给的官方示例写法
        attention_mask = torch.ne(input_ids_batch, pad_id).to(device=self.device)

        # ================================ 显存估计 ==================================
        # 仅扩文本侧（input_ids/attention_mask），pixel_values 不需要扩
        if oom_estimate:
            from utils import apply_prefill_extra_tokens

            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)

            inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask}
            inputs = apply_prefill_extra_tokens(
                batch_size=batch_size,
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.text_tokenizer,
                prefill_token_id=prefill_token_id,
            )
            input_ids_batch = inputs["input_ids"]
            attention_mask = inputs["attention_mask"]
        # ===========================================================================

        # 4) generation config
        generation_config = self._prepare_generation_config(max_new_tokens=max_new_tokens, gen_cfg=gen_cfg)

        # 5) generate
        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids_batch,
                pixel_values=pixel_values,
                attention_mask=attention_mask,
                **generation_config,
            )

        # 6) 截掉前缀（padded input_ids 的长度；与 phi.py 对齐）
        prefix_len_padded = int(input_ids_batch.shape[1])
        new_tokens_all = output_ids

        # eos_ids：用于 hit_limit 判定
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.text_tokenizer, "eos_token_id", None))
        eos_ids = list(set([x for x in (gc_eos + tok_eos) if x is not None]))

        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []

        decoded_token_ids: List[List[int]] = []

        for i in range(batch_size):
            seq = new_tokens_all[i]

            # right_pad_len：复用你现有实现（与 qwen/phi 对齐）
            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)
            seq = seq[:-cut] if cut > 0 else seq

            out_len = int(seq.shape[0])

            ended_with_eos = False
            if out_len > 0 and eos_ids:
                ended_with_eos = int(seq[-1].item()) in eos_ids

            hit_limits.append(out_len >= max_new_tokens and (not ended_with_eos))
            decoded_token_ids.append(seq.detach().cpu().tolist())

        # batch_decode：只用 text_tokenizer（不混用）
        texts = self.text_tokenizer.batch_decode(
            decoded_token_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )
        outputs = [t.strip() for t in texts]

        for o in outputs:
            console.print("\n[yellow]output without prompt: ", o)

        return outputs, right_pad_lens, hit_limits

    def generate_one(self, item: dict, max_new_tokens: int, gen_cfg: Dict) -> Tuple[str, int, bool]:
        outputs, right_pads, hit_limits = self.generate_batch(
            [item],
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
            oom_estimate=False,
            bs_estimate_gen_cfg={},
        )
        return outputs[0], right_pads[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
