import os
import sys
from typing import List, Dict, Tuple

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 从 utils 导入通用工具（与你 qwen.py / phi.py 保持一致）
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()

IMAGE_TOKEN_ID = -200  # Bunny 官方示例使用的 image placeholder token id


class BunnyVLM(BaseVLM):
    """
    Bunny (BAAI/Bunny-v1_1-Llama-3-8B-V)

    输出严格对齐官方示例：
      generated_ids = output_ids[i][input_len:]
      tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

    返回格式对齐你现有 qwen.py / phi.py：
      generate_batch -> (outputs, right_pad_lens, hit_limits)
      generate_one   -> (output_text, right_pad_len, hit_limit)
    """

    def __init__(self, model, tokenizer, processor, device: str = "cuda"):
        super().__init__(model=model, tokenizer=tokenizer, processor=None, device=device)

        # 强制统一 tokenizer 来源：如果 processor 存在且有 tokenizer，优先使用 processor.tokenizer
        if processor is not None and hasattr(processor, "tokenizer") and processor.tokenizer is not None:
            self.tokenizer = processor.tokenizer

        # 强制左 padding，避免警告 + 对齐生成逻辑
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # 如果 processor 内部也有 tokenizer，一并改掉
        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None:
                self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

    @staticmethod
    def _pad_left(seqs: List[torch.Tensor], pad_token_id: int) -> torch.Tensor:
        max_len = max(s.size(0) for s in seqs)
        dtype = seqs[0].dtype
        device = seqs[0].device
        out = torch.full((len(seqs), max_len), pad_token_id, dtype=dtype, device=device)
        for i, s in enumerate(seqs):
            out[i, -s.size(0):] = s
        return out

    def _build_input_ids(self, prompt: str) -> torch.Tensor:
        # 完全复刻官方示例：split('<image>') + 插入 -200 + 第二段[1:]
        text = (
            f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
        )
        text_chunks = [self.tokenizer(chunk).input_ids for chunk in text.split("<image>")]
        ids = text_chunks[0] + [IMAGE_TOKEN_ID] + text_chunks[1][1:]
        return torch.tensor(ids, dtype=torch.long)

    def _prepare_gen_conf(self, max_new_tokens: int, gen_cfg: Dict) -> Tuple[Dict, List[int], int]:
        gen_conf = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **(gen_cfg or {}),
        )

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))
        eos_ids = [e for e in eos_ids if e is not None]
        if eos_ids:
            gen_conf["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)

        if pad_id is not None:
            gen_conf["pad_token_id"] = pad_id
            self.model.config.pad_token_id = pad_id
            self.model.generation_config.pad_token_id = pad_id

        return gen_conf, eos_ids, pad_id

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool = False,
        bs_estimate_gen_cfg: Dict = {}
    ) -> Tuple[List[str], List[int], List[bool]]:
        batch_size = len(items)
        assert batch_size > 0, "generate_batch: items 不能为空"

        # 1) prompt + image（单图）
        prompts: List[str] = []
        image_paths: List[str] = []
        for item in items:
            prompt = parse_input(item)
            img = get_image_path(item)
            img_list = _normalize_to_list(img)
            if not img_list:
                raise ValueError("BunnyVLM 目前仅支持多模态输入（需要一张图片）")
            prompts.append(prompt)
            image_paths.append(img_list[0])

        # 2) input_ids（每条不等长）-> left pad 成 batch
        seqs = [self._build_input_ids(p) for p in prompts]
        pad_token_id = self.tokenizer.pad_token_id
        seqs = [s.to(self.device) for s in seqs]
        input_ids = self._pad_left(seqs, pad_token_id=pad_token_id)
        attention_mask = (input_ids != pad_token_id).long()

        # 3) 图片 -> image_tensor（官方：model.process_images）
        images = [Image.open(p).convert("RGB") for p in image_paths]
        image_tensor = self.model.process_images(images, self.model.config).to(
            dtype=self.model.dtype, device=self.device
        )

        gen_conf, eos_ids, _ = self._prepare_gen_conf(max_new_tokens=max_new_tokens, gen_cfg=gen_cfg)

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            batch_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
            batch_inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=batch_inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
            input_ids = batch_inputs["input_ids"]
            attention_mask = batch_inputs["attention_mask"]
        # ===========================================================================

        with torch.no_grad():
            output_ids = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                images=image_tensor,
                use_cache=True,
                **gen_conf,
            )

        # 4) 按官方示例切输出：output_ids[i][input_len:]
        input_len = input_ids.size(1)

        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []

        for i in range(batch_size):
            generated_ids = output_ids[i][input_len:]
            # generated_ids = output_ids[i]

            # right_pad：直接复用你现有实现
            cut = self.calculate_right_padding_length(generated_ids)
            right_pad_lens.append(int(cut))

            gen_trim = generated_ids[:-cut] if cut > 0 else generated_ids

            # hit_limit：用“去掉右侧 pad/多余 eos 后”的真实长度判断
            ended_with_eos = False
            if gen_trim.numel() > 0 and eos_ids:
                ended_with_eos = int(gen_trim[-1].item()) in eos_ids

            hit_limit = (int(gen_trim.numel()) >= int(max_new_tokens)) and (not ended_with_eos)
            hit_limits.append(bool(hit_limit))

            # gen_trim = [id for id in gen_trim if 0<= id <= self.tokenizer.vocab_size -1]

            out_text = self.tokenizer.decode(
                gen_trim,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            ).strip()
            outputs.append(out_text)

        for o in outputs:
            console.log("\n[yellow]output without prompt:", o)

        return outputs, right_pad_lens, hit_limits

    def generate_one(
        self,
        item: dict,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Tuple[str, int, bool]:
        outputs, right_pad_lens, hit_limits = self.generate_batch(
            [item],
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
            oom_estimate=False,
            bs_estimate_gen_cfg={}
        )
        return outputs[0], right_pad_lens[0], hit_limits[0]


    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
