from llava.model.builder import load_pretrained_model
import torch
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model
from PIL import Image


model_path = "liuhaotian/llava-v1.6-34b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_path)
)


class Args:
    def __init__(self):
        self.model_path = model_path
        self.model_base = None
        self.query = "Describe the contents of the image."
        self.conv_mode = None
        self.temperature = 0.7
        self.top_p = 0.9
        self.num_beams = 1
        self.max_new_tokens = 256

args = Args()

# 加载图像
image_path = "/scratch/xpy/image_moderation/moderation/images/0.jpg"
image = Image.open(image_path)

# 处理图像
processed_image = image_processor(image)

# 使用tokenizer处理文本
qs = args.query.replace("<image>", "")
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if model.config.mm_use_im_start_end:
    qs = image_token_se + "\n" + qs
else:
    qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

input_ids = tokenizer(qs, return_tensors="pt").input_ids

# 处理图像并转换为张量
images_tensor = processed_image.unsqueeze(0).to(model.device, dtype=torch.float16)

# 设置模型为推理模式
model.eval()

# 推理
with torch.no_grad():
    output_ids = model.generate(
        input_ids=input_ids.to(model.device),
        images=images_tensor,
        do_sample=True if args.temperature > 0 else False,
        temperature=args.temperature,
        top_p=args.top_p,
        num_beams=args.num_beams,
        max_new_tokens=args.max_new_tokens,
        use_cache=True,
    )

# 解码输出
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(outputs)
