{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f3b9aa08",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: gradio in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (5.21.0)\n",
      "Requirement already satisfied: aiofiles<24.0,>=22.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (23.2.1)\n",
      "Requirement already satisfied: anyio<5.0,>=3.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (4.8.0)\n",
      "Requirement already satisfied: fastapi<1.0,>=0.115.2 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.115.11)\n",
      "Requirement already satisfied: ffmpy in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.5.0)\n",
      "Requirement already satisfied: gradio-client==1.7.2 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (1.7.2)\n",
      "Requirement already satisfied: groovy~=0.1 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.1.2)\n",
      "Requirement already satisfied: httpx>=0.24.1 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.28.1)\n",
      "Requirement already satisfied: huggingface-hub>=0.28.1 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.30.1)\n",
      "Requirement already satisfied: jinja2<4.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (3.1.6)\n",
      "Requirement already satisfied: markupsafe~=2.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (2.1.5)\n",
      "Requirement already satisfied: numpy<3.0,>=1.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (1.26.4)\n",
      "Requirement already satisfied: orjson~=3.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (3.10.15)\n",
      "Requirement already satisfied: packaging in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (24.2)\n",
      "Requirement already satisfied: pandas<3.0,>=1.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (2.2.3)\n",
      "Requirement already satisfied: pillow<12.0,>=8.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (10.4.0)\n",
      "Requirement already satisfied: pydantic>=2.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (2.10.6)\n",
      "Requirement already satisfied: pydub in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.25.1)\n",
      "Requirement already satisfied: python-multipart>=0.0.18 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.0.20)\n",
      "Requirement already satisfied: pyyaml<7.0,>=5.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (6.0.2)\n",
      "Requirement already satisfied: ruff>=0.9.3 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.11.0)\n",
      "Requirement already satisfied: safehttpx<0.2.0,>=0.1.6 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.1.6)\n",
      "Requirement already satisfied: semantic-version~=2.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (2.10.0)\n",
      "Requirement already satisfied: starlette<1.0,>=0.40.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.46.1)\n",
      "Requirement already satisfied: tomlkit<0.14.0,>=0.12.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.13.2)\n",
      "Requirement already satisfied: typer<1.0,>=0.12 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.15.2)\n",
      "Requirement already satisfied: typing-extensions~=4.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (4.13.0)\n",
      "Requirement already satisfied: uvicorn>=0.14.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio) (0.34.0)\n",
      "Requirement already satisfied: fsspec in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio-client==1.7.2->gradio) (2024.12.0)\n",
      "Requirement already satisfied: websockets<16.0,>=10.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from gradio-client==1.7.2->gradio) (15.0.1)\n",
      "Requirement already satisfied: idna>=2.8 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from anyio<5.0,>=3.0->gradio) (3.10)\n",
      "Requirement already satisfied: sniffio>=1.1 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from anyio<5.0,>=3.0->gradio) (1.3.1)\n",
      "Requirement already satisfied: certifi in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from httpx>=0.24.1->gradio) (2025.1.31)\n",
      "Requirement already satisfied: httpcore==1.* in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from httpx>=0.24.1->gradio) (1.0.7)\n",
      "Requirement already satisfied: h11<0.15,>=0.13 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from httpcore==1.*->httpx>=0.24.1->gradio) (0.14.0)\n",
      "Requirement already satisfied: filelock in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from huggingface-hub>=0.28.1->gradio) (3.18.0)\n",
      "Requirement already satisfied: requests in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from huggingface-hub>=0.28.1->gradio) (2.32.3)\n",
      "Requirement already satisfied: tqdm>=4.42.1 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from huggingface-hub>=0.28.1->gradio) (4.67.1)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from pandas<3.0,>=1.0->gradio) (2.9.0.post0)\n",
      "Requirement already satisfied: pytz>=2020.1 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from pandas<3.0,>=1.0->gradio) (2025.1)\n",
      "Requirement already satisfied: tzdata>=2022.7 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from pandas<3.0,>=1.0->gradio) (2025.1)\n",
      "Requirement already satisfied: annotated-types>=0.6.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from pydantic>=2.0->gradio) (0.7.0)\n",
      "Requirement already satisfied: pydantic-core==2.27.2 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from pydantic>=2.0->gradio) (2.27.2)\n",
      "Requirement already satisfied: click>=8.0.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from typer<1.0,>=0.12->gradio) (8.1.8)\n",
      "Requirement already satisfied: shellingham>=1.3.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from typer<1.0,>=0.12->gradio) (1.5.4)\n",
      "Requirement already satisfied: rich>=10.11.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from typer<1.0,>=0.12->gradio) (13.9.4)\n",
      "Requirement already satisfied: six>=1.5 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas<3.0,>=1.0->gradio) (1.17.0)\n",
      "Requirement already satisfied: markdown-it-py>=2.2.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (3.0.0)\n",
      "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (2.19.1)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from requests->huggingface-hub>=0.28.1->gradio) (3.4.1)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from requests->huggingface-hub>=0.28.1->gradio) (2.3.0)\n",
      "Requirement already satisfied: mdurl~=0.1 in /home/dsailyt/anaconda3/envs/vpt/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio) (0.1.2)\n"
     ]
    }
   ],
   "source": [
    "!pip install gradio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a1753bf6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "* Running on local URL:  http://127.0.0.1:7861\n",
      "\n",
      "To create a public link, set `share=True` in `launch()`.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"http://127.0.0.1:7861/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gradio as gr\n",
    "import json\n",
    "import random\n",
    "from PIL import Image as PILImage\n",
    "\n",
    "with open('./captioning/quad_concept/caption-rap-llava-skip-retreival_RAP-llava-2k-skip-full.json', 'r') as f:\n",
    "    rap = json.load(f)\n",
    "with open('./captioning/quad_concept/caption-rap-llava-skip-retreival_RAP-llava-skip-full.json', 'r') as f:\n",
    "    rap_2k = json.load(f)\n",
    "with open('./captioning/quad_concept/captioning-skip-retrieval-ours.json', 'r') as f:\n",
    "    ours = json.load(f)\n",
    "with open('./captioning/quad_concept/captioning-skip-retrieval-qwen_sft.json', 'r') as f:\n",
    "    rap_qwen = json.load(f)\n",
    "with open('./captioning/quad_concept/captioning-skip-retrieval-zero_shot.json', 'r') as f:\n",
    "    zero_shot = json.load(f)\n",
    "    \n",
    "database_path = './quad-multi-concept/'\n",
    "with open(f\"{database_path}/database.json\", \"r\") as f:\n",
    "    database = json.load(f)\n",
    "    \n",
    "method_names = ['RAP-LLaVA 13B 260K', 'RAP-LLaVA 13B 2K', 'RAP-Qwen 7B 260K', 'Qwen-2.5VL 7B Zero-Shot', 'Ours 7B']\n",
    "\n",
    "def random_samples(n=3):\n",
    "    idx = random.choice(range(len(rap)))\n",
    "    methods = [rap, rap_2k, rap_qwen, zero_shot, ours]\n",
    "\n",
    "    img = PILImage.open(rap[idx][0])\n",
    "    texts = [method[idx][2] for method in methods]    \n",
    "    query =  [rap[idx][-1]]\n",
    "    db_samples = rap[idx][1]\n",
    "    db_imgs = [PILImage.open(database['concept_dict'][s]['image']).convert('RGB').resize((256,256)) for s in db_samples]\n",
    "    db_texts = [f'Name: {s}, Info: ' + database['concept_dict'][s]['info'] for s in db_samples]\n",
    "\n",
    "    return [img] + query + texts + db_imgs + db_texts\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"# Random Visualization of Personalized Multi-Concept Image Captioning Examples using MLLM\")\n",
    "    with gr.Row():\n",
    "        image_output = gr.Image()\n",
    "\n",
    "    with gr.Row():\n",
    "        q_textbox = [gr.Textbox(label=f\"Query Caption\", lines=1)]\n",
    " \n",
    "    with gr.Row():\n",
    "        text_outputs = [\n",
    "            gr.Textbox(label=f\"Caption {method_names[i]}\", lines=5)\n",
    "            for i in range(5)\n",
    "        ]\n",
    "        \n",
    "    with gr.Row():\n",
    "        db_images = [\n",
    "            gr.Image(label=f\"DB Image {i+1}\") for i in range(4)\n",
    "        ]\n",
    "\n",
    "    with gr.Row():\n",
    "        db_textboxes = [\n",
    "            gr.Textbox(label=f\"DB Caption {i+1}\", lines=4)\n",
    "            for i in range(4)\n",
    "        ]\n",
    "    n_samples = gr.Number(value=5, label=\"Number of Texts\", precision=0)\n",
    "    btn = gr.Button(\"Pick Random Samples\")\n",
    "\n",
    "    btn.click(\n",
    "        fn=random_samples,\n",
    "        inputs=[n_samples],\n",
    "        outputs=[image_output] + q_textbox + text_outputs + db_images + db_textboxes  \n",
    "    )\n",
    "demo.launch(server_name=\"127.0.0.1\", share=False, inbrowser=True)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vpt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
