from typing import List, Dict, Tuple
import os
import sys

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# ==== 导入项目根目录下的 utils ====
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

from utils import parse_input, get_image_path, _normalize_to_list  # 使用你现有的实现

console = Console()


class SmolVLM(BaseVLM):
    """
    HuggingFaceTB/SmolVLM-Synthetic 封装

    - 使用 AutoProcessor + AutoModelForVision2Seq
    - 文本 prompt：直接调用 utils.parse_input(item)
    - 图像路径：使用 utils.get_image_path(item)
    - 接口：
        * generate_batch -> (outputs, right_pad_lens, hit_limit_flags)
        * generate_one   -> (output_text, right_pad_len, hit_limit)
    """

    def __init__(self, model, tokenizer, processor, device: str = "cuda"):
        """
        参数
        ----
        model: AutoModelForVision2Seq
        processor: AutoProcessor
        device: "cuda" / "cpu"
        """
        tok = tokenizer or getattr(processor, "tokenizer", None)
        if tok is None:
            raise ValueError("SmolVLM: tokenizer is None (need tokenizer or processor.tokenizer)")
        super().__init__(model=model, tokenizer=None, processor=processor, device=device)

        self.tokenizer = tok
        # 统一左侧 padding，避免 decoder-only 右侧 padding 报警告
        if self.tokenizer is not None:
            self.tokenizer.padding_side = "left"
            if self.tokenizer.pad_token is None and getattr(self.tokenizer, "eos_token", None) is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        if hasattr(self.processor, "tokenizer") and self.processor.tokenizer is not None:
            self.processor.tokenizer.padding_side = "left"
            if (
                self.processor.tokenizer.pad_token is None
                and getattr(self.processor.tokenizer, "eos_token", None) is not None
            ):
                self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

    @torch.inference_mode()
    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:
        """
        多样本 batch 推理

        返回
        ----
        outputs: List[str]          # 每条样本的回答（只包含生成部分）
        right_pad_lens: List[int]   # 每条样本“右侧 padding 长度”（尾部 pad/eos 个数）
        hit_limit_flags: List[bool] # 是否命中 max_new_tokens（按有效生成长度 + 是否以 EOS 结束判断）
        """
        assert self.processor is not None, "SmolVLM 需要 AutoProcessor"
        assert self.model is not None, "SmolVLM 需要已加载的模型"

        device = self.device

        texts: List[str] = []
        all_images: List[List[Image.Image]] = []

        # -----------------------
        # 1. 解析样本 -> prompt + 图像列表
        # -----------------------
        for idx, item in enumerate(items):
            # 直接用你 utils.py 里的 parse_input
            prompt_text = parse_input(item)

            # 同样直接用你 utils.py 的 get_image_path(item)
            img_path = get_image_path(item)
            img_paths = _normalize_to_list(img_path)

            images: List[Image.Image] = []
            for p in img_paths:
                if not p:
                    continue
                try:
                    img = Image.open(p).convert("RGB")
                except Exception as e:
                    console.print(
                        f"[red][SmolVLM] 样本 {idx} 图像加载失败：{p}，使用占位黑图代替，error={e}[/red]"
                    )
                    img = Image.new("RGB", (384, 384), (0, 0, 0))
                images.append(img)

            # 构造 SmolVLM 官方格式的 chat messages
            content = [{"type": "image"} for _ in images]
            content.append({"type": "text", "text": prompt_text})
            messages = [
                {
                    "role": "user",
                    "content": content,
                }
            ]

            chat_prompt = self.processor.apply_chat_template(
                messages,
                add_generation_prompt=True,
            )

            texts.append(chat_prompt)
            all_images.append(images)

        # -----------------------
        # 2. 用 AutoProcessor 打包 batch
        # -----------------------
        if all(len(imgs) == 0 for imgs in all_images):
            # 理论上 SeedBench-Plus 都有图，这里只是兜底
            model_inputs = self.processor(
                text=texts,
                padding=True,
                return_tensors="pt",
            )
        else:
            # 多模态：images 是 List[List[PIL.Image]]
            model_inputs = self.processor(
                text=texts,
                images=all_images,
                padding=True,
                return_tensors="pt",
            )

        # model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
        dtype = getattr(self.model, "dtype", None)  # 一般是 torch.bfloat16

        for k, v in model_inputs.items():
            if not isinstance(v, torch.Tensor):
                continue
            # 文本 id 这类整数张量，只挪设备不改 dtype
            if v.dtype in (torch.long, torch.int64, torch.int32, torch.int16, torch.int8):
                model_inputs[k] = v.to(device)
            else:
                # 像 pixel_values 这种浮点的，同步成 model.dtype
                if dtype is not None:
                    model_inputs[k] = v.to(device=device, dtype=dtype)
                else:
                    model_inputs[k] = v.to(device)

        # -----------------------
        # 3. 组装生成参数
        # -----------------------
        gen_kwargs: Dict = dict(
            max_new_tokens=max_new_tokens,
            do_sample=gen_cfg.get("do_sample", False),
        )

        if gen_kwargs["do_sample"]:
            if "temperature" in gen_cfg:
                gen_kwargs["temperature"] = gen_cfg["temperature"]
            if "top_p" in gen_cfg:
                gen_kwargs["top_p"] = gen_cfg["top_p"]
            if "top_k" in gen_cfg:
                gen_kwargs["top_k"] = gen_cfg["top_k"]

        if "num_beams" in gen_cfg:
            gen_kwargs["num_beams"] = gen_cfg["num_beams"]

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            model_inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=model_inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        # -----------------------
        # 4. 调用 generate
        # -----------------------
        generated_ids = self.model.generate(
            **model_inputs,
            **gen_kwargs,
        )

        # -----------------------
        # 5. 解码 & 统计长度 / hit_limit
        # -----------------------
        input_ids = model_inputs["input_ids"]  # [B, L_in]
        batch_size = input_ids.size(0)

        # 统一收集 eos_token_id 列表
        tokenizer = getattr(self.processor, "tokenizer", None)
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(tokenizer, "eos_token_id", None)) if tokenizer is not None else []
        eos_ids = list(set(gc_eos + tok_eos))

        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limit_flags: List[bool] = []

        for i in range(batch_size):
            inp = input_ids[i]
            gen = generated_ids[i]

            in_len = inp.size(0)   # 输入序列长度（包含左侧 pad）
            out_len = gen.size(0)  # 输出总长度（输入 + 生成 + 右侧 pad）

            # 右侧 padding（pad + 多余 eos）的 token 数
            right_pad = self.calculate_right_padding_length(gen)

            # 去掉右侧 padding 后的“有效序列长度”
            eff_out_len = out_len - right_pad if right_pad > 0 else out_len

            # 只取“生成部分 + 去掉尾部 pad/eos”的有效 token
            answer_ids = gen[in_len:eff_out_len]
            answer_text = self.processor.batch_decode(
                answer_ids.unsqueeze(0),
                skip_special_tokens=True,
            )[0]

            # 有效 new_tokens 数（不含右侧 pad/eos）
            eff_new_len = max(eff_out_len - in_len, 0)

            # hit_limit 逻辑：达到 max_new_tokens 且没以 EOS 结束
            last_token = None
            if eff_new_len > 0:
                last_token = gen[eff_out_len - 1].item()
            ended_with_eos = last_token in eos_ids if (last_token is not None and eos_ids) else False
            hit_limit = (eff_new_len >= max_new_tokens) and (not ended_with_eos)

            outputs.append(answer_text)
            right_pad_lens.append(int(right_pad))
            hit_limit_flags.append(hit_limit)

        for o in outputs:
            console.print("\n[yellow]output without prompt: ", o)

        return outputs, right_pad_lens, hit_limit_flags

    @torch.inference_mode()
    def generate_one(
        self,
        item: dict,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Tuple[str, int, bool]:
        """
        单样本推理封装
        """
        outputs, right_pad_lens, hit_limit_flags = self.generate_batch(
            [item],
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
            oom_estimate=False,
            bs_estimate_gen_cfg={}
        )
        return outputs[0], right_pad_lens[0], hit_limit_flags[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
