{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/conda/envs/pi0_test/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Failed to import src.model.gemma3.processing_gemma3 because of the following error (look up to see its traceback):\nNo module named 'src.feature_extraction_utils'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "File \u001b[0;32m/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/conda/envs/pi0_test/lib/python3.10/site-packages/transformers/utils/import_utils.py:1863\u001b[0m, in \u001b[0;36m_LazyModule._get_module\u001b[0;34m(self, module_name)\u001b[0m\n\u001b[1;32m   1862\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 1863\u001b[0m     \u001b[39mreturn\u001b[39;00m importlib\u001b[39m.\u001b[39;49mimport_module(\u001b[39m\"\u001b[39;49m\u001b[39m.\u001b[39;49m\u001b[39m\"\u001b[39;49m \u001b[39m+\u001b[39;49m module_name, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m\u001b[39m__name__\u001b[39;49m)\n\u001b[1;32m   1864\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n",
      "File \u001b[0;32m/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/conda/envs/pi0_test/lib/python3.10/importlib/__init__.py:126\u001b[0m, in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m    125\u001b[0m         level \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[0;32m--> 126\u001b[0m \u001b[39mreturn\u001b[39;00m _bootstrap\u001b[39m.\u001b[39;49m_gcd_import(name[level:], package, level)\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:1050\u001b[0m, in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:1027\u001b[0m, in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:1006\u001b[0m, in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:688\u001b[0m, in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap_external>:883\u001b[0m, in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:241\u001b[0m, in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
      "File \u001b[0;32m/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/cunxin/pi0_suite/open-pi-zero/src/model/gemma3/processing_gemma3.py:21\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[39mimport\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mnumpy\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mas\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mnp\u001b[39;00m\n\u001b[0;32m---> 21\u001b[0m \u001b[39mfrom\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m.\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mfeature_extraction_utils\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mimport\u001b[39;00m BatchFeature\n\u001b[1;32m     22\u001b[0m \u001b[39mfrom\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m.\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mimage_utils\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mimport\u001b[39;00m ImageInput, make_nested_list_of_images\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'src.feature_extraction_utils'",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/cunxin/pi0_suite/open-pi-zero/scripts/try_gemma.ipynb Cell 1\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> <a href='vscode-notebook-cell://notebook-inspire.sii.edu.cn/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/cunxin/pi0_suite/open-pi-zero/scripts/try_gemma.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mfrom\u001b[39;00m\u001b[39m \u001b[39m\u001b[39msrc\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmodel\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mgemma3\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mimport\u001b[39;00m Gemma3Processor, Gemma3ForConditionalGeneration\n\u001b[1;32m      <a href='vscode-notebook-cell://notebook-inspire.sii.edu.cn/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/cunxin/pi0_suite/open-pi-zero/scripts/try_gemma.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a>\u001b[0m \u001b[39mfrom\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mPIL\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mimport\u001b[39;00m Image\n\u001b[1;32m      <a href='vscode-notebook-cell://notebook-inspire.sii.edu.cn/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/cunxin/pi0_suite/open-pi-zero/scripts/try_gemma.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39mimport\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mtorch\u001b[39;00m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:1075\u001b[0m, in \u001b[0;36m_handle_fromlist\u001b[0;34m(module, fromlist, import_, recursive)\u001b[0m\n",
      "File \u001b[0;32m/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/conda/envs/pi0_test/lib/python3.10/site-packages/transformers/utils/import_utils.py:1851\u001b[0m, in \u001b[0;36m_LazyModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   1849\u001b[0m     value \u001b[39m=\u001b[39m Placeholder\n\u001b[1;32m   1850\u001b[0m \u001b[39melif\u001b[39;00m name \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_class_to_module\u001b[39m.\u001b[39mkeys():\n\u001b[0;32m-> 1851\u001b[0m     module \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_get_module(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_class_to_module[name])\n\u001b[1;32m   1852\u001b[0m     value \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(module, name)\n\u001b[1;32m   1853\u001b[0m \u001b[39melif\u001b[39;00m name \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_modules:\n",
      "File \u001b[0;32m/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/conda/envs/pi0_test/lib/python3.10/site-packages/transformers/utils/import_utils.py:1865\u001b[0m, in \u001b[0;36m_LazyModule._get_module\u001b[0;34m(self, module_name)\u001b[0m\n\u001b[1;32m   1863\u001b[0m     \u001b[39mreturn\u001b[39;00m importlib\u001b[39m.\u001b[39mimport_module(\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m \u001b[39m+\u001b[39m module_name, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m)\n\u001b[1;32m   1864\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n\u001b[0;32m-> 1865\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m   1866\u001b[0m         \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mFailed to import \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m{\u001b[39;00mmodule_name\u001b[39m}\u001b[39;00m\u001b[39m because of the following error (look up to see its\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m   1867\u001b[0m         \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m traceback):\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m{\u001b[39;00me\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m   1868\u001b[0m     ) \u001b[39mfrom\u001b[39;00m\u001b[39m \u001b[39m\u001b[39me\u001b[39;00m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Failed to import src.model.gemma3.processing_gemma3 because of the following error (look up to see its traceback):\nNo module named 'src.feature_extraction_utils'"
     ]
    }
   ],
   "source": [
    "from src.model.gemma3 import Gemma3Processor, Gemma3ForConditionalGeneration\n",
    "from PIL import Image\n",
    "import torch\n",
    "\n",
    "model_path = \"/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/cunxin/pi0_suite/checkpoints/gemma-3-4b-it\"\n",
    "\n",
    "model = Gemma3ForConditionalGeneration.from_pretrained(\n",
    "    model_path, device_map=\"auto\"\n",
    ").eval()\n",
    "\n",
    "processor = Gemma3Processor.from_pretrained(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "eggplant = Image.open(\"/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/cunxin/pi0_suite/open-pi-zero/eggplant__conf0.61.jpg\")\n",
    "basket = Image.opne(\"/inspire/ssd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/public/yjc/cunxin/pi0_suite/open-pi-zero/yellow_basket__conf0.17.jpg\")\n",
    "\n",
    "messages = [\n",
    "    {\n",
    "        \"role\": \"system\",\n",
    "        \"content\": [{\"type\": \"text\", \"text\": \"You are a helpful assistant.\"}]},\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": [\n",
    "            {\"type\": \"text\", \"text\": \"What can you see in this image \"},\n",
    "            {\"type\": \"image\", \"content\": eggplant},\n",
    "            {\"type\": \"text\", \"text\": \"and I want to know this one too \"},\n",
    "            {\"type\": \"image\", \"content\": basket},\n",
    "        ]\n",
    "    },\n",
    "]\n",
    "\n",
    "inputs = processor.apply_chat_template(\n",
    "    messages, add_generation_prompt=True, tokenize=True,\n",
    "    return_dict=True, return_tensors=\"pt\"\n",
    ").to(model.device, dtype=torch.bfloat16)\n",
    "\n",
    "input_len = inputs[\"input_ids\"].shape[-1]\n",
    "\n",
    "with torch.inference_mode():\n",
    "    generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)\n",
    "    generation = generation[0][input_len:]\n",
    "\n",
    "decoded = processor.decode(generation, skip_special_tokens=True)\n",
    "print(decoded)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pi0_test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
