from typing import List, Dict, Tuple, Optional
import os
import sys
import re
from contextlib import nullcontext

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# project root for utils
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

from utils import parse_input, get_image_path, _normalize_to_list  # noqa: E402

console = Console()


class MiniCPMVLM(BaseVLM):
    """
    MiniCPM-V-2 (openbmb/MiniCPM-V-2)

    - prompt：<用户> + (<image placeholder> + \n + question) + <AI>
    - 为了精确 right_pad / hit_limit：走 _process_list -> get_vllm_embedding -> llm.generate 拿 output_ids
    - 无 batch_chat，但 generate(data_list, img_list, ...) 的实现本身支持 batch（chat 内部也是调 generate）
    """

    def __init__(self, model, tokenizer, processor=None, device: str = "cuda"):
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        # ---------- tokenizer：只认这一份，避免“混用” ----------
        self.tokenizer.padding_side = "left"
        if getattr(self.tokenizer, "pad_token", None) is None:
            if getattr(self.tokenizer, "unk_token", None) is not None:
                self.tokenizer.pad_token = self.tokenizer.unk_token
            elif getattr(self.tokenizer, "eos_token", None) is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        # 如果你外面真的传了 processor：强制让 processor.tokenizer 指向同一个对象（防“撞”）
        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            try:
                self.processor.tokenizer = self.tokenizer
            except Exception:
                pass
        if self.processor is not None and hasattr(self.processor, "tokenizer") and self.processor.tokenizer is not None:
            try:
                self.processor.tokenizer.padding_side = "left"
                if getattr(self.processor.tokenizer, "pad_token", None) is None:
                    self.processor.tokenizer.pad_token = self.tokenizer.pad_token
            except Exception:
                pass

        # generation_config：pad/eos 对齐（pad 固定用 0）
        try:
            if getattr(self.model, "generation_config", None) is not None:
                self.model.generation_config.pad_token_id = 0
                if getattr(self.model.generation_config, "eos_token_id", None) is None and getattr(self.tokenizer, "eos_token_id", None) is not None:
                    self.model.generation_config.eos_token_id = self.tokenizer.eos_token_id
        except Exception:
            pass
        try:
            if getattr(self.model, "config", None) is not None:
                self.model.config.pad_token_id = 0
        except Exception:
            pass

    def _llm_dtype(self) -> torch.dtype:
        """llm 子模块 dtype（用于 pixel_values / inputs_embeds / autocast）"""
        if hasattr(self.model, "llm"):
            try:
                return next(self.model.llm.parameters()).dtype
            except Exception:
                pass
        try:
            return next(self.model.parameters()).dtype
        except Exception:
            return torch.bfloat16

    # -----------------------------
    # Prompt / Image utilities
    # -----------------------------
    def _build_prompt_and_images(self, image: Optional[Image.Image], question: str) -> Tuple[str, List[Image.Image]]:
        question = question or ""
        images: List[Image.Image] = []

        if image is not None:
            if getattr(self.model.config, "slice_mode", False):
                images, final_placeholder = self.model.get_slice_image_placeholder(image, self.tokenizer)
                content = final_placeholder + "\n" + question
            else:
                images = [image]
                content = (
                    self.tokenizer.im_start
                    + self.tokenizer.unk_token * int(getattr(self.model.config, "query_num", 64))
                    + self.tokenizer.im_end
                    + "\n"
                    + question
                )
        else:
            content = question

        prompt = "<用户>" + content + "<AI>"
        return prompt, images

    # -----------------------------
    # Main inference
    # -----------------------------
    @torch.inference_mode()
    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:

        data_list: List[str] = []
        img_list: List[List[Image.Image]] = []

        for item in items:
            q = parse_input(item)
            img_path = get_image_path(item)
            img = Image.open(img_path).convert("RGB") if img_path else None
            prompt, images = self._build_prompt_and_images(img, q)
            data_list.append(prompt)
            img_list.append(images)

        max_inp_length = int(gen_cfg.get("max_inp_length", 4096))

        # tokenize（remote_code 内部是 left-pad + padding_value=0）
        model_inputs = self.model._process_list(self.tokenizer, data_list, max_inp_length=max_inp_length)

        # 把 tensor 项搬到 device
        for k, v in list(model_inputs.items()):
            if torch.is_tensor(v):
                model_inputs[k] = v.to(self.device)

        # ---------- 关键：统一 dtype，用 llm dtype（别用 next(model.parameters()) 兜底） ----------
        ldtype = self._llm_dtype()

        # pixel_values: 直接上 ldtype
        pixel_values = []
        for images in img_list:
            img_inps = []
            for img in images:
                pv = self.model.transform(img)              # 常见是 float32
                pv = pv.to(self.device, dtype=ldtype)        # 这里强制到 bf16/fp16
                img_inps.append(pv)
            pixel_values.append(img_inps if img_inps else [])
        model_inputs["pixel_values"] = pixel_values

        # ---------- 关键：get_vllm_embedding/llm.generate 放进 autocast，避免中间产出 float32 喂 bf16 权重 ----------
        use_amp = (str(self.device).startswith("cuda") and ldtype in (torch.float16, torch.bfloat16))
        amp_ctx = torch.autocast(device_type="cuda", dtype=ldtype) if use_amp else nullcontext()

        pad_id = getattr(self.model.generation_config, "pad_token_id", 0)
        eos_id = getattr(self.tokenizer, "eos_token_id", None)

        allowed_keys = {
            "do_sample",
            "temperature",
            "top_p",
            "top_k",
            "repetition_penalty",
            "num_beams",
            "max_new_tokens",
        }
        gen_kwargs = {k: v for k, v in (gen_cfg or {}).items() if k in allowed_keys}
        gen_kwargs["max_new_tokens"] = int(max_new_tokens)
        gen_kwargs.setdefault("do_sample", False)

        # 取 tokenizer 阶段的 attention_mask（_process_list 一般会给）
        attention_mask = model_inputs.get("attention_mask", None)

        # 兜底：如果 remote_code 没给 attention_mask，就用 input_ids 自己造一个
        if attention_mask is None and "input_ids" in model_inputs:
            attention_mask = (model_inputs["input_ids"] != pad_id).long()

        # 再兜底：确保在同一设备上
        if torch.is_tensor(attention_mask):
            attention_mask = attention_mask.to(self.device)

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)

            model_inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=model_inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
            _, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"]
        # ===========================================================================

        with amp_ctx:
            inputs_embeds, _vision_hs = self.model.get_vllm_embedding(model_inputs)

            # 再兜底一次（有些分支会返回 float）
            if torch.is_tensor(inputs_embeds) and inputs_embeds.dtype != ldtype:
                inputs_embeds = inputs_embeds.to(dtype=ldtype)

            output_ids = self.model.llm.generate(
                inputs_embeds=inputs_embeds,
                pad_token_id=pad_id,
                eos_token_id=eos_id,
                **gen_kwargs,
            )  # [B, T_total]

        # right_pad / hit_limit
        eos_ids = list(
            set(
                _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
                + _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            )
        )

        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limit_flags: List[bool] = []

        for i in range(output_ids.shape[0]):
            total_seq = output_ids[i]
            total_len = int(total_seq.shape[0])

            cut = self.calculate_right_padding_length(total_seq)
            right_pad_lens.append(cut)

            out_len = total_len - cut
            new_ids = total_seq[:-cut] if cut > 0 else total_seq

            ended_with_eos = False
            if out_len > 0 and new_ids.numel() > 0 and int(new_ids[-1].item()) in eos_ids:
                ended_with_eos = True

            hit_limit = (out_len >= max_new_tokens) and (not ended_with_eos)
            hit_limit_flags.append(bool(hit_limit))

            # decode（只用这一份 tokenizer）
            text = self.tokenizer.decode(new_ids.tolist(), skip_special_tokens=True).strip()
            outputs.append(text)

        for o in outputs:
            console.print("\n[yellow]output without prompt:", o)

        return outputs, right_pad_lens, hit_limit_flags

    def generate_one(self, item: dict, max_new_tokens: int, gen_cfg: Dict):
        outputs, right_pads, hit_limits = self.generate_batch(
            [item],
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
            oom_estimate=False,
            bs_estimate_gen_cfg={}
        )
        return outputs[0], right_pads[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len

