from typing import List, Dict, Tuple
import os
import sys

import torch
from rich.console import Console

from .base import BaseVLM

# 从 utils 导入通用工具
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()


def _flatten_eos_ids(*values) -> List[int]:
    """
    把各种形态的 eos_token_id（int / list / 嵌套 list / tuple 等）
    展平成一个「去重后的 int 列表」，避免 set(list) 触发 'unhashable type: list'
    """
    out: List[int] = []

    def _walk(x):
        if x is None:
            return
        if isinstance(x, (list, tuple, set)):
            for y in x:
                _walk(y)
        elif isinstance(x, int):
            if x not in out:
                out.append(x)
        # 其他类型（比如 str）直接忽略

    for v in values:
        _walk(v)
    return out


class Gemma3VLM(BaseVLM):
    """
    Gemma3 多模态（如 google/gemma-3-4b-it）适配类

    参考官方示例：

        inputs = processor.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=True,
            return_dict=True, return_tensors="pt"
        ).to(model.device, dtype=torch.bfloat16)

    这里为了兼容你整套 batch 推理 + hook 提取隐状态的框架：

    - 每个样本单独调用 processor.apply_chat_template(...) 得到 features（包含 input_ids + 视觉字段）
    - 然后手动 left-pad + stack 成一个大 batch
    - model.generate 之后截掉 prefix，再 decode 输出
    - generate_batch / generate_one 返回格式与其他 VLM 一致：
        * generate_batch -> (outputs, right_pad_lens, hit_limits)
        * generate_one   -> (output_text, right_pad_len, hit_limit)
    """

    def __init__(self, model, processor, device: str = "cuda"):
        """
        model: Gemma3ForConditionalGeneration.from_pretrained(...)
        processor: AutoProcessor.from_pretrained(...)
        """
        tokenizer = getattr(processor, "tokenizer", None)
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            processor=processor,
            device=device,
        )

        # 统一左侧 padding，方便“输入 + 输出”切分逻辑
        if self.tokenizer is not None:
            self.tokenizer.padding_side = "left"
            if self.tokenizer.pad_token is None and getattr(self.tokenizer, "eos_token", None) is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        if hasattr(self.processor, "tokenizer") and self.processor.tokenizer is not None:
            self.processor.tokenizer.padding_side = "left"
            if (
                self.processor.tokenizer.pad_token is None
                and getattr(self.processor.tokenizer, "eos_token", None) is not None
            ):
                self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

    # ------------------------------------------------------------------
    # 消息构造：从 item 里抽 prompt + image_path，拼成 Gemma3 官方格式 messages
    # ------------------------------------------------------------------
    def _build_messages(self, item: dict):
        """
        从 item 中解析 prompt + 图片路径，构造 Gemma3 多模态 messages：

        [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": "/path/to/img1.jpg"},
                    {"type": "image", "image": "/path/to/img2.jpg"},
                    {"type": "text", "text": "<你的 prompt>"},
                ],
            }
        ]
        """
        prompt = parse_input(item)
        image_paths = get_image_path(item)
        image_paths = _normalize_to_list(image_paths)

        if not image_paths:
            raise ValueError("Gemma3VLM 目前仅支持多模态输入（需要至少一张图片）")

        content = []
        for p in image_paths:
            if p is None:
                continue
            # 这里直接把本地路径塞给 AutoProcessor，类似官方用 URL 的方式
            content.append({"type": "image", "image": p})

        if not content:
            raise ValueError("Gemma3VLM: 未成功解析任何图片路径，请检查 item 中的 image 字段")

        content.append({"type": "text", "text": prompt})

        messages = [
            {
                "role": "user",
                "content": content,
            }
        ]
        return messages

    # ------------------------------------------------------------------
    # 单样本编码：调用 AutoProcessor.apply_chat_template
    # ------------------------------------------------------------------
    def _encode_one(self, messages):
        """
        对单个 messages 调用 AutoProcessor.apply_chat_template，
        返回：
          - features: dict[str, tensor]
          - prompt_len: input_ids 的长度（该样本 prefix token 数）
        """
        enc = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        )

        # enc 可能是 BatchEncoding / dict，统一转 dict
        if not isinstance(enc, dict):
            enc = dict(enc)

        input_ids = enc["input_ids"]
        # 通常 shape 为 [1, seq_len]
        if input_ids.dim() == 2 and input_ids.size(0) == 1:
            prompt_len = input_ids.size(1)
        else:
            prompt_len = input_ids.numel()

        # 不在这里 .to(self.device)，避免多次搬运；stack 时再搬运
        return enc, prompt_len

    # ------------------------------------------------------------------
    # 将若干单样本 features 合成一个大 batch（左侧 pad）
    # ------------------------------------------------------------------
    def _stack_and_pad(self, encoded_list: List[dict], pad_token_id: int) -> Dict[str, torch.Tensor]:
        """
        文本部分：
          - input_ids: 左 pad 到相同长度
          - attention_mask: 由 input_ids != pad_token_id 构造

        视觉 / 其它部分：
          - 仅保留高维(>=3)张量（如 pixel_values / video_frames 等），沿 batch 维 cat
          - 对于 1D/2D 的文本辅助字段（如 token_type_ids、position_ids 等）直接跳过，
            避免长度不一致导致 stack 报错。
        """
        # 1) 文本：left-pad input_ids
        seqs: List[torch.Tensor] = []
        for feat in encoded_list:
            ids = feat["input_ids"]
            # 通常是 [1, L]，压掉 batch 维
            if ids.dim() == 2 and ids.size(0) == 1:
                ids = ids[0]
            seqs.append(ids)

        max_len = max(s.size(0) for s in seqs)
        dtype = seqs[0].dtype
        device = seqs[0].device

        input_ids = torch.full(
            (len(seqs), max_len),
            pad_token_id,
            dtype=dtype,
            device=device,
        )
        for i, s in enumerate(seqs):
            input_ids[i, -s.size(0):] = s

        attention_mask = (input_ids != pad_token_id).long()

        inputs: Dict[str, torch.Tensor] = {
            "input_ids": input_ids.to(self.device),
            "attention_mask": attention_mask.to(self.device),
        }

        # 2) 其它字段：只保留高维(>=3)的视觉/视频张量
        sample_keys = list(encoded_list[0].keys())
        for k in sample_keys:
            if k in {"input_ids", "attention_mask"}:
                continue

            tensors = [feat[k] for feat in encoded_list]
            t0 = tensors[0]
            if not isinstance(t0, torch.Tensor):
                # 非 tensor 的字段直接忽略
                continue

            # 对于 1D/2D 的文本辅助字段（长度跟 input_ids 相关），直接跳过
            if t0.dim() <= 2:
                continue

            # 对于图片 / 视频等高维张量，沿 batch 维拼起来
            try:
                cat = torch.cat(tensors, dim=0).to(self.device)
            except Exception:
                # 如果 cat 失败，再尝试 stack（要求所有 shape 相同）
                if all(t.shape == t0.shape for t in tensors):
                    cat = torch.stack(tensors, dim=0).to(self.device)
                else:
                    # 实在对不齐就放弃这个 key
                    continue
            inputs[k] = cat

        return inputs

    # ------------------------------------------------------------------
    # 生成配置：处理 eos / pad
    # ------------------------------------------------------------------
    def _prepare_generation_config(
        self,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Dict:
        """
        整理最终用于 model.generate 的 kwargs（除了 inputs）
        """
        generation_config = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **(gen_cfg or {}),
        )

        # eos_token_id
        eos_ids = _flatten_eos_ids(
            getattr(self.model.generation_config, "eos_token_id", None),
            getattr(self.tokenizer, "eos_token_id", None) if self.tokenizer is not None else None,
        )
        if eos_ids:
            generation_config["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        # pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None and self.tokenizer is not None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None and self.tokenizer is not None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)

        if pad_id is not None:
            generation_config["pad_token_id"] = pad_id
            self.model.config.pad_token_id = pad_id
            self.model.generation_config.pad_token_id = pad_id

        return generation_config

    # ------------------------------------------------------------------
    # 批量推理
    # ------------------------------------------------------------------
    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool = False,
        bs_estimate_gen_cfg: Dict = {}
    ) -> Tuple[List[str], List[int], List[bool]]:
        """
        真·batch 推理：

          1. 逐样本构造 messages，并通过 AutoProcessor.apply_chat_template 编码
          2. 手动 left-pad & stack 成批量 inputs
          3. 一次性 model.generate
          4. 截掉 prefix，按样本 decode 输出

        返回：
          - outputs: 每个样本的文本输出
          - right_pad_lens: 每个样本右侧填充长度（pad + 多余 eos）
          - hit_limits: 是否达到 max_new_tokens 上限（粗略判断）
        """
        assert self.processor is not None, "Gemma3VLM 需要 AutoProcessor"
        batch_size = len(items)
        assert batch_size > 0, "generate_batch: items 不能为空"

        console.print(f"[cyan][Gemma3VLM] 批量推理，batch_size={batch_size}[/cyan]")

        # 1) 每个样本单独 encode
        encoded_list: List[dict] = []
        for item in items:
            messages = self._build_messages(item)
            features, _ = self._encode_one(messages)
            encoded_list.append(features)

        # 2) 生成配置 & pad_token_id
        generation_config = self._prepare_generation_config(
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
        )
        pad_token_id = generation_config.get(
            "pad_token_id",
            getattr(self.tokenizer, "pad_token_id", None) if self.tokenizer is not None else None,
        )
        if pad_token_id is None and self.tokenizer is not None:
            pad_token_id = getattr(self.tokenizer, "eos_token_id", None)

        if pad_token_id is None:
            raise ValueError("Gemma3VLM: 无法确定 pad_token_id")

        # 3) 手动 stack & pad
        inputs = self._stack_and_pad(encoded_list, pad_token_id=pad_token_id)

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        # 4) 一次性 generate
        with torch.no_grad():
            generate_ids = self.model.generate(
                **inputs,
                **generation_config,
        )

        # 5) 截掉前缀（长度 = padded input_ids 的长度）
        prefix_len_padded = inputs["input_ids"].shape[1]
        new_tokens_all = generate_ids[:, prefix_len_padded:]  # [B, T_out_max]

        # 6) 计算 right_pad_len / hit_limit，并解码
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []
        decoded_token_ids: List[List[int]] = []

        eos_ids = _flatten_eos_ids(
            getattr(self.model.generation_config, "eos_token_id", None),
            getattr(self.tokenizer, "eos_token_id", None) if self.tokenizer is not None else None,
        )

        for i in range(batch_size):
            seq = new_tokens_all[i]  # 仅生成部分
            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)
            seq = seq[:-cut] if cut > 0 else seq

            out_len = int(seq.shape[0])

            ended_with_eos = False
            if out_len > 0 and eos_ids:
                if seq[-1].item() in eos_ids:
                    ended_with_eos = True

            hit_limits.append(out_len >= max_new_tokens and not ended_with_eos)
            decoded_token_ids.append(seq.cpu().tolist())

        # 7) decode 输出文本
        if hasattr(self.processor, "batch_decode"):
            texts = self.processor.batch_decode(
                decoded_token_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )
        elif self.tokenizer is not None and hasattr(self.tokenizer, "batch_decode"):
            texts = self.tokenizer.batch_decode(
                decoded_token_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )
        else:
            # 兜底逐条 decode
            texts = []
            for ids in decoded_token_ids:
                if hasattr(self.processor, "decode"):
                    texts.append(
                        self.processor.decode(
                            ids,
                            skip_special_tokens=True,
                            clean_up_tokenization_spaces=False,
                        )
                    )
                elif self.tokenizer is not None and hasattr(self.tokenizer, "decode"):
                    texts.append(
                        self.tokenizer.decode(
                            ids,
                            skip_special_tokens=True,
                            clean_up_tokenization_spaces=False,
                        )
                    )
                else:
                    texts.append("")

        outputs = [t.strip() for t in texts]

        for o in outputs:
            console.print("\n[yellow]output without prompt: ", o)

        return outputs, right_pad_lens, hit_limits

    # ------------------------------------------------------------------
    # 单样本推理：复用 generate_batch
    # ------------------------------------------------------------------
    def generate_one(
        self,
        item: dict,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Tuple[str, int, bool]:
        outputs, right_pad_lens, hit_limits = self.generate_batch(
            [item],
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
            oom_estimate=False,
            bs_estimate_gen_cfg={}
        )
        return outputs[0], right_pad_lens[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
