from typing import List, Dict, Tuple, Optional
import os
import sys

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 通用工具
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()

class LLaVAVisionVLM(BaseVLM):
    """LLaVA-OneVision (llava-hf/*-hf) Transformers 推理后端

    ✅ 对齐 HuggingFace 官方示例：
      - 用 processor.apply_chat_template 构造 prompt
      - inputs = processor(images=..., text=..., return_tensors='pt', padding=True)
      - output = model.generate(**inputs, ...)

    ✅ 真 batch：一次 processor + 一次 generate
    ✅ right_pad：复用 calculate_right_padding_length

    约束：当前实现假设每条样本只有 1 张图片。
    """

    def __init__(
        self,
        model,
        tokenizer,
        processor,
        device: str = "cuda",
    ):
        # 优先用 processor 自带 tokenizer（避免 tokenizer/processor 不一致）
        proc_tok = getattr(processor, "tokenizer", None)
        if proc_tok is not None:
            tokenizer = proc_tok

        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        self.device = device

        # 记录模型 dtype（用于把 pixel_values 等 cast 到同一 dtype）
        try:
            self.model_dtype = next(self.model.parameters()).dtype
        except Exception:
            self.model_dtype = torch.float16

        # 强制左 padding（decoder-only 生成建议 left padding）
        if self.tokenizer is not None:
            self.tokenizer.padding_side = "left"
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        if self.processor is not None and hasattr(self.processor, "tokenizer") and self.processor.tokenizer is not None:
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None:
                self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

        # 防御：某些环境/版本下 processor.patch_size / vision_feature_select_strategy 可能缺失，导致内部 // None
        self._ensure_processor_vision_attrs()

    # ------------------------- prompt / preprocess -------------------------
    def _ensure_processor_vision_attrs(self) -> None:
        """给 processor 补齐 OneVision 预处理时可能需要的关键属性，避免 patch_size=None 导致 `// None`。"""
        if self.processor is None:
            return

        # patch_size
        patch_size = getattr(self.processor, "patch_size", None)
        if patch_size is None:
            patch_size = None
            cfg = getattr(self.model, "config", None)
            vis_cfg = getattr(cfg, "vision_config", None) if cfg is not None else None
            if vis_cfg is not None:
                patch_size = getattr(vis_cfg, "patch_size", None)

            if patch_size is not None:
                try:
                    setattr(self.processor, "patch_size", patch_size)
                except Exception:
                    pass

            # 处理器内部可能还有 image_processor
            img_proc = getattr(self.processor, "image_processor", None)
            if img_proc is not None and getattr(img_proc, "patch_size", None) is None and patch_size is not None:
                try:
                    setattr(img_proc, "patch_size", patch_size)
                except Exception:
                    pass

        # vision_feature_select_strategy
        vfs = getattr(self.processor, "vision_feature_select_strategy", None)
        if vfs is None:
            cfg = getattr(self.model, "config", None)
            # 有的 config 直接挂在顶层，有的在 vision_config
            cand = None
            if cfg is not None:
                cand = getattr(cfg, "vision_feature_select_strategy", None)
                if cand is None:
                    vis_cfg = getattr(cfg, "vision_config", None)
                    if vis_cfg is not None:
                        cand = getattr(vis_cfg, "vision_feature_select_strategy", None)
            if cand is not None:
                try:
                    setattr(self.processor, "vision_feature_select_strategy", cand)
                except Exception:
                    pass

    def _build_chat_prompt(self, user_text: str) -> str:
        """按 HF 官方示例构造 conversation + apply_chat_template。"""
        assert self.processor is not None
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": user_text},
                    {"type": "image"},
                ],
            }
        ]
        return self.processor.apply_chat_template(conversation, add_generation_prompt=True)

    def _build_prompt_and_image(self, item: dict) -> Tuple[str, Image.Image]:
        user_text = parse_input(item)
        image_paths = _normalize_to_list(get_image_path(item))
        if not image_paths:
            raise ValueError("LLaVAVisionVLM 目前仅支持多模态输入（需要至少一张图片）")
        if len(image_paths) != 1:
            raise ValueError("LLaVAVisionVLM 当前实现假设每个样本只有一张图片，请先规整数据。")

        img = Image.open(image_paths[0]).convert("RGB")
        prompt = self._build_chat_prompt(user_text)
        return prompt, img

    def _move_and_cast_inputs(self, inputs):
        """BatchEncoding.to 有时不支持 dtype 参数，这里做一层兼容。"""
        # 1) 先尽量走 BatchEncoding.to(device, dtype)
        try:
            inputs = inputs.to(self.device, self.model_dtype)
        except Exception:
            try:
                inputs = inputs.to(self.device)
            except Exception:
                # 兜底：逐项搬运
                for k, v in list(inputs.items()):
                    if isinstance(v, torch.Tensor):
                        inputs[k] = v.to(self.device)

        # 2) 把浮点张量（主要是 pixel_values / image_embeds）cast 到模型 dtype
        for k, v in list(inputs.items()):
            if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
                if v.dtype != self.model_dtype:
                    inputs[k] = v.to(dtype=self.model_dtype)
        return inputs

    # ------------------------- generation -------------------------
    def _prepare_generation_config(self, max_new_tokens: int, gen_cfg: Optional[Dict]) -> Dict:
        generation_config = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **(gen_cfg or {}),
        )

        # eos_token_id
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))
        if eos_ids and all(e is not None for e in eos_ids):
            generation_config["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        # pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)

        if pad_id is not None:
            generation_config["pad_token_id"] = pad_id
            try:
                self.model.config.pad_token_id = pad_id
            except Exception:
                pass
            try:
                self.model.generation_config.pad_token_id = pad_id
            except Exception:
                pass

        return generation_config

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool = False,
        bs_estimate_gen_cfg: Dict = {}
    ) -> Tuple[List[str], List[int], List[bool]]:
        assert self.processor is not None, "LLaVAVisionVLM 需要 AutoProcessor"
        batch_size = len(items)
        assert batch_size > 0, "generate_batch: items 不能为空"

        prompts: List[str] = []
        images: List[Image.Image] = []
        for item in items:
            p, img = self._build_prompt_and_image(item)
            prompts.append(p)
            images.append(img)

        console.print(f"[cyan]执行 LLaVA-OneVision(HF) 批处理，batch_size={batch_size}[/cyan]")

        # 一次性 encode（对齐官方：processor(images=..., text=..., padding=True)）
        inputs = self.processor(
            images=images,
            text=prompts,
            padding=True,
            return_tensors="pt",
        )
        inputs = self._move_and_cast_inputs(inputs)

        generation_config = self._prepare_generation_config(max_new_tokens=max_new_tokens, gen_cfg=gen_cfg)

         # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        with torch.no_grad():
            generate_ids = self.model.generate(**inputs, **generation_config)

        # 用 padded input_len 截掉 prompt 前缀，然后用 right_pad 逻辑截尾
        prefix_len_padded = inputs["input_ids"].shape[1]
        new_tokens_all = generate_ids[:, prefix_len_padded:]

        # eos for hit_limit
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        decoded_token_ids: List[List[int]] = []
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []

        for i in range(batch_size):
            seq = new_tokens_all[i]

            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)
            if cut > 0:
                seq = seq[:-cut]

            out_len = int(seq.shape[0])

            ended_with_eos = False
            if out_len > 0 and eos_ids:
                ended_with_eos = seq[-1].item() in eos_ids

            hit_limits.append(out_len >= max_new_tokens and not ended_with_eos)
            decoded_token_ids.append(seq.detach().cpu().tolist())

        texts_out = self.tokenizer.batch_decode(
            decoded_token_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )
        outputs = [t.strip() for t in texts_out]

        for out in outputs:
            console.print("[yellow]LLaVA output (without prompt):", out)

        return outputs, right_pad_lens, hit_limits

    def generate_one(
        self,
        item: dict,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Tuple[str, int, bool]:
        outputs, right_pad_lens, hit_limits = self.generate_batch(
            [item],
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
            oom_estimate=False,
            bs_estimate_gen_cfg={}
        )
        return outputs[0], right_pad_lens[0], hit_limits[0]

    # ======== 右侧截断：直接复用你现有逻辑 ========
    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # pad_id == eos_id 且多截了一个真正的 eos 的情况
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
