# granite.py
# IBM Granite Vision (AutoModelForVision2Seq) 推理适配：按 qwen.py 的写法（processor 负责 batch padding）

import os
import sys
from typing import List, Dict, Tuple

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 从 utils 导入通用工具
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()


class GraniteVisionVLM(BaseVLM):
    """
    IBM Granite Vision (e.g. ibm-granite/granite-vision-3.2-2b)

    - 参考 qwen.py：先 apply_chat_template(tokenize=False) 拿 prompt 文本，再交给 processor(text, images, padding=True)
    - right_pad 计算：直接复用你现有 calculate_right_padding_length（不要改）
    - hit_limit：沿用 phi.py / qwen.py 的 ended_with_eos 判定逻辑
    - tokenizer 不混用：统一使用 processor.tokenizer
    """

    def __init__(self, model, tokenizer, processor, device: str = "cuda"):
        tokenizer = getattr(processor, "tokenizer", None)
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            processor=processor,
            device=device,
        )

        if tokenizer is None:
            # raise ValueError("GraniteVisionVLM: processor.tokenizer 不存在，无法继续（避免 tokenizer 混用）")
            self.tokenizer = tokenizer
        # 强制左 padding（decoder-only 常见；即便 encoder-decoder 也无害）
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # processor 内部也同步（避免 processor 内部 tokenizer 用右 padding）
        if self.processor is not None and hasattr(self.processor, "tokenizer") and self.processor.tokenizer is not None:
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None:
                self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

    def _build_messages(self, prompt: str) -> List[dict]:
        # 按 qwen.py 的 content list 形式：image 占位符 + text
        return [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": prompt},
                ],
            }
        ]


    @staticmethod
    def _infer_model_device(model: torch.nn.Module) -> torch.device:
        """尽量从模型参数推断实际 device（比 self.device 字符串更可靠）"""
        try:
            return next(model.parameters()).device
        except StopIteration:
            return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _move_batch_to_device(self, batch, device: torch.device):
        """
        更鲁棒的 .to(device)：
        - BatchEncoding/BatchFeature: 先调用其 .to
        - dict: 递归搬运
        - list/tuple/numpy: 尝试转成 tensor 并放到 device

        主要用于修复 LlavaNext/GraniteVision 某些字段（如 image_sizes）可能不是 tensor，
        导致内部 index_select 出现 cpu/cuda 混用报错。
        """
        if batch is None:
            return batch

        # 1) transformers 的 BatchEncoding / BatchFeature
        if hasattr(batch, "to") and callable(getattr(batch, "to")):
            try:
                batch = batch.to(device)
            except Exception:
                pass

        # 2) dict 逐项处理
        if isinstance(batch, dict) or hasattr(batch, "items"):
            items = list(batch.items())
            for k, v in items:
                if isinstance(v, torch.Tensor):
                    if v.device != device:
                        batch[k] = v.to(device)
                elif isinstance(v, (list, tuple)):
                    if len(v) == 0:
                        batch[k] = v
                    elif all(isinstance(x, torch.Tensor) for x in v):
                        batch[k] = [x.to(device) for x in v]
                    else:
                        # 尝试把 list/tuple 直接 tensor 化（例如 image_sizes: [(h,w), ...]）
                        try:
                            batch[k] = torch.tensor(v, device=device)
                        except Exception:
                            batch[k] = v
                else:
                    # numpy 等：尽量 tensor 化
                    try:
                        import numpy as np  # local import
                        if isinstance(v, np.ndarray):
                            batch[k] = torch.from_numpy(v).to(device)
                    except Exception:
                        pass

        return batch

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool = False,
        bs_estimate_gen_cfg: Dict = {}
    ) -> Tuple[List[str], List[int], List[bool]]:
        assert self.processor is not None, "GraniteVisionVLM 需要 AutoProcessor"
        batch_size = len(items)
        assert batch_size > 0, "generate_batch: items 不能为空"

        # 1) 构造 messages + texts
        messages_list = []
        image_inputs = []
        for item in items:
            prompt = parse_input(item)
            img_path = get_image_path(item)
            img_path = _normalize_to_list(img_path)
            if not img_path:
                raise ValueError("GraniteVisionVLM 目前仅支持多模态输入（需要至少一张图片）")

            # 按你数据集假设：单图，取第一张
            img_path0 = img_path[0]
            if img_path0 is None:
                raise ValueError("GraniteVisionVLM: image_path 为空，请检查数据")

            messages_list.append(self._build_messages(prompt))
            image_inputs.append(Image.open(img_path0).convert("RGB"))

        texts = [
            self.processor.apply_chat_template(
                msg,
                tokenize=False,
                add_generation_prompt=True,
            )
            for msg in messages_list
        ]

        console.print(f"[cyan]执行Granite-Vision批处理，batch_size={batch_size}[/cyan]")

        # 2) processor 统一 batch padding（按 qwen.py 的写法）
        try:
            inputs = self.processor(
                text=texts,
                images=image_inputs,
                padding=True,
                padding_side="left",  # 尝试指定左 padding
                return_tensors="pt",
            )
        except TypeError:
            inputs = self.processor(
                text=texts,
                images=image_inputs,
                padding=True,
                return_tensors="pt",
            )

        # ⚠️ GraniteVision/LlavaNext 有时会在内部用到非 tensor 字段（如 image_sizes），
        # 这些字段如果留在 CPU，会触发 index_select 的 cpu/cuda 混用报错。
        target_device = self._infer_model_device(self.model)
        inputs = self._move_batch_to_device(inputs, target_device)
        # 兜底：确保 input_ids 一定在 target_device
        if isinstance(inputs, dict) and isinstance(inputs.get('input_ids', None), torch.Tensor):
            inputs['input_ids'] = inputs['input_ids'].to(target_device)


        # 3) 生成配置（按 qwen.py）
        gen_conf = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **(gen_cfg or {}),
        )

        # eos_token_id
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))
        if eos_ids and all(e is not None for e in eos_ids):
            gen_conf["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        # pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)

        if pad_id is not None:
            gen_conf["pad_token_id"] = pad_id
            self.model.config.pad_token_id = pad_id
            self.model.generation_config.pad_token_id = pad_id

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        # 4) generate
        with torch.no_grad():
            output_ids = self.model.generate(**inputs, **gen_conf)

        input_ids = inputs.get("input_ids", None)

        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limit_flags: List[bool] = []

        for i in range(batch_size):
            seq_full = output_ids[i]
            total_len = int(seq_full.size(0))

            cut = self.calculate_right_padding_length(seq_full)
            right_pad_lens.append(cut)

            # decoder-only：输出 = [prompt(padded) + generated + right_pad]
            input_len = int(input_ids.size(1))  # padded prompt len
            output_len = total_len - input_len - cut
            generated_ids = seq_full[input_len:-cut] if cut > 0 else seq_full[input_len:]

            # hit_limit：out_len >= max_new_tokens 且没以 eos 结尾
            ended_with_eos = False
            if output_len > 0 and eos_ids:
                last_token_id = int(generated_ids[-1].item())
                if last_token_id in eos_ids:
                    ended_with_eos = True

            hit_limit_flags.append((output_len >= max_new_tokens) and (not ended_with_eos))

            # decode（只用 self.tokenizer，避免混用）
            output_text = self.tokenizer.decode(
                generated_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )
            outputs.append(output_text)

        for o in outputs:
            console.print("\n[yellow]output without prompt: ", o)

        return outputs, right_pad_lens, hit_limit_flags

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outputs[0], right_pads[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
