# phi_vision.py，该系列全是 instruct，放心使用 chat_tempalte

import os
import sys
from typing import List, Dict, Tuple

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 从 utils 导入通用工具
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path

console = Console()


class PhiVisionVLM(BaseVLM):
    """
    Phi 视觉家族 (如 microsoft/Phi-3.5-vision-instruct) 适配类

    - 使用 AutoProcessor + tokenizer.apply_chat_template
    - 支持单图 / 多图输入
    - generate_batch 采用“真·batch inference”：单条 encode，多条手动 pad+stack 后一次 generate
    """

    def __init__(self, model, processor, device: str = "cuda"):
        """
        model: 已加载好的 AutoModelForCausalLM（trust_remote_code=True）
        processor: AutoProcessor.from_pretrained(...)
        device: 模型所在设备（通常是 'cuda' 或 'cuda:0'）
        """
        tokenizer = getattr(processor, "tokenizer", None)
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            processor=processor,
            device=device,
        )

    def _build_messages_and_images(self, item: dict):
        """
        从 item 中解析 prompt + 图片，构造:
          - messages: 用于 chat_template 的 messages
          - images:  List[PIL.Image]
        """
        prompt = parse_input(item)
        image_paths = get_image_path(item)

        # 允许 image_paths 是 str / List[str]，统一成 list
        image_paths = _normalize_to_list(image_paths)

        if not image_paths:
            raise ValueError("PhiVisionVLM 目前仅支持多模态输入（需要至少一张图片）")

        # 打开图片
        images = []
        for p in image_paths:
            if p is None:
                continue
            img = Image.open(p).convert("RGB")
            images.append(img)

        if len(images) == 0:
            raise ValueError("PhiVisionVLM: 未成功加载任何图片，请检查 image_path 是否正确")

        # 如果 prompt 里已经手写了 <|image_x|> 占位符，则尊重用户，不再自动追加
        if "<|image_" in prompt:
            content = prompt
        else:
            # 按官方示例自动拼接占位符
            placeholder = ""
            for i in range(1, len(images) + 1):
                placeholder += f"<|image_{i}|>\n"
            content = placeholder + prompt

        messages = [
            {
                "role": "user",
                "content": content,
            }
        ]

        return messages, images

    def _prepare_generation_config(
        self,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Dict:
        """
        整理最终用于 model.generate 的 kwargs（除了 inputs）
        """
        generation_config = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **(gen_cfg or {}),
        )

        # eos_token_id
        from_ids = [
            getattr(self.model.generation_config, "eos_token_id", None),
            getattr(self.tokenizer, "eos_token_id", None),
        ]
        eos_ids = list({i for i in _normalize_to_list(from_ids) if i is not None})
        if eos_ids:
            generation_config["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        # pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)

        if pad_id is not None:
            generation_config["pad_token_id"] = pad_id
            self.model.config.pad_token_id = pad_id
            self.model.generation_config.pad_token_id = pad_id

        return generation_config

    @staticmethod
    def _pad_left(seqs: List[torch.Tensor], pad_token_id: int) -> torch.Tensor:
        """左侧 pad 到同一长度（decoder-only 常用形式）"""
        max_len = max(s.size(0) for s in seqs)
        dtype = seqs[0].dtype
        device = seqs[0].device
        out = torch.full((len(seqs), max_len), pad_token_id, dtype=dtype, device=device)
        for i, s in enumerate(seqs):
            out[i, -s.size(0):] = s
        return out

    def _encode_one(self, messages, images):
        """
        对单个 (messages, images) 做 processor 处理，返回:
          - features: dict[tensor]，各字段 batch_size=1
          - prompt_len: 该样本“实际 prompt token 数”（不含 pad）
        """
        tokenizer = self.tokenizer

        # 用 tokenize=True 获取 prompt token 长度
        enc = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        )
        if isinstance(enc, dict):
            ids = enc["input_ids"]
        else:
            ids = enc
        prompt_len = int(ids.shape[-1])

        # tokenize=False，交给 processor 处理文本+图像
        prompt_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )

        features = self.processor(
            text=prompt_text,
            images=images,
            return_tensors="pt",
        )
        # 统一不在这里 .to(self.device)，避免多次搬运；stack 时再搬运

        return features, prompt_len

    def _stack_and_pad(self, encoded_list: List[dict], pad_token_id: int) -> Dict[str, torch.Tensor]:
        """
        将若干单样本 features 合成为一个大 batch：
          - input_ids: 左 pad 到相同长度
          - attention_mask: 根据 input_ids != pad_token_id 构造
          - 其它视觉字段：沿 batch 维拼接
        """
        # 文本部分：input_ids left-pad
        seqs = [feat["input_ids"][0] for feat in encoded_list]   # [1, L_i] -> [L_i]
        seqs = [s.to(self.device) for s in seqs]
        input_ids = self._pad_left(seqs, pad_token_id=pad_token_id)

        data: Dict[str, torch.Tensor] = {
            "input_ids": input_ids,
            "attention_mask": (input_ids != pad_token_id).long(),
        }

        # 视觉部分 & 其它字段：逐 key 拼接
        sample_keys = list(encoded_list[0].keys())
        for k in sample_keys:
            if k in {"input_ids", "attention_mask"}:
                continue
            tensors = [feat[k] for feat in encoded_list]  # 每个 shape: [1, ...] or [num_imgs, ...]
            # 默认沿 batch 维 cat
            try:
                cat = torch.cat(tensors, dim=0).to(self.device)
            except Exception:
                # 万一维度不匹配，退回 stack（基本不会出现）
                cat = torch.stack(tensors, dim=0).to(self.device)
            data[k] = cat

        return data

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:
        """
        真·batch 推理：
          1. 每个样本单独 messages+images -> processor encode
          2. 手动左 pad & stack 成批量 inputs
          3. 一次性 model.generate
          4. 截掉前缀，按样本解码输出
        返回：
          - outputs: 文本输出
          - right_pad_lens: 每个样本右侧填充长度（pad + 多余 eos）
          - hit_limits: 是否达到 max_new_tokens 上限（粗略以 output_len 判断）
        """
        batch_size = len(items)
        assert batch_size > 0, "generate_batch: items 不能为空"

        right_pad_lens: List[int] = []

        # 1) 逐样本 encode
        encoded_list: List[dict] = []

        for item in items:
            messages, images = self._build_messages_and_images(item)
            features, plen = self._encode_one(messages, images)
            encoded_list.append(features)

        # 2) 整理 pad_token_id & generation_config
        generation_config = self._prepare_generation_config(
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
        )
        pad_token_id = generation_config.get(
            "pad_token_id",
            getattr(self.tokenizer, "pad_token_id", getattr(self.tokenizer, "eos_token_id", None)),
        )

        if pad_token_id is None:
            raise ValueError("PhiVisionVLM: 无法确定 pad_token_id")

        # 3) 手动 stack & pad
        inputs = self._stack_and_pad(encoded_list, pad_token_id=pad_token_id)

        # console.print(
        #     f"[cyan][PhiVisionVLM] 批量推理，batch_size={batch_size}, "
        #     f"seq_len={inputs['input_ids'].shape[1]}[/cyan]"
        # )

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        # 4) 一次性 generate
        with torch.no_grad():
            generate_ids = self.model.generate(
                **inputs,
                **generation_config,
            )

        # 5) 截掉前缀（长度 = padded input_ids 的长度）
        prefix_len_padded = inputs["input_ids"].shape[1]  # left-pad 后的长度
        new_tokens_all = generate_ids[:, prefix_len_padded:]  # [B, T_out_max]

        # 6) 按样本拆解 & 解码
        outputs: List[str] = []
        hit_limits: List[bool] = []

        decoded_token_ids: List[List[int]] = []

        for i in range(batch_size):
            seq = new_tokens_all[i]
            # 计算需要从右侧截掉多少个 pad/eos
            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)
            seq = seq[:-cut] if cut > 0 else seq

            out_len = int(seq.shape[0])

            # 检查是否以 EOS 结尾，采用 Qwen 的更精细逻辑
            ended_with_eos = False
            if out_len > 0:
                eos_ids = []
                gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
                tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
                eos_ids = list(set(gc_eos + tok_eos))
                if eos_ids and seq[-1].item() in eos_ids:
                    ended_with_eos = True

            hit_limits.append(out_len >= max_new_tokens and not ended_with_eos)

            decoded_token_ids.append(seq.cpu().tolist())

        # 一次性 batch_decode
        texts = self.processor.batch_decode(
            decoded_token_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )
        outputs = [t.strip() for t in texts]

        for out in outputs:
            console.print("\n[yellow]output without prompt: ", out)

        return outputs, right_pad_lens, hit_limits

    def generate_one(
        self,
        item: dict,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Tuple[str, int, bool]:
        """
        单样本推理：复用 generate_batch 确保逻辑一致
        """
        outputs, right_pad_lens, hit_limits = self.generate_batch(
            [item],
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
            oom_estimate=False,
            bs_estimate_gen_cfg={}
        )
        return outputs[0], right_pad_lens[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len

