from __future__ import annotations

from typing import List, Dict, Tuple, Optional
import importlib
from rich.console import Console
import torch
from PIL import Image

from .base import BaseVLM

# 从 utils 导入：parse_input / get_image_path / _normalize_to_list
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import parse_input, get_image_path, _normalize_to_list  # noqa: E402

console = Console()

class POINTSVLM(BaseVLM):
    """WePOINTS/POINTS-* Chat 模型后端（真 batch 实现）"""

    def __init__(self, model, tokenizer, image_processor=None, device: str = "cuda"):
        super().__init__(model=model, tokenizer=tokenizer, processor=image_processor, device=device)
        self.image_processor = image_processor
        assert self.image_processor is not None, "POINTSVLM 需要 CLIPImageProcessor（create_backend 已传入 processor）"

        # 左 padding，保证 prompt_last_token 对齐真实 prompt 末 token
        self.tokenizer.padding_side = "left"

        # 动态拿到 remote code 里导入的 split 函数（modeling_points_chat.py 顶部导入）
        mp = importlib.import_module(self.model.__module__)
        self._split_image_with_catty = getattr(mp, "split_image_with_catty", None)
        self._split_image = getattr(mp, "split_image", None)
        if self._split_image_with_catty is None or self._split_image is None:
            raise RuntimeError(
                f"无法从 {self.model.__module__} 获取 split_image_with_catty/split_image。"
                "可能是 POINTS 版本变化，请检查 remote code。"
            )

        # eos token（官方示例使用 <|im_end|>）
        self.eos_token_id = self._safe_token_id(
            "<|im_end|>", fallback=getattr(self.tokenizer, "eos_token_id", None)
        )

        # 自动探测 image token（Yi 版是 <|endoftext|>，Qwen 版是 <|image_pad|>）
        self.image_token_id = self._resolve_image_token_id()

        # dtype 推断
        self._vision_dtype = self._infer_param_dtype(["general_vit", "vision_tower", "vision_model"])
        self._llm_dtype = self._infer_param_dtype(["llm", "language_model", "model"]) or self._vision_dtype

    def _infer_param_dtype(self, attr_candidates: List[str]) -> torch.dtype:
        for attr in attr_candidates:
            m = getattr(self.model, attr, None)
            if m is None:
                continue
            try:
                return next(m.parameters()).dtype
            except Exception:
                continue
        try:
            return next(self.model.parameters()).dtype
        except Exception:
            return torch.float32

    def _safe_token_id(self, token: str, fallback: Optional[int]) -> Optional[int]:
        try:
            tid = self.tokenizer.convert_tokens_to_ids(token)
            unk_id = getattr(self.tokenizer, "unk_token_id", None)
            if unk_id is not None and tid == unk_id and token != getattr(self.tokenizer, "unk_token", None):
                return fallback
            return tid
        except Exception:
            return fallback

    def _resolve_image_token_id(self) -> int:
        """用 apply_chat_template(image_num=1) 实测哪个 token 被重复了 144 次。"""
        try:
            tmpl = self.model.apply_chat_template("test", 1)
        except Exception:
            # fallback：直接在候选里选能取到 id 的
            for tok in ["<|image_pad|>", "<|endoftext|>"]:
                tid = self._safe_token_id(tok, None)
                if tid is not None:
                    return int(tid)
            raise RuntimeError("无法解析 POINTS image token id：apply_chat_template 不可用，且候选 token 都不存在。")

        enc = self.tokenizer(tmpl, return_tensors="pt")
        ids = enc["input_ids"][0]

        best = None
        best_cnt = -1
        for tok in ["<|image_pad|>", "<|endoftext|>"]:
            tid = self._safe_token_id(tok, None)
            if tid is None:
                continue
            cnt = int((ids == int(tid)).sum().item())
            if cnt > best_cnt:
                best_cnt = cnt
                best = int(tid)

        if best is None:
            raise RuntimeError("无法解析 POINTS image token id：候选 token 都不可用。")
        return best

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:

        if not items:
            return [], [], []

        cfg = dict(gen_cfg or {})
        catty = bool(cfg.pop("catty", True))
        max_splits = int(cfg.pop("max_splits", 8))

        # 1) 逐样本 split 图像 + 构造 prompt（预处理不可避免有 loop）
        prompts: List[str] = []
        crop_counts: List[int] = []
        all_crops: List[Image.Image] = []

        for item in items:
            prompt = parse_input(item)
            image_path = get_image_path(item)
            if not image_path:
                raise ValueError("POINTSVLM 目前仅支持多模态（每条样本必须有 image_path）")

            img = Image.open(image_path).convert("RGB")
            if catty:
                crops = self._split_image_with_catty(img, do_resize=True, max_crop_slices=max_splits)
            else:
                crops = self._split_image(img, max_splits=max_splits)

            crop_counts.append(len(crops))
            all_crops.extend(crops)
            prompts.append(self.model.apply_chat_template(prompt, len(crops)))

        # 2) 视觉编码（真 batch）
        pixel_values = self.image_processor.preprocess(all_crops, return_tensors="pt")["pixel_values"]
        pixel_values = pixel_values.to(self.device, dtype=self._vision_dtype)

        with torch.no_grad():
            general = self.model.extract_image_features(pixel_values, vision_encoder="general_vit")
            ocr = self.model.extract_image_features(pixel_values, vision_encoder="ocr_vit")
            image_features_flat = (0.5 * general + 0.5 * ocr).to(dtype=self._llm_dtype)

        # 还原成 List[Tensor]，长度=bsz
        image_features: List[torch.Tensor] = []
        off = 0
        for c in crop_counts:
            image_features.append(image_features_flat[off:off + c])
            off += c

        # 3) 文本 batch tokenize（左 padding）+ 一次 generate
        if getattr(self.tokenizer, "pad_token_id", None) is None:
            if getattr(self.tokenizer, "unk_token", None) is not None:
                self.tokenizer.pad_token = self.tokenizer.unk_token
            else:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        model_inputs = self.tokenizer(prompts, padding=True, return_tensors="pt")
        input_ids = model_inputs["input_ids"].to(self.device)
        attention_mask = model_inputs["attention_mask"].to(self.device)

        eos_ids = list(
            set(
                _normalize_to_list(self.eos_token_id)
                + _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
                + _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            )
        )
        eos_ids = [e for e in eos_ids if e is not None]
        if not eos_ids and self.eos_token_id is not None:
            eos_ids = [self.eos_token_id]

        pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = 0

        # 默认对齐官方示例：temperature=0 / top_p=0 / num_beams=1
        temperature = float(cfg.get("temperature", 0.0))
        do_sample = bool(cfg.get("do_sample", False)) or (temperature > 0.0)

        gen_kwargs = dict(
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=float(cfg.get("top_p", 0.0)),
            num_beams=int(cfg.get("num_beams", 1)),
            eos_token_id=eos_ids[0] if len(eos_ids) == 1 else eos_ids,
            pad_token_id=pad_id,
            return_dict_in_generate=False,
            output_scores=False,
            **{k: v for k, v in cfg.items() if k not in {"temperature", "top_p", "num_beams", "do_sample"}},
        )

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            batch_model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
            batch_model_inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=batch_model_inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
            input_ids, attention_mask = batch_model_inputs["input_ids"], batch_model_inputs["attention_mask"]
        # ===========================================================================

        with torch.no_grad():
            output_ids = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                image_features=image_features,
                image_token_id=int(self.image_token_id),
                **gen_kwargs,
            )

        # 4) 解码 + 统计 right_pad_len/hit_limit

        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []
        eos_set = set(eos_ids)

        for i in range(output_ids.size(0)):
            seq = output_ids[i]
            cut = self.calculate_right_padding_length(seq, pad_id=pad_id, eos_ids=eos_ids)
            right_pad_lens.append(cut)

            gen = seq[:-cut] if cut > 0 else seq
            ended_with_eos = (gen.numel() > 0) and (int(gen[-1].item()) in eos_set)
            hit_limits.append((gen.numel() >= max_new_tokens) and (not ended_with_eos))

            txt = self.tokenizer.decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            outputs.append(txt)

        for o in outputs:
            console.print("\n[yellow]output without prompt: ", o)

        return outputs, right_pad_lens, hit_limits

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outs, pads, hits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outs[0], pads[0], hits[0]

    @staticmethod
    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len


