import os
import sys
from typing import List, Dict, Tuple

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 从 utils 导入通用工具
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()


class OvisVLM(BaseVLM):
    """
    Ovis2 系列适配类（如 AIDC-AI/Ovis2-4B）
    """

    def __init__(
        self,
        model,
        device: str = "cuda",
        max_partition: int = 9,
    ):
        """
        model: AutoModelForCausalLM.from_pretrained("AIDC-AI/Ovis2-4B", trust_remote_code=True)
        """
        # Ovis 提供的专用 tokenizer 接口
        text_tokenizer = model.get_text_tokenizer()
        super().__init__(model=model, tokenizer=text_tokenizer, processor=None, device=device)

        # 额外保存，方便使用 dtype / device
        self.text_tokenizer = text_tokenizer
        self.visual_tokenizer = model.get_visual_tokenizer()

        # padding 相关设定：统一左侧 padding + 明确 pad_token
        if hasattr(self.text_tokenizer, "padding_side"):
            self.text_tokenizer.padding_side = "left"
        if self.text_tokenizer.pad_token is None:
            self.text_tokenizer.pad_token = self.text_tokenizer.eos_token

        pad_id = self.text_tokenizer.pad_token_id
        if pad_id is not None:
            self.model.config.pad_token_id = pad_id
            self.model.generation_config.pad_token_id = pad_id

        # Ovis 配置里的 multimodal_max_length
        self.multimodal_max_length = getattr(self.model.config, "multimodal_max_length", None)
        self.max_partition = max_partition

    # ---------------------------
    # 一些内部辅助函数
    # ---------------------------

    def _build_query_and_images(self, item: dict):
        """
        从通用 item（jsonl 一行对应一个 dict）中解析出：
          - query: 传给 model.preprocess_inputs 的文本
          - images: List[PIL.Image]
        """
        prompt = parse_input(item)
        image_paths = get_image_path(item)
        image_paths = _normalize_to_list(image_paths)

        if not image_paths:
            raise ValueError("OvisVLM 目前仅支持多模态输入（需要至少一张图片）")

        images: List[Image.Image] = []
        for p in image_paths:
            if p is None:
                continue
            img = Image.open(p).convert("RGB")
            images.append(img)

        if len(images) == 0:
            raise ValueError("OvisVLM: 未成功加载任何图片，请检查 image_path 是否正确")

        # 如果 prompt 已经自己写了 <image> 占位符，则不再自动添加
        if "<image>" in prompt:
            query = prompt
        else:
            if len(images) == 1:
                # 官方单图示例：'<image>\n{text}'
                query = f"<image>\n{prompt}"
            else:
                # 官方多图示例类似：'<image>\n<image>\n...\n{text}'
                prefix = "\n".join(["<image>"] * len(images))
                query = prefix + "\n" + prompt

        return query, images

    def _prepare_generation_config(
        self,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Dict:
        """
        整理最终用于 model.generate 的 kwargs（除 inputs 外）
        - 合并 / 设置 eos_token_id / pad_token_id
        """
        generation_config = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **(gen_cfg or {}),
        )

        # eos_token_id 统一收集自 model.generation_config 和 tokenizer
        from_ids = [
            getattr(self.model.generation_config, "eos_token_id", None),
            getattr(self.text_tokenizer, "eos_token_id", None),
        ]

        eos_ids = []
        for x in from_ids:
            if x is None:
                continue
            if isinstance(x, (list, tuple)):
                eos_ids.extend(list(x))
            else:
                eos_ids.append(x)
        eos_ids = list({i for i in eos_ids if i is not None})

        if eos_ids:
            generation_config["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        # pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.text_tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.text_tokenizer, "eos_token_id", None)

        if pad_id is not None:
            generation_config["pad_token_id"] = pad_id
            self.model.config.pad_token_id = pad_id
            self.model.generation_config.pad_token_id = pad_id

        return generation_config

    @staticmethod
    def _pad_left(seqs: List[torch.Tensor], pad_token_id: int) -> torch.Tensor:
        """
        左侧 pad 到同一长度（decoder-only 常用形式）
        seqs: List[1D tensor]，支持 int / bool 等
        """
        if len(seqs) == 0:
            raise ValueError("_pad_left: seqs 为空")

        max_len = max(s.size(0) for s in seqs)
        dtype = seqs[0].dtype
        device = seqs[0].device

        out = torch.full(
            (len(seqs), max_len),
            pad_token_id,
            dtype=dtype,
            device=device,
        )
        for i, s in enumerate(seqs):
            out[i, -s.size(0):] = s
        return out

    # ---------------------------
    # 核心：batch / single 推理
    # ---------------------------

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:
        batch_size = len(items)
        assert batch_size > 0, "generate_batch: items 不能为空"

        # 1) 逐样本 encode
        input_ids_list: List[torch.Tensor] = []
        attn_mask_list: List[torch.Tensor] = []
        pixel_values_list: List[torch.Tensor] = []

        for idx, item in enumerate(items):
            query, images = self._build_query_and_images(item)

            # 官方接口：prompt, input_ids, pixel_values
            _, input_ids, pixel_values = self.model.preprocess_inputs(
                query,
                images,
                max_partition=self.max_partition,
            )

            if not isinstance(input_ids, torch.Tensor):
                raise TypeError(
                    f"OvisVLM: preprocess_inputs 返回的 input_ids 不是 Tensor，"
                    f"得到类型 {type(input_ids)}"
                )

            input_ids = input_ids.to(self.device)
            pad_id = self.text_tokenizer.pad_token_id

            # attention_mask：非 pad 位置为 1（或 True）
            attention_mask = torch.ne(input_ids, pad_id)

            # pixel_values：移动到视觉 tokenizer 定义的 dtype / device
            # 若 visual_tokenizer 未定义 dtype/device，则退到模型 device / bfloat16
            visual_device = getattr(self.visual_tokenizer, "device", self.device)
            visual_dtype = getattr(self.visual_tokenizer, "dtype", torch.bfloat16)

            pixel_values = pixel_values.to(dtype=visual_dtype, device=visual_device)

            input_ids_list.append(input_ids)
            attn_mask_list.append(attention_mask)
            pixel_values_list.append(pixel_values)

        # 2) 生成配置
        generation_config = self._prepare_generation_config(
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
        )
        pad_id = generation_config.get(
            "pad_token_id",
            getattr(self.text_tokenizer, "pad_token_id", getattr(self.text_tokenizer, "eos_token_id", None)),
        )

        if pad_id is None:
            raise ValueError("OvisVLM: 无法确定 pad_token_id")

        # 3) 左 pad 合并 batch
        # input_ids: List[1D] -> [B, L_max]
        batch_input_ids = self._pad_left(
            [ids.to(self.device) for ids in input_ids_list],
            pad_token_id=pad_id,
        )
        # attention_mask: List[1D bool] -> [B, L_max]
        batch_attention_mask = self._pad_left(
            [m.to(self.device) for m in attn_mask_list],
            pad_token_id=0,
        ).long()

        # 截断到 multimodal_max_length
        if self.multimodal_max_length is not None:
            batch_input_ids = batch_input_ids[:, -self.multimodal_max_length:]
            batch_attention_mask = batch_attention_mask[:, -self.multimodal_max_length:]

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            batch_model_inputs = {"input_ids": batch_input_ids, "attention_mask": batch_attention_mask}
            batch_model_inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=batch_model_inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
            batch_input_ids, batch_attention_mask = batch_model_inputs["input_ids"], batch_model_inputs["attention_mask"]
        # ===========================================================================

        # 4) 一次性 generate
        with torch.no_grad():
            # 注意：Ovis 的 generate 会读取 pixel_values（List[Tensor]），
            # 并只返回 output_ids（不包含输入部分）
            output_ids = self.model.generate(
                batch_input_ids,
                pixel_values=pixel_values_list,
                attention_mask=batch_attention_mask,
                **generation_config,
            )

        if not isinstance(output_ids, torch.Tensor):
            # 保险起见：有些实现可能返回 List[Tensor]
            output_ids = torch.stack(output_ids, dim=0)

        # 5) 按样本拆解 & 右侧清理 + 解码
        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []

        # 收集 eos_ids
        gc_eos = generation_config.get("eos_token_id", None)
        eos_ids: List[int] = []
        if gc_eos is not None:
            if isinstance(gc_eos, (list, tuple)):
                eos_ids.extend(list(gc_eos))
            else:
                eos_ids.append(gc_eos)
        tok_eos = getattr(self.text_tokenizer, "eos_token_id", None)
        if isinstance(tok_eos, (list, tuple)):
            eos_ids.extend(list(tok_eos))
        elif tok_eos is not None:
            eos_ids.append(tok_eos)
        eos_ids = list({i for i in eos_ids if i is not None})

        for i in range(batch_size):
            seq = output_ids[i]

            # 计算右侧 padding 长度（pad + 多余 eos）
            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)

            if cut > 0:
                trimmed = seq[:-cut]
            else:
                trimmed = seq

            out_len = int(trimmed.shape[0])
            ended_with_eos = False
            if out_len > 0 and eos_ids:
                if trimmed[-1].item() in eos_ids:
                    ended_with_eos = True

            hit_limit = (out_len >= max_new_tokens) and (not ended_with_eos)
            hit_limits.append(hit_limit)

            # 解码输出
            text = self.text_tokenizer.decode(
                trimmed,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            ).strip()
            outputs.append(text)

        for _, text in enumerate(outputs):
            console.print("\n[green]output without prompt[/green]", text)

        return outputs, right_pad_lens, hit_limits

    def generate_one(
        self,
        item: dict,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Tuple[str, int, bool]:
        """
        单样本推理：复用 generate_batch 确保逻辑完全一致
        """
        outputs, right_pad_lens, hit_limits = self.generate_batch(
            [item],
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
            oom_estimate=False,
            bs_estimate_gen_cfg={}
        )
        return outputs[0], right_pad_lens[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
