{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "30fd23f044014458b7cc86d78b60a11a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "import requests\n",
    "from PIL import Image\n",
    "from io import BytesIO\n",
    "import torch\n",
    "import json, os\n",
    "from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor\n",
    "from qwen_vl_utils import process_vision_info\n",
    "import re\n",
    "import random\n",
    "import copy \n",
    "import warnings\n",
    "import random\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "MODEL_PATH = 'Anonymous-3982/OICT_GRPO_Qwen2.5VL_7B_Full' # Temporary name\n",
    "device = 'cuda:0'\n",
    "model = Qwen2_5_VLForConditionalGeneration.from_pretrained(\n",
    "    MODEL_PATH,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    attn_implementation=\"flash_attention_2\",\n",
    "    device_map=device,\n",
    ")\n",
    "\n",
    "# default processer\n",
    "max_pixels = 1024 * 28 * 28\n",
    "processor = AutoProcessor.from_pretrained(MODEL_PATH, max_pixels=max_pixels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query Img: ./__assets__/multi_large/resize_img.jpg\n",
      "Question: Give a detailed personalized caption of the image.\n",
      "Answer: The image features a group of influential tech leaders, including <jeff> (CEO of Amazon), <jenson> (CEO of NVIDIA), <lecun> (Chief AI Scientist at Meta), <sundar> (CEO of Google), <elon> (CEO of TESLA), <mark> (CEO of META), and <sam> (CEO of OPENAI). They are standing together in a desert-like setting, wearing tactical gear, symbolizing their leadership roles in the tech industry.\n",
      "\n",
      "Query Img: ./__assets__/main_figure/concept_13.png\n",
      "Question: Give a personalized caption for the image.\n",
      "Answer: A lively parade scene unfolds on a bustling street, featuring <monster_toy>, <sloth>, <plush>, and <teddy> in colorful costumes. The <monster_toy> leads the way with a cheerful expression, followed by <sloth>, <plush>, and <teddy>, who carries a drum, creating a festive atmosphere as they march through the crowd.\n",
      "\n",
      "Query Img: ./__assets__/main_figure/concept_2.png\n",
      "Question: Give a personalized caption for the image.\n",
      "Answer: In this delightful autumn scene, <ball> the cheerful animated raccoon, <monster> the curious animated creature, and <otter> the excited animated otter enjoy a playful moment together amidst fallen leaves and a warm, golden backdrop.\n",
      "\n",
      "Query Img: ./__assets__/multi/three_mans.jpg\n",
      "Question: Give a personalized caption of the image.\n",
      "Answer: In this image, we see three distinguished individuals: <A>, a professor at the Université de Montréal; <B>, a Chief AI Scientist at Meta; and <C>, a Nobel laureate in physics at 2024. They are captured in a casual setting, sharing a moment together with <A> holding a mug, symbolizing their camaraderie and shared intellectual pursuits.\n",
      "\n",
      "Query Img: ./__assets__/style_2/2-2.jpg\n",
      "Question: Give a rich caption of the image.\n",
      "Answer: A striking <style2> sculpture featuring a golden astronaut in a dynamic pose, seemingly mid-jump on a snowboard. The figure's suit and limbs are thick, shiny, and metallic, with a molten gold appearance that drips downward, enhancing the surreal and luxurious aesthetic. The astronaut's helmet reflects light, adding depth and dimension to the piece.\n",
      "\n",
      "Query Img: ./__assets__/vton/jenson_out.jpg\n",
      "Question: Give a detailed personalized caption of the image.\n",
      "Answer: The query image shows <customer1> wearing <outfit_1329>. The black t-shirt features a white Stussy logo on the front.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "database_paths = ['./__assets__/multi_large', './__assets__/main_figure', './__assets__/main_figure', './__assets__/multi', './__assets__/style_2', './__assets__/vton']\n",
    "query_imgs     = ['./__assets__/multi_large/resize_img.jpg', './__assets__/main_figure/concept_13.png', './__assets__/main_figure/concept_2.png', './__assets__/multi/three_mans.jpg', './__assets__/style_2/2-2.jpg', './__assets__/vton/jenson_out.jpg']\n",
    "concept_tot    = [['<jeff>', '<jenson>', '<lecun>', '<sundar>', '<elon>', '<mark>', '<sam>'], ['<monster_toy>', '<sloth>', '<plush>', '<teddy>'], ['<ball>', '<monster>', '<otter>'], ['<A>', '<B>', '<C>'], ['<style2>'], ['<outfit_1329>', '<customer1>']]\n",
    "caption_prompts = [\"Give a detailed personalized caption of the image.\", \"Give a personalized caption for the image.\", \"Give a personalized caption for the image.\", \\\n",
    "                    \"Give a personalized caption of the image.\", \"Give a rich caption of the image.\", \"Give a detailed personalized caption of the image.\"]\n",
    "\n",
    "for sel_idx in range(len(database_paths)):\n",
    "    database_path = database_paths[sel_idx]\n",
    "    query_img     = query_imgs[sel_idx]\n",
    "    concepts = concept_tot[sel_idx]\n",
    "    with open(f\"{database_path}/database.json\", \"r\") as f:\n",
    "        database = json.load(f)\n",
    "            \n",
    "    sys_prompt = \"You are a captioning assistant. Your task is to generate an accurate caption for the query image while referencing the given reference images without duplication. \\n Below is additional information about the reference images.\" \n",
    "    QUESTION_TEMPLATE = \"{Question}\"\n",
    "\n",
    "    empty_image = {\"type\": \"image\", \"image\": \"\"}\n",
    "    empty_text  = {\"type\": \"text\", \"text\": \"\"}\n",
    "    message = [\n",
    "        {\"role\": \"system\", \"content\": sys_prompt},\n",
    "        {\"role\": \"user\",\n",
    "        \"content\": [\n",
    "        ],\n",
    "        }\n",
    "    ]\n",
    "    template = copy.deepcopy(empty_image)\n",
    "    template1 = copy.deepcopy(empty_text)\n",
    "    for i in range(len(concepts)):\n",
    "        template['image'] =  database['concept_dict'][concepts[i]]['image'] \n",
    "        message[1]['content'].append(template.copy())\n",
    "        template1['text'] = f\"Name : {concepts[i]}, Info: {database['concept_dict'][concepts[i]]['info']}\"\n",
    "        message[1]['content'].append(template1.copy())\n",
    "        \n",
    "    template['image'] = query_img \n",
    "    question = caption_prompts[sel_idx] \n",
    "    message[1]['content'].append(template.copy())\n",
    "    findname = ' and '.join(concepts)\n",
    "    question = question.format(name = findname)\n",
    "    template1['text'] = f\"This is the query image. {QUESTION_TEMPLATE.format(Question=question)}\"\n",
    "    message[1]['content'].append(template1.copy())\n",
    "\n",
    "    text = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n",
    "            \n",
    "    image_inputs, video_inputs = process_vision_info(message)\n",
    "    inputs = processor(\n",
    "        text=text,\n",
    "        images=image_inputs,\n",
    "        videos=video_inputs,\n",
    "        padding=True,\n",
    "        return_tensors=\"pt\",\n",
    "    )\n",
    "    inputs = inputs.to(\"cuda:0\")\n",
    "\n",
    "    # Inference: Generation of the output\n",
    "    generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024, do_sample=False)\n",
    "\n",
    "    generated_ids_trimmed = [\n",
    "        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n",
    "    ]\n",
    "    batch_output_text = processor.batch_decode(\n",
    "        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
    "    )\n",
    "    outputs = batch_output_text[0]\n",
    "    print('Query Img: {}'.format(query_img))\n",
    "    print(\"Question: {}\\nAnswer: {}\\n\".format(question, outputs))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vpt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
