import os
import sys
from typing import List, Dict, Tuple, Any

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# 从 utils 导入通用工具
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path


console = Console()


class OmChatVLM(BaseVLM):
    """HF 上的 OmChat v2 适配（omlab/omchat-v2.0-13B-single-beta_hf）。

    关键点：
    - 官方 AutoProcessor/OmChatProcessor 的 __call__ 基本按「单样本」编写；这里做“真·batch”是：
      * 每个样本单独 processor(...) 得到 input_ids + images(=所有 patch 的 concat)
      * 手动 left-pad input_ids 组成 batch
      * 将每个样本的 images(补丁张量)按样本顺序 concat 成一个大张量传给 model
    - 必须强制 left padding（否则你在 extractor 里用 hidden_state[:, -1] 会取到 pad）。
    """

    def __init__(self, model, tokenizer, processor, device: str = "cuda"):
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        # 1) 选主：processor.tokenizer 存在就用它（因为 processor 负责 tokenize）
        if self.processor is not None and getattr(self.processor, "tokenizer", None) is not None:
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = tokenizer  # 外部传入

        # 2) 强制绑定：让 processor.tokenizer 与 self.tokenizer 指向同一对象
        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer = self.tokenizer

        # R-4B 是 decoder-only 生成，左 padding 更常见；但不要强行把 pad_token 设成 eos（有些模型 pad_token_id != eos）
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
            if pad_id is not None:
                try:
                    tok = self.tokenizer.convert_ids_to_tokens(int(pad_id))
                    # 验证 round-trip，避免设了一个“看起来像 token 但不对应 pad_id”的字符串
                    if tok is not None and self.tokenizer.convert_tokens_to_ids(tok) == int(pad_id):
                        self.tokenizer.pad_token = tok
                except Exception:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None:
                pad_id = getattr(self.processor.tokenizer, "pad_token_id", None)
                if pad_id is not None:
                    try:
                        tok = self.processor.tokenizer.convert_ids_to_tokens(
                            self.processor.tokenizer.pad_token_id
                        )
                        if tok is not None and self.processor.tokenizer.convert_tokens_to_ids(tok) == int(pad_id):
                            self.processor.tokenizer.pad_token = tok
                    except Exception:
                        self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
                else:
                    self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

    @staticmethod
    def _pad_left(seqs: List[torch.Tensor], pad_token_id: int) -> torch.Tensor:
        max_len = max(int(s.size(0)) for s in seqs)
        out = torch.full(
            (len(seqs), max_len),
            pad_token_id,
            dtype=seqs[0].dtype,
            device=seqs[0].device,
        )
        for i, s in enumerate(seqs):
            out[i, -s.size(0):] = s
        return out

    def _prepare_generation_config(self, max_new_tokens: int, gen_cfg: Dict) -> Dict:
        generation_config = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **(gen_cfg or {}),
        )

        # eos_token_id
        from_ids = [
            getattr(self.model.generation_config, "eos_token_id", None),
            getattr(self.tokenizer, "eos_token_id", None),
        ]
        eos_ids = list({i for i in _normalize_to_list(from_ids) if i is not None})
        if eos_ids:
            generation_config["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        # pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)
        if pad_id is None:
            pad_id = 0
        generation_config["pad_token_id"] = int(pad_id)

        # 同步到 model 侧（否则 generate 可能用 None）
        try:
            self.model.config.pad_token_id = int(pad_id)
        except Exception:
            pass
        try:
            self.model.generation_config.pad_token_id = int(pad_id)
        except Exception:
            pass

        return generation_config

    def _load_images(self, item: dict) -> List[Image.Image]:
        image_paths = _normalize_to_list(get_image_path(item))
        images: List[Image.Image] = []
        for p in image_paths:
            if not p:
                continue
            images.append(Image.open(p).convert("RGB"))
        return images

    def _encode_one(self, item: dict) -> Dict[str, Any]:
        prompt = parse_input(item)
        images = self._load_images(item)
        if not images:
            raise ValueError(
                "OmChatVLM 当前实现要求 batch 内所有样本都包含至少 1 张图片（不支持图文混合批）。"
            )

        # OmChatProcessor 多图模式依赖 <image> 占位符；若用户没写，就自动把图片占位符前置。
        if "<image>" not in prompt:
            prompt = ("<image>\n" * len(images)) + prompt

        # processor 支持单图 PIL 或 List[PIL]
        inputs = self.processor(
            text=prompt,
            images=images if len(images) > 1 else images[0],
            return_tensors="pt",
        )
        return inputs

    def _stack_and_pad(
        self, encoded_list: List[Dict[str, Any]], pad_token_id: int
    ) -> Dict[str, torch.Tensor]:
        # input_ids: left pad
        seqs = [feat["input_ids"][0] for feat in encoded_list]
        seqs = [s.to(self.device) for s in seqs]
        input_ids = self._pad_left(seqs, pad_token_id)
        attention_mask = (input_ids != pad_token_id).long()

        # images: 按样本顺序 concat 成一个大张量
        # OmChatProcessor 返回 key 名叫 "images"（不是 pixel_values）。
        img_tensors = []
        for feat in encoded_list:
            if "images" not in feat or feat["images"] is None:
                raise ValueError(
                    "OmChatVLM: processor 没有返回 images；请确认你加载的是 OmChat 的 AutoProcessor。"
                )
            img_tensors.append(feat["images"].to(self.device))
        images = torch.cat(img_tensors, dim=0)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "images": images,
        }

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:
        assert len(items) > 0, "generate_batch: items 不能为空"

        # 1) encode (per-sample)
        encoded_list = [self._encode_one(it) for it in items]

        # 2) generation config
        generation_config = self._prepare_generation_config(max_new_tokens=max_new_tokens, gen_cfg=gen_cfg)
        pad_token_id = int(generation_config["pad_token_id"])

        # 3) stack + left pad
        inputs = self._stack_and_pad(encoded_list, pad_token_id=pad_token_id)

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        # 4) one-shot generate
        with torch.no_grad():
            generate_ids = self.model.generate(
                **inputs,
                **generation_config,
            )

        # 5) cut prefix by padded length (和官方示例一致：按 input_ids.shape[1] 切)
        prefix_len_padded = inputs["input_ids"].shape[1]
        new_tokens_all = generate_ids[:, prefix_len_padded:]

        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limits: List[bool] = []

        decoded_token_ids: List[List[int]] = []

        for i in range(len(items)):
            seq = new_tokens_all[i]
            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)
            seq = seq[:-cut] if cut > 0 else seq

            out_len = int(seq.shape[0])

            # hit max_new_tokens：更可靠的判定是“生成长度到上限且末尾不是 EOS”。
            ended_with_eos = False
            if out_len > 0:
                eos_ids = []
                gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
                tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
                eos_ids = list(set(gc_eos + tok_eos))
                if eos_ids and seq[-1].item() in eos_ids:
                    ended_with_eos = True

            hit_limits.append(out_len >= max_new_tokens and not ended_with_eos)
            decoded_token_ids.append(seq.cpu().tolist())

        texts = self.tokenizer.batch_decode(
            decoded_token_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )
        outputs = [t.strip() for t in texts]
        for o in outputs:
            console.print("\n[yellow]output without prompt: ", o)
        return outputs, right_pad_lens, hit_limits

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outs, pads, hits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outs[0], pads[0], hits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len

