from __future__ import annotations

from typing import List, Dict, Tuple, Optional
import os
import sys
from copy import deepcopy
from contextlib import nullcontext

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# project root for utils
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

from utils import parse_input, get_image_path, _normalize_to_list  # noqa: E402

console = Console()


class MiniCPMLlama3V25VLM(BaseVLM):
    """
    Backend ONLY for: openbmb/MiniCPM-Llama3-V-2_5

    设计原则：
    1) 完全跟随该模型 remote code 的 model.chat() 输入构造：
       - 把 image 注入到 msgs[0]["content"] 里（[PIL.Image, text]）
       - 把图片替换成 "(<image>./</image>)"
       - prompt = processor.tokenizer.apply_chat_template(..., tokenize=False, add_generation_prompt=True)
       - inputs = processor(prompt, images, return_tensors="pt", max_length=max_inp_length)

    2) 生成走该模型自带的 model.generate(model_inputs, tokenizer=..., decode_text=False)
       - 避免旧版 backend 里依赖的 _process_list / 旧 image_processor API 等不兼容点
       - 右侧 padding 固定为 0（与该模型 _decode() 保持一致）
       - eos 同时考虑 eos_id + eot_id（该模型 _decode() 使用两个 terminators）
    """

    def __init__(self, model, tokenizer, processor=None, device: str = "cuda"):
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        if self.processor is None:
            raise RuntimeError(
                "MiniCPMLlama3V25VLM 需要 AutoProcessor (trust_remote_code=True)；"
                "请确保 HiddenExtractor 给该模型传入的是 AutoProcessor。"
            )

        # 左 padding 对齐（跟官方 chat 模板一致）
        try:
            self.tokenizer.padding_side = "left"
        except Exception:
            pass

        # pad_token：主要给 tokenizer.decode / 部分 HF API 用；该模型生成 pad 固定用 0
        if getattr(self.tokenizer, "pad_token", None) is None:
            if getattr(self.tokenizer, "unk_token", None) is not None:
                self.tokenizer.pad_token = self.tokenizer.unk_token
            elif getattr(self.tokenizer, "eos_token", None) is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        # 避免 processor/tokenizer 混用（你项目里很多 backend 都会踩这个坑）
        if hasattr(self.processor, "tokenizer") and self.processor.tokenizer is not None:
            try:
                self.processor.tokenizer = self.tokenizer
            except Exception:
                pass
            try:
                self.processor.tokenizer.padding_side = "left"
                if getattr(self.processor.tokenizer, "pad_token", None) is None and getattr(self.tokenizer, "pad_token", None) is not None:
                    self.processor.tokenizer.pad_token = self.tokenizer.pad_token
            except Exception:
                pass

        # 该模型 _decode() 固定 pad_token_id=0
        self._pad_token_id = 0

        # eos / eot
        self._eos_ids = self._resolve_eos_ids()

    def _llm_dtype(self) -> torch.dtype:
        if hasattr(self.model, "llm"):
            try:
                return next(self.model.llm.parameters()).dtype
            except Exception:
                pass
        try:
            return next(self.model.parameters()).dtype
        except Exception:
            return torch.float32

    def _maybe_autocast(self):
        if self.device.startswith("cuda"):
            dt = self._llm_dtype()
            if dt in (torch.float16, torch.bfloat16):
                return torch.cuda.amp.autocast(dtype=dt)
        return nullcontext()

    def _safe_token_id(self, token: str) -> Optional[int]:
        try:
            tid = self.tokenizer.convert_tokens_to_ids(token)
            unk = getattr(self.tokenizer, "unk_token_id", None)
            if unk is not None and tid == unk and token != getattr(self.tokenizer, "unk_token", None):
                return None
            if tid is None:
                return None
            return int(tid)
        except Exception:
            return None

    def _resolve_eos_ids(self) -> List[int]:
        """
        MiniCPM-Llama3-V-2_5 的 remote code _decode() 用了两个 terminators：
        - tokenizer.eos_token_id
        - tokenizer.convert_tokens_to_ids("<|eot_id|>")
        并且它的 tokenizer wrapper 里通常也会有 eos_id / eot_id。
        """
        eos_ids: List[int] = []

        # 优先：wrapper 属性（如果存在）
        for attr in ["eos_id", "eot_id", "eos_token_id"]:
            v = getattr(self.tokenizer, attr, None)
            if v is None:
                continue
            if isinstance(v, (list, tuple)):
                eos_ids.extend([int(x) for x in v if x is not None])
            else:
                eos_ids.append(int(v))

        # 再补：显式 <|eot_id|>
        eot = self._safe_token_id("<|eot_id|>")
        if eot is not None:
            eos_ids.append(eot)

        # 再补：generation_config 里可能也有
        eos_ids.extend(_normalize_to_list(getattr(getattr(self.model, "generation_config", None), "eos_token_id", None)))

        # 去重/过滤 None
        eos_ids = [int(x) for x in eos_ids if x is not None]
        eos_ids = list(sorted(set(eos_ids)))
        return eos_ids

    def _build_single_inputs(self, item: dict, max_inp_length: int):
        """
        复刻 modeling_minicpmv.py 的 model.chat() 那段输入构造逻辑。
        返回 processor(...) 输出的 MiniCPMVBatchFeature（单样本）。
        """
        prompt = parse_input(item)
        image_path = get_image_path(item)

        pil_img = None
        if image_path:
            pil_img = Image.open(image_path).convert("RGB")

        # msgs: 标准 chat 格式
        msgs = [{"role": "user", "content": prompt}]
        copy_msgs = deepcopy(msgs)

        # 注入 image 到第一轮 user content
        if pil_img is not None and isinstance(copy_msgs[0].get("content", None), str):
            copy_msgs[0]["content"] = [pil_img, copy_msgs[0]["content"]]

        # 把 content 里的 image -> "(<image>./</image>)"，并收集 images
        images: List[Image.Image] = []
        for i, msg in enumerate(copy_msgs):
            role = msg.get("role")
            content = msg.get("content")
            if role not in ["user", "assistant", "system"]:
                raise ValueError(f"Unsupported role in msgs: {role}")

            if isinstance(content, str):
                content = [content]

            cur_msgs: List[str] = []
            for c in content:
                if isinstance(c, Image.Image):
                    images.append(c)
                    cur_msgs.append("(<image>./</image>)")
                elif isinstance(c, str):
                    cur_msgs.append(c)
                else:
                    raise ValueError(f"Unsupported content type in msgs: {type(c)}")

            msg["content"] = "\n".join(cur_msgs)

        # chat_template
        if not hasattr(self.processor, "tokenizer") or self.processor.tokenizer is None:
            raise RuntimeError("processor.tokenizer 不存在，无法 apply_chat_template（检查 AutoProcessor 是否加载正确）")

        prompt_text = self.processor.tokenizer.apply_chat_template(
            copy_msgs,
            tokenize=False,
            add_generation_prompt=True,
        )

        # processor: 单条输入（prompt_text 是 str；images 是 list[PIL]）
        inputs = self.processor(
            prompt_text,
            images,
            return_tensors="pt",
            max_length=max_inp_length,
        )
        return inputs

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:

        if not items:
            return [], [], []

        cfg = dict(gen_cfg or {})
        # 与 remote code chat() 对齐：默认 max_inp_length=2048
        max_inp_length = int(cfg.pop("max_inp_length", 2048))

        # 生成参数白名单：避免把奇怪字段喂给 HF generate
        allowed_keys = {
            "do_sample",
            "temperature",
            "top_p",
            "top_k",
            "repetition_penalty",
            "num_beams",
            "min_new_tokens",
            "max_new_tokens",
            "length_penalty",
            "no_repeat_ngram_size",
            "presence_penalty",
            "frequency_penalty",
        }
        user_gen_kwargs = {k: v for k, v in cfg.items() if k in allowed_keys}

        gen_kwargs = {"do_sample": False,}
        gen_kwargs.update(user_gen_kwargs)
        gen_kwargs["max_new_tokens"] = int(max_new_tokens)

        # 1) 单条 processor(...) -> list
        single_features = []
        for it in items:
            single_features.append(self._build_single_inputs(it, max_inp_length=max_inp_length))

        # 2) pad input_ids 成 batch
        # processor.pad 的实现来自 processing_minicpmv.py（left pad）
        input_ids = self.processor.pad(
            orig_items=single_features,
            key="input_ids",
            padding_value=self._pad_token_id,
            padding_side="left",
            max_length=None,
        ).to(self.device)

        # 3) 组装 model_inputs（注意：该模型 generate 不会帮你搬 tgt_sizes/image_bound 到 device）
        pixel_values: List[List[torch.Tensor]] = []
        tgt_sizes: List[torch.Tensor] = []
        image_bound: List[torch.Tensor] = []

        for feat in single_features:
            # pixel_values: [[...]] -> [...]
            pv = feat.get("pixel_values", None)
            if pv is None:
                pixel_values.append([])
            else:
                # feat["pixel_values"] 结构是 batch=1： [new_images]
                pixel_values.append(pv[0])

            ts = feat.get("tgt_sizes", None)
            if ts is None:
                tgt_sizes.append(torch.zeros((0, 2), dtype=torch.int32, device=self.device))
            else:
                tgt_sizes.append(ts[0].to(self.device))

            ib = feat.get("image_bound", None)
            if ib is None:
                image_bound.append(torch.zeros((0, 2), dtype=torch.int64, device=self.device))
            else:
                image_bound.append(ib[0].to(self.device))

        model_inputs = {
            "input_ids": input_ids,
            "pixel_values": pixel_values,
            "tgt_sizes": tgt_sizes,
            "image_bound": image_bound,
        }

        # 4) generate：走该模型 remote code 的 MiniCPMV.generate()
        eos_ids = list(sorted(set(self._eos_ids)))
        if not eos_ids:
            # 最差也要给一个 eos，否则有些环境会报错
            fallback = getattr(self.tokenizer, "eos_token_id", None)
            if fallback is not None:
                eos_ids = [int(fallback)]

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            model_inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=model_inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        # 该模型 _decode() 的 pad_token_id 固定为 0；eos_token_id 支持 list
        with torch.inference_mode():
            with self._maybe_autocast():
                output_ids = self.model.generate(
                    model_inputs,
                    tokenizer=self.tokenizer,
                    decode_text=False,
                    **gen_kwargs,
                )

        # 5) right_pad / hit_limit / decode
        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limit_flags: List[bool] = []

        for i in range(output_ids.shape[0]):
            total_seq = output_ids[i]
            total_len = int(total_seq.shape[0])

            cut = self.calculate_right_padding_length(total_seq)
            right_pad_lens.append(cut)

            out_len = total_len - cut
            new_ids = total_seq[:-cut] if cut > 0 else total_seq

            ended_with_eos = False
            if out_len > 0 and new_ids.numel() > 0:
                last_id = int(new_ids[-1].item())
                if last_id in set(eos_ids):
                    ended_with_eos = True

            # 这个 hit_limit 的定义沿用你现有项目里的写法（不改你下游逻辑）
            hit_limit = (out_len >= max_new_tokens) and (not ended_with_eos)
            hit_limit_flags.append(bool(hit_limit))

            # decode：尽量跟官方 _decode_text 对齐（去 pad=0；去 bos；去 eos/eot 末尾）
            seq = new_ids
            if torch.is_tensor(seq):
                seq = seq[seq != self._pad_token_id]

            if seq.numel() > 0:
                bos_id = getattr(self.tokenizer, "bos_id", None)
                if bos_id is not None and int(seq[0].item()) == int(bos_id):
                    seq = seq[1:]

            if seq.numel() > 0:
                eos_id = getattr(self.tokenizer, "eos_id", None)
                eot_id = getattr(self.tokenizer, "eot_id", None)
                tail = int(seq[-1].item())
                if eos_id is not None and tail == int(eos_id):
                    seq = seq[:-1]
                elif eot_id is not None and tail == int(eot_id):
                    seq = seq[:-1]

            text = self.tokenizer.decode(seq.tolist(), skip_special_tokens=True).strip()

            # 再兜底：把显式 token 字符串剔掉（跟 remote code stream 模式一致）
            eot_token = getattr(self.tokenizer, "eot_token", None)
            eos_token = getattr(self.tokenizer, "eos_token", None)
            if eot_token:
                text = text.replace(eot_token, "")
            if eos_token:
                text = text.replace(eos_token, "")
            outputs.append(text.strip())

        for o in outputs:
            console.print("\n[yellow]output without prompt:", o)

        return outputs, right_pad_lens, hit_limit_flags

    def generate_one(self, item: dict, max_new_tokens: int, gen_cfg: Dict):
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        for out in outputs:
            console.print("\n[yellow]output without prompt:", out)
        return outputs[0], right_pads[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len

