from transformers import AutoModel, AutoTokenizer
import torch, os
from PIL import Image

from .base import BaseVLInference
from .utils_internvl import load_image

class InternVLInference(BaseVLInference):
    def load_model_and_processor(self):
        self.model = AutoModel.from_pretrained(
            self.model_path,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            use_flash_attn=True,
            trust_remote_code=True
        ).eval().cuda()
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True, use_fast=False)

        # self.max_new_tokens = 20

        if self.model_type == "internvl2_5":
            self.gen_cfg = dict(max_new_tokens=self.max_new_tokens, do_sample=False)
        elif self.model_type == "internvl3":
            self.gen_cfg = dict(max_new_tokens=self.max_new_tokens, do_sample=False, eos_token_id=151645, pad_token_id=151645)
        elif self.model_type == "internvl3_5":
            self.gen_cfg = dict(max_new_tokens=self.max_new_tokens, do_sample=False, eos_token_id=151645, pad_token_id=151645)

    def run_batch_inference(self, batch):
        imgs = batch['image']
        query_texts = batch['text']

        pixel_values = []
        num_patches_list = []
        for img in imgs:
            tensor = load_image(img, max_num=12).to(torch.bfloat16).cuda()
            pixel_values.append(tensor)
            num_patches_list.append(tensor.size(0))

        pixel_values = torch.cat(pixel_values, dim=0)
        questions = ["<image>\n"+text for text in query_texts]

        answers = self.model.batch_chat(
            self.tokenizer, pixel_values,
            num_patches_list=num_patches_list,
            questions=questions,
            generation_config=self.gen_cfg
        )

        # return [{"question": q, "answer": a} for q, a in zip(query_texts, answers)]
        return {"question":query_texts, "answer": answers}
