from typing import List, Dict, Tuple
import os
import sys

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 从 utils 导入通用工具
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()


class OvisU1VLM(BaseVLM):
    """
    Ovis2 系列：基于官方示例实现多模态单样本 / 批处理推理

    - 使用 model.preprocess_inputs(query, images, max_partition)
    - 文本 tokenizer 使用 model.get_text_tokenizer()
    - 视觉 tokenizer 使用 model.get_visual_tokenizer()
    - generate_batch / generate_one 的返回格式与 QwenVLVLM 保持一致：
        * generate_batch -> (outputs, right_pad_lens, hit_limit_flags)
        * generate_one   -> (output_text, right_pad_len, hit_limit)
    """

    def __init__(self, model, device: str = "cuda"):
        text_tokenizer = model.get_text_tokenizer()
        super().__init__(model=model, tokenizer=text_tokenizer, processor=None, device=device)

        # 额外保存 text / visual tokenizer
        self.text_tokenizer = text_tokenizer
        self.visual_tokenizer = model.get_visual_tokenizer()

        # ★ 强制左 padding，避免警告 + 对齐生成逻辑
        self.text_tokenizer.padding_side = "left"
        if self.text_tokenizer.pad_token is None:
            self.text_tokenizer.pad_token = self.text_tokenizer.eos_token

        pad_id = self.text_tokenizer.pad_token_id
        if getattr(self.model.generation_config, "pad_token_id", None) is None:
            self.model.generation_config.pad_token_id = pad_id
        if getattr(self.model.config, "pad_token_id", None) is None:
            self.model.config.pad_token_id = pad_id

    # ----------------------------------------------------------------------
    # 公共小工具：构造单条样本的 input_ids / attention_mask / pixel_values
    # ----------------------------------------------------------------------
    def _build_single_inputs(
        self,
        image_path: str,
        text: str,
        max_partition: int,
    ):
        """按照官方示例，构造单条样本的输入三元组"""
        image = Image.open(image_path).convert("RGB")
        query = f"<image>\n{text}"

        # 官方接口：返回 prompt（字符串）、input_ids（1D tensor）、pixel_values（Tensor）
        try:
            out = self.model.preprocess_inputs(
                query,
                [image],
                max_partition=max_partition,
            )
        except TypeError as e:
            if "unexpected keyword argument 'max_partition'" in str(e):
                # 回退到不支持 max_partition 参数的版本
                out = self.model.preprocess_inputs(
                    query,
                    [image],
                )
            else:
                raise

        # 兼容 Ovis2（三元组）和 Ovis-U1（四元组）
        if len(out) == 3:
            prompt, input_ids, pixel_values = out
            grid_thws = None
        elif len(out) == 4:
            prompt, input_ids, pixel_values, grid_thws = out
        else:
            raise ValueError(f"Unexpected number of outputs from preprocess_inputs: {len(out)}")

        # attention_mask
        attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)

        # 放到正确 device：全部走 self.device
        device = self.device
        input_ids = input_ids.to(device=device)
        attention_mask = attention_mask.to(device=device)

        if pixel_values is not None:
            pixel_values = pixel_values.to(
                dtype=torch.bfloat16,
                device=device,
                non_blocking=True,
            )
        if grid_thws is not None:
            grid_thws = grid_thws.to(
                device=device,
                non_blocking=True,
            )

        # ★ 注意：现在多返回了 grid_thws
        return input_ids, attention_mask, pixel_values, grid_thws

    # ----------------------------------------------------------------------
    # 批处理推理：多样本多图输入
    # ----------------------------------------------------------------------
    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:
        """
        批处理实现：返回 (outputs, right_pad_lens, hit_limit_flags)
        其中：
            - outputs:      每条样本的生成文本（不含 prompt）
            - right_pad_lens: 每条样本被视作“右侧填充”的长度（pad+多余 eos）
            - hit_limit_flags: 是否触达 max_new_tokens 且未以 EOS 结束
        """
        console.print(f"[cyan]执行 OvisVLM 批处理，batch_size={len(items)}[/cyan]")

        # 从 gen_cfg 中抽出 max_partition（如果有），默认 9
        local_cfg = dict(gen_cfg) if gen_cfg is not None else {}
        max_partition = local_cfg.pop("max_partition", 9)

        # 为每个样本构造输入
        single_input_ids = []
        single_attention_masks = []
        single_pixel_values = []
        single_grid_thws = []

        for item in items:
            prompt = parse_input(item)
            image_path = get_image_path(item)
            if not image_path:
                raise ValueError("OvisVLM 目前仅支持多模态输入，item 中需要提供图像路径")

            input_ids, attention_mask, pixel_values, grid_thws = self._build_single_inputs(
                image_path=image_path,
                text=prompt,
                max_partition=max_partition,
            )
            single_input_ids.append(input_ids)
            single_attention_masks.append(attention_mask)
            single_pixel_values.append(pixel_values)
            single_grid_thws.append(grid_thws)

        # ★ 按官方示例进行左侧 padding（flip + pad_sequence + flip）
        pad_id = self.text_tokenizer.pad_token_id if self.text_tokenizer.pad_token_id is not None else 0

        batch_input_ids = torch.nn.utils.rnn.pad_sequence(
            [ids.flip(dims=[0]) for ids in single_input_ids],
            batch_first=True,
            padding_value=pad_id,
        ).flip(dims=[1])

        batch_attention_mask = torch.nn.utils.rnn.pad_sequence(
            [m.flip(dims=[0]) for m in single_attention_masks],
            batch_first=True,
            padding_value=False,
        ).flip(dims=[1])

        # 截断到 multimodal_max_length（与官方示例保持一致）
        mm_max_len = getattr(self.model.config, "multimodal_max_length", None)
        if mm_max_len is not None:
            batch_input_ids = batch_input_ids[:, -mm_max_len:]
            batch_attention_mask = batch_attention_mask[:, -mm_max_len:]

        # 移到 device
        batch_input_ids = batch_input_ids.to(self.device)
        batch_attention_mask = batch_attention_mask.to(self.device)

        # 处理 pixel_values
        if any(pv is not None for pv in single_pixel_values):
            batch_pixel_values = torch.cat(
                [pv for pv in single_pixel_values if pv is not None],
                dim=0,
            ).to(self.device, non_blocking=True)
        else:
            batch_pixel_values = None

        # ★ 对 grid_thws 做同样处理
        if any(g is not None for g in single_grid_thws):
            batch_grid_thws = torch.cat(
                [g for g in single_grid_thws if g is not None],
                dim=0,
            ).to(self.device, non_blocking=True)
        else:
            batch_grid_thws = None

        # ------------------------------------------------------------------
        # 构造 generate 配置
        # ------------------------------------------------------------------
        gen_conf = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **local_cfg,
        )

        # 处理 EOS token
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.text_tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))
        if eos_ids and all(e is not None for e in eos_ids):
            gen_conf["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        # 处理 pad token
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.text_tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.text_tokenizer, "eos_token_id", None)

        if pad_id is not None:
            gen_conf["pad_token_id"] = pad_id
            self.model.config.pad_token_id = pad_id
            self.model.generation_config.pad_token_id = pad_id

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            batch_model_inputs = {"input_ids": batch_input_ids, "attention_mask": batch_attention_mask}
            batch_model_inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=batch_model_inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
            batch_input_ids, batch_attention_mask = batch_model_inputs["input_ids"], batch_model_inputs["attention_mask"]
        # ===========================================================================

        # ------------------------------------------------------------------
        # 调用 generate
        # ------------------------------------------------------------------
        with torch.no_grad():
            output_ids = self.model.generate(
                batch_input_ids,
                pixel_values=batch_pixel_values,
                attention_mask=batch_attention_mask,
                grid_thws=batch_grid_thws,   # ★ 关键！
                **gen_conf,
            )

            batch_size = batch_input_ids.size(0)

            outputs: List[str] = []
            right_pad_lens: List[int] = []
            hit_limit_flags: List[bool] = []

            for i in range(batch_size):
                seq = output_ids[i]
                total_len = seq.size(0)

                # 为 right_pad_len 统计用完整序列
                cut = self.calculate_right_padding_length(seq)
                right_pad_lens.append(cut)

                # 生成部分，不进行截断
                generated_ids = seq[:-cut] if cut > 0 else seq
                output_len = total_len - cut

                # 是否 hit limit
                ended_with_eos = False
                if output_len > 0 and len(eos_ids) > 0:
                    last_token_id = generated_ids[-1].item()
                    if last_token_id in eos_ids:
                        ended_with_eos = True
                hit_limit = (output_len >= max_new_tokens) and (not ended_with_eos)
                hit_limit_flags.append(hit_limit)

                # 解码输出
                output_text = self.text_tokenizer.decode(
                    generated_ids,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=True,
                )
                # print("\n", output_text)
                outputs.append(output_text)

        for _, text in enumerate(outputs):
            console.print("\n[green]output without prompt:", text)

        return outputs, right_pad_lens, hit_limit_flags

    # ----------------------------------------------------------------------
    # 单样本推理
    # ----------------------------------------------------------------------
    def generate_one(self, item, max_new_tokens: int, gen_cfg: Dict):
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outputs[0], right_pads[0], hit_limits[0]

    # ----------------------------------------------------------------------
    # 右侧 padding 长度计算：与 QwenVLVLM 保持一致
    # ----------------------------------------------------------------------
    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.text_tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.text_tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad，并且 pad 与 eos 相同导致可能多截一个
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
