# llava-1_5.py
from typing import List, Dict, Tuple, Optional

import os
import sys

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 从 utils 导入通用工具
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()


class Llava15VLM(BaseVLM):
    """
    LLaVA-1.5 (llava-hf/llava-1.5-7b-hf) 适配类

    - 基于官方示例：processor.apply_chat_template + processor(images=..., text=...)
    - 你的数据集保证单图：每条 item 只取 1 张图
    - generate_batch 走“真 batch”（一次 processor，一次 generate）
    - 返回格式与 qwen.py / phi.py 保持一致：
        generate_batch -> (outputs, right_pad_lens, hit_limits)
        generate_one   -> (output_text, right_pad_len, hit_limit)

    注意：
    - right_pad 计算：直接复用你当前的 calculate_right_padding_length（原样复制）
    - hit_limit：沿用你在 Qwen/Phi 中的逻辑：out_len>=max_new_tokens 且未以 eos 结束
    """

    def __init__(self, model, processor, device: str = "cuda"):
        tokenizer = getattr(processor, "tokenizer", None)
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        # 强制左 padding（decoder-only 常见要求，避免 warning + 对齐截断逻辑）
        if self.tokenizer is not None:
            self.tokenizer.padding_side = "left"
            if self.tokenizer.pad_token is None and getattr(self.tokenizer, "eos_token", None) is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        # processor 内部 tokenizer 也同步
        if self.processor is not None and hasattr(self.processor, "tokenizer") and self.processor.tokenizer is not None:
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None and getattr(self.processor.tokenizer, "eos_token", None) is not None:
                self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

    @staticmethod
    def _build_conversation(user_text: str) -> List[Dict]:
        # 严格按官方示例：content 是 list[{"type": "text"/"image", ...}]
        # 官方示例是 text 然后 image
        return [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": user_text},
                    {"type": "image"},
                ],
            }
        ]

    def _prepare_generation_config(self, max_new_tokens: int, gen_cfg: Dict) -> Dict:
        gen_conf = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **(gen_cfg or {}),
        )

        # eos_token_id：合并 model.generation_config 与 tokenizer 的 eos
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list({e for e in (gc_eos + tok_eos) if e is not None})
        if eos_ids:
            gen_conf["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        # pad_token_id：优先 generation_config，其次 tokenizer，否则退到 eos
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None) if self.tokenizer is not None else None
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None) if self.tokenizer is not None else None

        if pad_id is not None:
            gen_conf["pad_token_id"] = pad_id
            # 同步到模型，确保 batch generate 右侧 padding 一致
            try:
                self.model.config.pad_token_id = pad_id
            except Exception:
                pass
            try:
                self.model.generation_config.pad_token_id = pad_id
            except Exception:
                pass

        return gen_conf

    def _move_and_cast_inputs(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        官方示例用 inputs.to(device, torch.float16)；
        这里做一个更稳健的版本：把浮点张量 cast 到模型参数 dtype，避免 dtype mismatch。
        """
        # 推断模型 dtype（比如 float16/bfloat16）
        try:
            model_dtype = next(self.model.parameters()).dtype
        except Exception:
            model_dtype = torch.float16

        moved: Dict[str, torch.Tensor] = {}
        for k, v in inputs.items():
            if not isinstance(v, torch.Tensor):
                moved[k] = v
                continue
            vv = v.to(self.device)
            if vv.is_floating_point() and vv.dtype != model_dtype:
                vv = vv.to(model_dtype)
            moved[k] = vv
        return moved

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool = False,
        bs_estimate_gen_cfg: Dict = {}
    ) -> Tuple[List[str], List[int], List[bool]]:
        assert self.processor is not None, "Llava15VLM 需要 AutoProcessor"
        assert len(items) > 0, "generate_batch: items 不能为空"

        # 1) 构造 conversations + texts + images（单图）
        texts: List[str] = []
        images: List[Image.Image] = []

        for item in items:
            prompt = parse_input(item)
            image_path = get_image_path(item)

            # 你保证单图：如果给成 list，就取第一张
            image_paths = _normalize_to_list(image_path)
            if not image_paths or image_paths[0] is None:
                raise ValueError("Llava15VLM 目前仅支持多模态输入（每条样本必须有 1 张图）")
            img_path = image_paths[0]

            conv = self._build_conversation(prompt)

            # apply_chat_template：显式 tokenize=False，得到 prompt string
            try:
                text = self.processor.apply_chat_template(
                    conv,
                    tokenize=False,
                    add_generation_prompt=True,
                )
            except TypeError:
                # 有些版本 processor.apply_chat_template 可能不支持 tokenize 参数
                text = self.processor.apply_chat_template(
                    conv,
                    add_generation_prompt=True,
                )

            texts.append(text)
            images.append(Image.open(img_path).convert("RGB"))

        # console.print(f"[cyan][LLaVA-1.5] 执行批处理，batch_size={len(items)}[/cyan]")

        # 2) processor batch encode（padding=True），然后搬运/转 dtype
        proc_inputs = self.processor(
            text=texts,
            images=images,
            padding=True,
            return_tensors="pt",
        )
        inputs = self._move_and_cast_inputs(dict(proc_inputs))

        # 3) generation config（pad/eos 对齐）
        gen_conf = self._prepare_generation_config(max_new_tokens=max_new_tokens, gen_cfg=gen_cfg)

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        # 4) generate
        with torch.no_grad():
            generate_ids = self.model.generate(**inputs, **gen_conf)

        # 5) 截掉前缀（长度 = padded input_ids 的长度）
        prefix_len_padded = inputs["input_ids"].shape[1]  # left-pad 后长度
        new_tokens_all = generate_ids[:, prefix_len_padded:]  # [B, T_out_max]

        # 6) right_pad + hit_limit + decode
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []
        decoded_token_ids: List[List[int]] = []

        # eos ids（用于 hit_limit 判定）
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None)) if self.tokenizer is not None else []
        eos_ids = list(set(gc_eos + tok_eos))

        batch_size = len(items)
        for i in range(batch_size):
            seq = new_tokens_all[i]

            # 右侧需要截断的长度：完全复用你当前的计算函数
            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)
            seq = seq[:-cut] if cut > 0 else seq

            out_len = int(seq.shape[0])

            ended_with_eos = False
            if out_len > 0 and eos_ids:
                if int(seq[-1].item()) in eos_ids:
                    ended_with_eos = True

            hit_limits.append(out_len >= max_new_tokens and not ended_with_eos)
            decoded_token_ids.append(seq.detach().cpu().tolist())

        # 一次性 batch_decode（优先走 processor）
        if hasattr(self.processor, "batch_decode"):
            texts_out = self.processor.batch_decode(
                decoded_token_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )
        else:
            texts_out = self.tokenizer.batch_decode(
                decoded_token_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )

        outputs = [t.strip() for t in texts_out]

        for out in outputs:
            console.print("\n[yellow]output without prompt: ", out)

        return outputs, right_pad_lens, hit_limits

    def generate_one(
        self,
        item: dict,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Tuple[str, int, bool]:
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outputs[0], right_pads[0], hit_limits[0]

    # ====== right_pad 计算：原样复用你现在的代码（不要动） ======
    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None and self.tokenizer is not None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None)) if self.tokenizer is not None else []
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
