from __future__ import annotations

import os
import sys
import importlib
from typing import List, Dict, Tuple, Optional, Any, DefaultDict
from collections import defaultdict

import torch
from rich.console import Console

from .base import BaseVLM

# 从 utils 导入通用工具（保持与你 qwen.py / phi.py 一致）
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()


class H2OVLVLM(BaseVLM):
    """
    H2OVL (h2oai/h2ovl-mississippi-2b) 适配类

    目标：
    - 真·batch：直接调用 remote code 的 model.generate(pixel_values, input_ids, attention_mask, ...)
    - attention_mask 不手写：使用 tokenizer(queries, padding=True, return_tensors="pt") 自动生成
    - right_pad：直接复用你现在的 calculate_right_padding_length（不改）
    - hit_limit：out_len >= max_new_tokens and not ended_with_eos（与 qwen/phi 对齐）
    - tokenizer 不混用：只使用传进来的 tokenizer（不从 processor 覆盖）
    """

    def __init__(
        self,
        model,
        tokenizer,
        processor=None,
        device: str = "cuda",
        max_tiles: int = 6,
        IMG_START_TOKEN: str = "<img>",
        IMG_END_TOKEN: str = "</img>",
        IMG_CONTEXT_TOKEN: str = "<IMG_CONTEXT>",
    ):
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        self.max_tiles = int(max_tiles)
        self.IMG_START_TOKEN = IMG_START_TOKEN
        self.IMG_END_TOKEN = IMG_END_TOKEN
        self.IMG_CONTEXT_TOKEN = IMG_CONTEXT_TOKEN

        # decoder-only：强制 left padding（与 qwen/phi 对齐）
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # remote code（trust_remote_code=True）里通常提供这些函数
        self._remote_mod = importlib.import_module(self.model.__class__.__module__)
        self._get_conv_template = getattr(self._remote_mod, "get_conv_template", None)
        self._load_single_image = getattr(self._remote_mod, "load_single_image", None)
        self._load_multi_images = getattr(self._remote_mod, "load_multi_images", None)

        if self._get_conv_template is None:
            raise RuntimeError("H2OVL: 未找到 get_conv_template（remote code 未正确加载？）")
        if self._load_single_image is None or self._load_multi_images is None:
            raise RuntimeError("H2OVL: 未找到 load_single_image/load_multi_images（remote code 未正确加载？）")

        # H2OVL generate() 内部会 assert img_context_token_id is not None
        self._set_img_context_token_id()

    def _model_dtype(self) -> torch.dtype:
        try:
            return next(self.model.parameters()).dtype
        except StopIteration:
            return torch.bfloat16

    def _set_img_context_token_id(self) -> None:
        img_ctx_id = int(self.tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN))
        if self.tokenizer.unk_token_id is not None and img_ctx_id == int(self.tokenizer.unk_token_id):
            raise ValueError(
                f"H2OVL: {self.IMG_CONTEXT_TOKEN} 未在 tokenizer vocab 中（变成 unk_token_id），请确认 tokenizer 加载正确"
            )
        self.model.img_context_token_id = img_ctx_id

    def _ensure_image_placeholders(self, prompt: str, n_images: int) -> str:
        """保证 prompt 里至少有 n_images 个 <image> 占位符（官方示例风格）"""
        prompt = (prompt or "").strip()
        if n_images <= 0:
            return prompt

        cnt = prompt.count("<image>")
        if cnt >= n_images:
            return prompt

        if cnt == 0:
            # 单图：<image>\n + prompt
            if n_images == 1:
                return "<image>\n" + prompt
            # 多图：Image-i: <image>\n...
            prefix = "".join([f"Image-{i+1}: <image>\n" for i in range(n_images)])
            return prefix + prompt

        # cnt in (1..n_images-1): 补齐缺失的
        missing = n_images - cnt
        return prompt + "\n" + "\n".join(["<image>"] * missing)

    def _load_images(self, image_paths) -> Tuple[Optional[torch.Tensor], List[int], int]:
        """
        返回：
          - pixel_values: [num_patches_total, 3, H, W] or None
          - num_patches_list: 每张图对应 patch 数
          - n_images: 图片张数
        """
        paths = [p for p in _normalize_to_list(image_paths) if p]
        if not paths:
            return None, [], 0

        if len(paths) == 1:
            pixel_values, num_patches_list = self._load_single_image(
                paths[0],
                max_num=self.max_tiles,
                msac=bool(getattr(self.model, "use_msac", False)),
            )
        else:
            pixel_values, num_patches_list = self._load_multi_images(paths, max_num=self.max_tiles)

        pixel_values = pixel_values.to(device=self.device, dtype=self._model_dtype())
        num_patches_list = [int(x) for x in (num_patches_list or [])]
        return pixel_values, num_patches_list, len(paths)

    def _build_query(
        self,
        prompt: str,
        num_patches_list: List[int],
    ) -> Tuple[str, str, int, int]:
        """
        复刻 model.chat() 内部逻辑：
          - conv_template.append_message(...)
          - query = template.get_prompt()
          - <image> 替换成 <img> + <IMG_CONTEXT>*K + </img>
          - eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
        """
        template_name = getattr(self.model, "template", None) or "h2ogpt2"
        template = self._get_conv_template(template_name)

        # system_message 与官方一致：优先用模型上的 system_message
        template.system_message = getattr(self.model, "system_message", getattr(template, "system_message", ""))

        template.append_message(template.roles[0], prompt)
        template.append_message(template.roles[1], None)
        query = template.get_prompt()

        num_image_token = int(getattr(self.model, "num_image_token", 0))
        total_img_ctx_tokens = 0

        if num_patches_list:
            for num_patches in num_patches_list:
                n_ctx = num_image_token * int(num_patches)
                total_img_ctx_tokens += n_ctx
                image_tokens = self.IMG_START_TOKEN + (self.IMG_CONTEXT_TOKEN * n_ctx) + self.IMG_END_TOKEN
                query = query.replace("<image>", image_tokens, 1)

        sep_str = template.sep
        eos_token_id = int(self.tokenizer.convert_tokens_to_ids(sep_str))
        return query, sep_str, eos_token_id, total_img_ctx_tokens

    def _prepare_generation_config(
        self,
        max_new_tokens: int,
        gen_cfg: Dict,
        eos_token_id: int,
    ) -> Dict:
        """
        整理最终用于 model.generate 的 kwargs（除了 inputs）
        - do_sample 默认 False（与你 qwen/phi 对齐），允许 gen_cfg 覆盖
        - eos_token_id 强制用 template.sep
        - pad_token_id：优先 model.generation_config，其次 tokenizer.pad_token_id，其次 tokenizer.eos_token_id
        """
        generation_config = dict(
            max_new_tokens=int(max_new_tokens),
            do_sample=False,
            **(gen_cfg or {}),
        )

        generation_config["eos_token_id"] = int(eos_token_id)

        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)
        if pad_id is None:
            raise ValueError("H2OVL: 无法确定 pad_token_id")

        generation_config["pad_token_id"] = int(pad_id)

        # 同步到 model config（给 right_pad / eos 判断用）
        try:
            self.model.config.pad_token_id = int(pad_id)
        except Exception:
            pass
        try:
            self.model.generation_config.pad_token_id = int(pad_id)
        except Exception:
            pass
        try:
            self.model.generation_config.eos_token_id = int(eos_token_id)
        except Exception:
            pass

        return generation_config

    def _run_group_generate(
        self,
        group_samples: List[Dict[str, Any]],
        max_new_tokens: int,
        gen_cfg: Dict,
        has_image: bool,
        eos_token_id: int,
        sep_str: str,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:
        """
        对一个“同构”group 做一次真 batch generate：
        - attention_mask 由 tokenizer padding 自动生成（不手写）
        - pixel_values（如果有图）按样本顺序 cat，保证与 <IMG_CONTEXT> flatten 顺序一致
        """
        queries = [s["query"] for s in group_samples]

        # mask 不手写：tokenizer 自动产 attention_mask
        tok = self.tokenizer(
            queries,
            padding=True,
            return_tensors="pt",
        )

        self._set_img_context_token_id()

        pixel_values = None
        if has_image:
            pv_list = [s["pixel_values"] for s in group_samples]
            # pv_list 每个 shape: [num_patches_i, 3, H, W]
            pixel_values = torch.cat(pv_list, dim=0).to(self.device, dtype=self._model_dtype())

        generation_config = self._prepare_generation_config(
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
            eos_token_id=eos_token_id,
        )

        tok = tok.to(self.device)

        # ================================ 显存估计 ==================================\
        if oom_estimate:
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            from utils import apply_prefill_extra_tokens
            tok = apply_prefill_extra_tokens(
                batch_size=len(group_samples),
                inputs=tok,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================
        input_ids = tok["input_ids"]
        attention_mask = tok["attention_mask"]

        with torch.no_grad():
            generate_ids = self.model.generate(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                **generation_config,
            )

        batch_size = input_ids.size(0)
        prefix_len_padded = input_ids.size(1)  # 因为 padding=True，batch 内统一长度

        # eos 集合（用于 ended_with_eos / hit_limit 判断）
        eos_ids = []
        eos_ids += _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        eos_ids += _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids += [int(eos_token_id)]
        eos_ids = list({int(x) for x in eos_ids if x is not None})

        right_pad_lens: List[int] = []
        hit_limit_flags: List[bool] = []
        decoded_token_ids: List[List[int]] = []

        for i in range(batch_size):
            out_seq = generate_ids[i]

            # 大多数 HF generate 会把 input_ids 拼在前面；这里做一次鲁棒判断
            if out_seq.numel() >= prefix_len_padded and torch.equal(out_seq[:prefix_len_padded], input_ids[i]):
                seq = out_seq[prefix_len_padded:]
            else:
                seq = out_seq  # 兜底：万一返回的不是 full sequence

            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)
            seq = seq[:-cut] if cut > 0 else seq

            out_len = int(seq.shape[0])

            ended_with_eos = False
            if out_len > 0:
                if int(seq[-1].item()) in eos_ids:
                    ended_with_eos = True

            hit_limit_flags.append(out_len >= int(max_new_tokens) and (not ended_with_eos))
            decoded_token_ids.append(seq.detach().cpu().tolist())

        texts = self.tokenizer.batch_decode(
            decoded_token_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
        )

        outputs: List[str] = []
        for t in texts:
            t = (t or "").strip()
            # 保险：即便 eos_token_id 是 sep，也做一次字符串截断
            if sep_str and sep_str in t:
                t = t.split(sep_str)[0].strip()
            outputs.append(t)

        return outputs, right_pad_lens, hit_limit_flags

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool = False,
        bs_estimate_gen_cfg: Dict = {}
    ) -> Tuple[List[str], List[int], List[bool]]:
        """
        真·batch 推理（按“同构”分组后各做一次 generate，保证兼容性）：
          1) 逐样本：load images -> build query（复刻 chat()）
          2) 按 (has_image, total_img_ctx_tokens, eos_token_id) 分组
          3) 每组 tokenizer padding -> 一次 generate（mask 不手写）
          4) right_pad/hit_limit/输出 decode
        """
        batch_size = len(items)
        assert batch_size > 0, "generate_batch: items 不能为空"

        prepared: List[Dict[str, Any]] = []
        for idx, item in enumerate(items):
            prompt = parse_input(item)
            image_paths = get_image_path(item)
            image_paths = _normalize_to_list(image_paths)

            pixel_values, num_patches_list, n_images = self._load_images(image_paths)

            if n_images > 0:
                prompt = self._ensure_image_placeholders(prompt, n_images=n_images)

            query, sep_str, eos_token_id, total_img_ctx_tokens = self._build_query(
                prompt=prompt,
                num_patches_list=num_patches_list,
            )

            prepared.append(
                dict(
                    idx=idx,
                    query=query,
                    sep_str=sep_str,
                    eos_token_id=eos_token_id,
                    has_image=(pixel_values is not None),
                    pixel_values=pixel_values,  # None for text-only
                    total_img_ctx_tokens=int(total_img_ctx_tokens),  # 0 for text-only
                )
            )

        # 分组：避免 remote generate() 对 batch 形态的潜在假设（同一组里 IMG_CONTEXT token 数一致）
        groups: DefaultDict[Tuple[bool, int, int, str], List[Dict[str, Any]]] = defaultdict(list)
        for s in prepared:
            key = (bool(s["has_image"]), int(s["total_img_ctx_tokens"]), int(s["eos_token_id"]), str(s["sep_str"]))
            groups[key].append(s)

        outputs_all = [""] * batch_size
        right_pads_all = [0] * batch_size
        hit_limits_all = [False] * batch_size

        for (has_image, _ctx_tokens, eos_token_id, sep_str), group_samples in groups.items():
            outs, rps, hits = self._run_group_generate(
                group_samples=group_samples,
                max_new_tokens=max_new_tokens,
                gen_cfg=gen_cfg,
                has_image=has_image,
                eos_token_id=eos_token_id,
                sep_str=sep_str,
                oom_estimate=oom_estimate,
                bs_estimate_gen_cfg=bs_estimate_gen_cfg
            )
            for s, o, rp, hl in zip(group_samples, outs, rps, hits):
                j = int(s["idx"])
                outputs_all[j] = o
                right_pads_all[j] = int(rp)
                hit_limits_all[j] = bool(hl)

        for o in outputs_all:
            console.print("\n[yellow]output without prompt: ", o)

        return outputs_all, right_pads_all, hit_limits_all

    def generate_one(
        self,
        item: dict,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Tuple[str, int, bool]:
        outputs, right_pads, hit_limits = self.generate_batch(
            [item],
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
            oom_estimate=False,
            bs_estimate_gen_cfg={}
        )
        return outputs[0], right_pads[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
