from abc import ABC, abstractmethod
from PIL.Image import Image
from torch import Tensor


class VlmWrapper(ABC):
    model: object
    image_processor: object
    tokenizer: object

    def load_model(self, model_name: str, quantize: bool = False):
        pass

    def prepare_inputs(
        self,
        image: Tensor | Image,
        prompt: str,
    ):
        pass

    def get_logits(
        self,
        images: Tensor | Image,
        prompt: str,
        layer_wise: bool = False,
    ) -> Tensor:
        pass
