from __future__ import annotations

from typing import Dict, List, Tuple, Any

import os
import sys

import torch
from PIL import Image
from rich.console import Console
from .base import BaseVLM

console = Console()

# 复用项目内的输入/路径解析工具
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402


class R4BVLM(BaseVLM):
    """YannQi/R-4B 后端。

    设计目标：
    - 与 Qwen-VL 后端解耦（单独文件/单独路由）。
    - 支持 R-4B 的 chat template 参数（例如 thinking_mode）传入 apply_chat_template。
    - 保持与 BaseVLM 一致的返回：outputs, right_pad_lens, hit_limits。

    目前约束：
    - 默认按“每条样本一张图”的多模态输入实现（与现有 QwenVLVLM 一致）。
      如需要混合纯文本/多图 batch，建议后续改成逐样本 processor 后再手动 pad/stack。
    """

    def __init__(self, model, tokenizer, processor=None, device: str = "cuda"):
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        # 1) 选主：processor.tokenizer 存在就用它（因为 processor 负责 tokenize）
        if self.processor is not None and getattr(self.processor, "tokenizer", None) is not None:
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = tokenizer  # 外部传入

        # 2) 强制绑定：让 processor.tokenizer 与 self.tokenizer 指向同一对象
        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer = self.tokenizer

        # R-4B 是 decoder-only 生成，左 padding 更常见；但不要强行把 pad_token 设成 eos（有些模型 pad_token_id != eos）
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
            if pad_id is not None:
                try:
                    tok = self.tokenizer.convert_ids_to_tokens(int(pad_id))
                    # 验证 round-trip，避免设了一个“看起来像 token 但不对应 pad_id”的字符串
                    if tok is not None and self.tokenizer.convert_tokens_to_ids(tok) == int(pad_id):
                        self.tokenizer.pad_token = tok
                except Exception:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None:
                pad_id = getattr(self.processor.tokenizer, "pad_token_id", None)
                if pad_id is not None:
                    try:
                        tok = self.processor.tokenizer.convert_ids_to_tokens(
                            self.processor.tokenizer.pad_token_id
                        )
                        if tok is not None and self.processor.tokenizer.convert_tokens_to_ids(tok) == int(pad_id):
                            self.processor.tokenizer.pad_token = tok
                    except Exception:
                        self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
                else:
                    self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

    def _split_cfg(self, gen_cfg: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str, Any]]:
        """把 gen_cfg 拆成两部分：

        - chat_template_kwargs: 传给 processor.apply_chat_template()
        - generate_kwargs:      传给 model.generate()
        """
        cfg = dict(gen_cfg or {})

        chat_template_kwargs: Dict[str, Any] = {}
        ctk = cfg.pop("chat_template_kwargs", None)
        if isinstance(ctk, dict):
            chat_template_kwargs.update(ctk)

        # 兼容：用户直接传 thinking_mode
        if "thinking_mode" in cfg:
            chat_template_kwargs.setdefault("thinking_mode", cfg.pop("thinking_mode"))

        return chat_template_kwargs, cfg

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:
        assert self.processor is not None, "R4BVLM 需要 AutoProcessor"

        chat_template_kwargs, generate_kwargs = self._split_cfg(gen_cfg)

        # 1) 组装 messages（R-4B 的 processor 通常也走 role/content 模式）
        messages_list = []
        images: List[Image.Image] = []
        for item in items:
            prompt = parse_input(item)
            image_path = get_image_path(item)
            if not image_path:
                raise ValueError("R4BVLM 当前实现仅支持多模态输入（每条样本至少一张图）。")

            messages_list.append(
                [
                    {
                        "role": "user",
                        "content": [
                            {"type": "image"},
                            {"type": "text", "text": prompt},
                        ],
                    }
                ]
            )
            images.append(Image.open(image_path).convert("RGB"))

        # 2) chat template
        texts: List[str] = []
        for msg in messages_list:
            try:
                texts.append(
                    self.processor.apply_chat_template(
                        msg,
                        tokenize=False,
                        add_generation_prompt=True,
                        **chat_template_kwargs,
                    )
                )
            except TypeError:
                # processor 不支持传入额外 kwargs（例如 thinking_mode）时回退
                texts.append(
                    self.processor.apply_chat_template(
                        msg,
                        tokenize=False,
                        add_generation_prompt=True,
                    )
                )

        # 3) processor -> batch inputs
        try:
            inputs = self.processor(
                text=texts,
                images=images,
                padding=True,
                padding_side="left",
                return_tensors="pt",
            )
        except TypeError:
            inputs = self.processor(
                text=texts,
                images=images,
                padding=True,
                return_tensors="pt",
            )

        inputs = inputs.to(self.device)

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        # 4) generate
        gen_conf = {
            "max_new_tokens": max_new_tokens,
            "do_sample": False,
            **generate_kwargs,
        }

        # eos/pad 处理：尽量尊重模型/分词器已有配置
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list({e for e in (gc_eos + tok_eos) if e is not None})
        if eos_ids:
            gen_conf["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)
        if pad_id is not None:
            gen_conf["pad_token_id"] = pad_id
            if getattr(self.model, "config", None) is not None:
                self.model.config.pad_token_id = pad_id
            if getattr(self.model, "generation_config", None) is not None:
                self.model.generation_config.pad_token_id = pad_id

        with torch.no_grad():
            output_ids = self.model.generate(**inputs, **gen_conf)

        # 5) 对齐输出：去掉左侧（含 pad）的输入部分 + 右侧 padding/eos
        input_ids = inputs["input_ids"]
        batch_size = input_ids.size(0)
        input_len_padded = input_ids.size(1)

        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []

        for i in range(batch_size):
            seq = output_ids[i]
            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)

            generated_ids = seq[input_len_padded:-cut] if cut > 0 else seq[input_len_padded:]

            ended_with_eos = False
            if generated_ids.numel() > 0 and eos_ids:
                ended_with_eos = int(generated_ids[-1].item()) in eos_ids
            hit_limits.append((generated_ids.numel() >= max_new_tokens) and (not ended_with_eos))

            outputs.append(
                self.tokenizer.decode(
                    generated_ids,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=True,
                )
            )
        for o in outputs:
            console.print("\n[yellow]output without prompt: ", o)

        return outputs, right_pad_lens, hit_limits

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outputs[0], right_pads[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
