{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "from PIL import Image \n",
    "from loguru import logger\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "project_root = Path().resolve().parent\n",
    "sys.path.append(str(project_root))\n",
    "\n",
    "from yolo import (\n",
    "    AugmentationComposer, \n",
    "    bbox_nms, \n",
    "    create_model, \n",
    "    custom_logger, \n",
    "    create_converter,\n",
    "    draw_bboxes, \n",
    "    Vec2Box\n",
    ")\n",
    "from yolo.config.config import NMSConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = \"v9-c\"\n",
    "DEVICE = \"cuda:0\"\n",
    "\n",
    "WEIGHT_PATH = f\"../weights/{MODEL}.pt\" \n",
    "TRT_WEIGHT_PATH = f\"../weights/{MODEL}.trt\"\n",
    "MODEL_CONFIG = f\"../yolo/config/model/{MODEL}.yaml\"\n",
    "\n",
    "IMAGE_PATH = \"../demo/images/inference/image.png\"\n",
    "IMAGE_SIZE = (640, 640)\n",
    "\n",
    "custom_logger()\n",
    "device = torch.device(DEVICE)\n",
    "image = Image.open(IMAGE_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(MODEL_CONFIG) as stream:\n",
    "    cfg_model = OmegaConf.load(stream)\n",
    "if os.path.exists(TRT_WEIGHT_PATH):\n",
    "    from torch2trt import TRTModule\n",
    "\n",
    "    model_trt = TRTModule()\n",
    "    model_trt.load_state_dict(torch.load(TRT_WEIGHT_PATH))\n",
    "else:\n",
    "    from torch2trt import torch2trt\n",
    "\n",
    "\n",
    "    model = create_model(cfg_model, weight_path=WEIGHT_PATH)\n",
    "    model = model.to(device).eval()\n",
    "\n",
    "    dummy_input = torch.ones((1, 3, 640, 640)).to(device)\n",
    "    logger.info(f\"♻️ Creating TensorRT model\")\n",
    "    model_trt = torch2trt(model, [dummy_input])\n",
    "    torch.save(model_trt.state_dict(), TRT_WEIGHT_PATH)\n",
    "    logger.info(f\"📥 TensorRT model saved to oonx.pt\")\n",
    "\n",
    "transform = AugmentationComposer([], IMAGE_SIZE)\n",
    "converter = create_converter(cfg_model.name, model_trt, cfg_model.anchor, IMAGE_SIZE, device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image, bbox, rev_tensor = transform(image, torch.zeros(0, 5))\n",
    "image = image.to(device)[None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    predict = model_trt(image)\n",
    "    predict = converter(predict[\"Main\"])\n",
    "predict_box = bbox_nms(predict[0], predict[2], NMSConfig(0.5, 0.5))\n",
    "draw_bboxes(image, predict_box)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sample Output:\n",
    "\n",
    "![image](../demo/images/output/visualize.png)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "yolomit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
