from abc import ABC, abstractmethod
from typing import List, Dict, Any, Tuple


class BaseVLM(ABC):
    """
    负责「根据 item 构造模型输入并执行生成」，不负责 hook 和 hidden_states 解析。
    """

    def __init__(self, model, tokenizer, processor=None, device: str = "cuda", model_key: str = ""):
        self.model = model
        self.tokenizer = tokenizer
        self.processor = processor
        self.device = device
        self.model_key = model_key

    @abstractmethod
    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Tuple[List[str], List[int], List[bool]]:
        """
        给定原始 items，返回：
          - outputs:  每个样本生成的文本（已经 decode 好）
          - right_pad_lens: 右填充长度
          - hit_limits:  每个样本是否 hit max_new_tokens（没遇到 EOS）
        """
        raise NotImplementedError

    def generate_one(
        self,
        item: dict,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Tuple[str, int, bool]:
        """默认单条就是在 batch 上包一层"""
        outputs, rplens, hits = self.generate_batch(
            [item], max_new_tokens=max_new_tokens, gen_cfg=gen_cfg
        )
        return outputs[0], rplens[0], hits[0]
