# llava_interleave.py
import os
import sys
from typing import List, Dict, Tuple, Any

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 从 utils 导入通用工具
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path

console = Console()


class LlavaInterleaveVLM(BaseVLM):
    """
    LLaVA Interleave (llava-hf/llava-interleave-qwen-*-hf) 适配类

    - 使用 AutoProcessor.apply_chat_template + processor(...) 进行多模态编码
    - generate_batch 采用“真·batch inference”：一次 processor，一次 generate
    - right_pad_len: 直接复用你现有 calculate_right_padding_length（不修改）
    - hit_limit: out_len>=max_new_tokens 且未以 eos 结束
    """

    def __init__(self, model, processor, device: str = "cuda"):
        tokenizer = getattr(processor, "tokenizer", None)
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        assert self.processor is not None, "LlavaInterleaveVLM 需要 AutoProcessor"

        # 强制统一 tokenizer 来源：只用 processor.tokenizer，避免混用
        if self.processor is not None and hasattr(self.processor, "tokenizer") and self.processor.tokenizer is not None:
            self.tokenizer = self.processor.tokenizer

        # decoder-only：强制左 padding（跟你 qwen/phi 的逻辑一致）
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # 如果 processor 内部也有 tokenizer，一并改掉
        if self.processor is not None and hasattr(self.processor, "tokenizer") and self.processor.tokenizer is not None:
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None:
                self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

    def _build_conversation_and_images(self, item: dict):
        """
        从 item 中解析 prompt + 图片，构造:
          - conversation: 用于 apply_chat_template 的对话结构（content 为 [{"type":"text"},{"type":"image"}, ...]）
          - images:  List[PIL.Image]（可多图）
        """
        prompt = parse_input(item)
        image_paths = _normalize_to_list(get_image_path(item))

        if not image_paths:
            raise ValueError("LlavaInterleaveVLM 目前仅支持多模态输入（需要至少一张图片）")

        images: List[Image.Image] = []
        for p in image_paths:
            if p is None:
                continue
            img = Image.open(p).convert("RGB")
            images.append(img)

        if len(images) == 0:
            raise ValueError("LlavaInterleaveVLM: 未成功加载任何图片，请检查 image_path 是否正确")

        # 按官方示例：text 后面跟 image（多图就多个 image 占位）
        content = [{"type": "text", "text": prompt}] + [{"type": "image"} for _ in images]
        conversation = [
            {
                "role": "user",
                "content": content,
            }
        ]
        return conversation, images

    def _prepare_generation_config(self, max_new_tokens: int, gen_cfg: Dict) -> Dict:
        """
        整理最终用于 model.generate 的 kwargs（除了 inputs）
        """
        generation_config = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **(gen_cfg or {}),
        )

        # eos_token_id：model.generation_config 与 tokenizer 的并集
        from_ids = [
            getattr(self.model.generation_config, "eos_token_id", None),
            getattr(self.tokenizer, "eos_token_id", None),
        ]
        eos_ids = list({i for i in _normalize_to_list(from_ids) if i is not None})
        if eos_ids:
            generation_config["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        # pad_token_id：优先 model.generation_config，其次 tokenizer，最后 eos
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)

        if pad_id is not None:
            generation_config["pad_token_id"] = pad_id
            self.model.config.pad_token_id = pad_id
            self.model.generation_config.pad_token_id = pad_id

        return generation_config

    def _to_device_dtype(self, inputs: Any):
        """
        递归地把 processor 输出里的所有 tensor 搬到 self.device；
        只对 float tensor 做 dtype 对齐（FP16/BF16），不动 input_ids/attention_mask 这类整型。
        """
        model_dtype = getattr(self.model, "dtype", None)

        def move(x):
            if isinstance(x, torch.Tensor):
                x = x.to(self.device)
                if model_dtype is not None and torch.is_floating_point(x):
                    x = x.to(dtype=model_dtype)
                return x
            if isinstance(x, (list, tuple)):
                return type(x)(move(v) for v in x)
            if isinstance(x, dict):
                return {k: move(v) for k, v in x.items()}
            return x

        # BatchEncoding / BatchFeature 都支持 items() 且可写回
        if hasattr(inputs, "items"):
            for k, v in list(inputs.items()):
                inputs[k] = move(v)
            return inputs

        return move(inputs)

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool = False,
        bs_estimate_gen_cfg: Dict = {}
    ) -> Tuple[List[str], List[int], List[bool]]:
        """
        真·batch 推理：
          1) batch 构造 conversations + images
          2) apply_chat_template 得到 texts
          3) processor(text=texts, images=images, padding=True) 得到 inputs（左 pad）
          4) 一次 model.generate
          5) 截掉前缀（padded prompt 长度），对 new tokens 计算 right_pad_len，并解码
        返回：
          - outputs: 文本输出
          - right_pad_lens: 每个样本右侧填充长度（pad + 多余 eos）
          - hit_limits: 是否达到 max_new_tokens 上限（以 out_len 判断 + eos 结尾修正）
        """
        batch_size = len(items)
        assert batch_size > 0, "generate_batch: items 不能为空"

        # 1) conversations & images
        conversations = []
        images_batch = []
        for item in items:
            conv, imgs = self._build_conversation_and_images(item)
            conversations.append(conv)
            # 允许多图：processor 通常支持 List[Image] 或 List[List[Image]]
            images_batch.append(imgs[0] if len(imgs) == 1 else imgs)

        # 2) chat_template -> texts（按官方示例，不强依赖 tokenize 参数）
        texts = [
            self.processor.apply_chat_template(conv, add_generation_prompt=True)
            for conv in conversations
        ]

        # 3) processor batch encode（尽量指定 left padding；不支持就回退）
        try:
            inputs = self.processor(
                text=texts,
                images=images_batch,
                padding=True,
                padding_side="left",
                return_tensors="pt",
            )
        except TypeError:
            inputs = self.processor(
                text=texts,
                images=images_batch,
                padding=True,
                return_tensors="pt",
            )

        generation_config = self._prepare_generation_config(max_new_tokens=max_new_tokens, gen_cfg=gen_cfg)
        pad_id = generation_config.get("pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", getattr(self.tokenizer, "eos_token_id", None))

        # attention_mask 兜底
        if "attention_mask" not in inputs and "input_ids" in inputs and pad_id is not None:
            inputs["attention_mask"] = (inputs["input_ids"] != pad_id).long()

        # 搬运到 device + dtype 对齐
        inputs = self._to_device_dtype(inputs)

        console.print(f"[cyan][LlavaInterleave] 执行批处理，batch_size={batch_size}[/cyan]")

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        # 4) generate
        with torch.no_grad():
            generate_ids = self.model.generate(
                **inputs,
                **generation_config,
            )

        # 5) 截掉前缀（长度 = padded input_ids 的长度）
        prefix_len_padded = inputs["input_ids"].shape[1]
        new_tokens_all = generate_ids[:, prefix_len_padded:]  # [B, T_out_max]

        # 6) right_pad / hit_limit / decode
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []
        decoded_token_ids: List[List[int]] = []

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set([e for e in (gc_eos + tok_eos) if e is not None]))

        for i in range(batch_size):
            seq = new_tokens_all[i]

            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)

            seq = seq[:-cut] if cut > 0 else seq
            out_len = int(seq.shape[0])

            ended_with_eos = False
            if out_len > 0 and eos_ids:
                ended_with_eos = int(seq[-1].item()) in eos_ids

            hit_limits.append(out_len >= max_new_tokens and not ended_with_eos)
            decoded_token_ids.append(seq.detach().cpu().tolist())

        # 一次性 batch decode（不混用 tokenizer：self.tokenizer 就是 processor.tokenizer）
        texts_out = self.tokenizer.batch_decode(
            decoded_token_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )
        outputs = [t.strip() for t in texts_out]

        for out in outputs:
            console.print("\n[yellow]output without prompt: ", out)

        return outputs, right_pad_lens, hit_limits

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outputs[0], right_pads[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # pad_id == eos_id 且多截了一个真正的 eos 的情况
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
