{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1: Initialize SAM2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "# if using Apple MPS, fall back to CPU for unsupported ops\n",
    "os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "# select the device for computation\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "elif torch.backends.mps.is_available():\n",
    "    device = torch.device(\"mps\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "print(f\"using device: {device}\")\n",
    "\n",
    "if device.type == \"cuda\":\n",
    "    # use bfloat16 for the entire notebook\n",
    "    torch.autocast(\"cuda\", dtype=torch.bfloat16).__enter__()\n",
    "    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
    "    if torch.cuda.get_device_properties(0).major >= 8:\n",
    "        torch.backends.cuda.matmul.allow_tf32 = True\n",
    "        torch.backends.cudnn.allow_tf32 = True\n",
    "elif device.type == \"mps\":\n",
    "    print(\n",
    "        \"\\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might \"\n",
    "        \"give numerically different outputs and sometimes degraded performance on MPS. \"\n",
    "        \"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion.\"\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "\n",
    "def show_mask(mask, ax, random_color=False, borders = True):\n",
    "    if random_color:\n",
    "        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
    "    else:\n",
    "        color = np.array([30/255, 144/255, 255/255, 0.6])\n",
    "    h, w = mask.shape[-2:]\n",
    "    mask = mask.astype(np.uint8)\n",
    "    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
    "    if borders:\n",
    "        import cv2\n",
    "        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) \n",
    "        # Try to smooth contours\n",
    "        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]\n",
    "        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) \n",
    "    ax.imshow(mask_image)\n",
    "\n",
    "def show_points(coords, labels, ax, marker_size=375):\n",
    "    pos_points = coords[labels==1]\n",
    "    neg_points = coords[labels==0]\n",
    "    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
    "    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   \n",
    "\n",
    "def show_box(box, ax):\n",
    "    x0, y0 = box[0], box[1]\n",
    "    w, h = box[2] - box[0], box[3] - box[1]\n",
    "    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    \n",
    "\n",
    "def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):\n",
    "    for i, (mask, score) in enumerate(zip(masks, scores)):\n",
    "        plt.figure(figsize=(10, 10))\n",
    "        plt.imshow(image)\n",
    "        show_mask(mask, plt.gca(), borders=borders)\n",
    "        if point_coords is not None:\n",
    "            assert input_labels is not None\n",
    "            show_points(point_coords, input_labels, plt.gca())\n",
    "        if box_coords is not None:\n",
    "            # boxes\n",
    "            show_box(box_coords, plt.gca())\n",
    "        if len(scores) > 1:\n",
    "            plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n",
    "        plt.axis('off')\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sam2.build_sam import build_sam2\n",
    "from sam2.sam2_image_predictor import SAM2ImagePredictor\n",
    "\n",
    "sam2_checkpoint = \"facebook/sam2.1_hiera_large.pt\" # your own SAM2 path\n",
    "model_cfg = \"configs/sam2.1/sam2.1_hiera_l.yaml\"\n",
    "\n",
    "sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)\n",
    "\n",
    "predictor = SAM2ImagePredictor(sam2_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2: Load ReferSegDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))  # 添加父目录到系统路径\n",
    "\n",
    "from refer_seg_dataset import ReferSegDataset\n",
    "dataset_name = \"refcocog\"\n",
    "dataset = ReferSegDataset(base_image_dir=\"data/VisionReasoner_multi_object_1k_840\", refer_seg_data=dataset_name, data_split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "refer_seg_ds = dataset.refer_seg_data[dataset_name]\n",
    "images = refer_seg_ds[\"images\"]\n",
    "annotations = refer_seg_ds[\"annotations\"]\n",
    "img2refs = refer_seg_ds[\"img2refs\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 7\n",
    "image_info = images[idx]\n",
    "image_path = image_info[\"file_name\"]\n",
    "image_id = image_info[\"id\"]\n",
    "image_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "refs = img2refs[image_id]\n",
    "refs[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set Image to SAM2\n",
    "import cv2 \n",
    "image = cv2.imread(image_path)\n",
    "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
    "plt.imshow(image)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import ndimage\n",
    "def get_two_representative_points(m):\n",
    "    \"\"\"\n",
    "    找到两个能较好描述mask形状的点\n",
    "    \n",
    "    Args:\n",
    "        m: 二值图像数组\n",
    "    \n",
    "    Returns:\n",
    "        tuple: ((x1, y1), (x2, y2)) 两个代表性点的坐标\n",
    "    \"\"\"\n",
    "    y_indices, x_indices = np.where(m == 1)\n",
    "    if len(x_indices) == 0 or len(y_indices) == 0:\n",
    "        return None, None\n",
    "    \n",
    "    # 计算距离变换\n",
    "    dist_transform = ndimage.distance_transform_edt(m)\n",
    "    \n",
    "    # 找到第一个点（全局最大值点）\n",
    "    y1, x1 = np.unravel_index(dist_transform.argmax(), dist_transform.shape)\n",
    "    \n",
    "    # 计算mask的重心\n",
    "    center_y = int(np.mean(y_indices))\n",
    "    center_x = int(np.mean(x_indices))\n",
    "    \n",
    "    # 将点分为两组：距离第一个点较远的点和较近的点\n",
    "    points = np.column_stack((y_indices, x_indices))\n",
    "    distances_to_first = ((points[:, 0] - y1) ** 2 + (points[:, 1] - x1) ** 2) ** 0.5\n",
    "    \n",
    "    # 找到距离第一个点最远的点集\n",
    "    far_points = points[distances_to_first > np.median(distances_to_first)]\n",
    "    \n",
    "    if len(far_points) > 0:\n",
    "        # 在远点中找到距离变换值最大的点作为第二个点\n",
    "        far_dist_values = dist_transform[far_points[:, 0], far_points[:, 1]]\n",
    "        second_point_idx = np.argmax(far_dist_values)\n",
    "        y2, x2 = far_points[second_point_idx]\n",
    "    else:\n",
    "        # 如果没有合适的远点，使用重心附近的点\n",
    "        local_region = dist_transform[\n",
    "            max(0, center_y - 10):min(m.shape[0], center_y + 10),\n",
    "            max(0, center_x - 10):min(m.shape[1], center_x + 10)\n",
    "        ]\n",
    "        local_y, local_x = np.unravel_index(local_region.argmax(), local_region.shape)\n",
    "        y2 = local_y + max(0, center_y - 10)\n",
    "        x2 = local_x + max(0, center_x - 10)\n",
    "    \n",
    "    # 确保两个点都在mask上\n",
    "    if m[y1, x1] == 0:\n",
    "        distances = (x_indices - x1)**2 + (y_indices - y1)**2\n",
    "        nearest_idx = np.argmin(distances)\n",
    "        x1, y1 = int(x_indices[nearest_idx]), int(y_indices[nearest_idx])\n",
    "    \n",
    "    if m[y2, x2] == 0:\n",
    "        distances = (x_indices - x2)**2 + (y_indices - y2)**2\n",
    "        nearest_idx = np.argmin(distances)\n",
    "        x2, y2 = int(x_indices[nearest_idx]), int(y_indices[nearest_idx])\n",
    "    \n",
    "    return [x1, y1], [x2, y2] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mask_from_point(predictor, input_point, input_label, box):\n",
    "    masks, scores, logits = predictor.predict(\n",
    "        point_coords=input_point,\n",
    "        point_labels=input_label,\n",
    "        box=box,\n",
    "        multimask_output=False,\n",
    "    )\n",
    "    sorted_ind = np.argsort(scores)[::-1]\n",
    "    masks = masks[sorted_ind]\n",
    "    scores = scores[sorted_ind]\n",
    "    logits = logits[sorted_ind]\n",
    "    return masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def compute_iou(mask1, mask2):\n",
    "    intersection = np.logical_and(mask1, mask2).sum()\n",
    "    union = np.logical_or(mask1, mask2).sum()\n",
    "    if union == 0:\n",
    "        return 0\n",
    "    return intersection / union"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3: Generate annotation list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pycocotools import mask\n",
    "import numpy as np\n",
    "from tqdm import tqdm  # 导入tqdm\n",
    "import json  # 导入json模块\n",
    "import cv2\n",
    "\n",
    "threshold_iou = 0.6  # threshold_iou IOU:  0.659445961\n",
    "cnt = 0\n",
    "\n",
    "seg_zero_annotation_list = []\n",
    "\n",
    "for idx in tqdm(range(len(images)), desc=\"Processing images\"):  # 使用tqdm包装循环\n",
    "    image_info = images[idx]\n",
    "    image_path = image_info[\"file_name\"]\n",
    "    image_id = image_info[\"id\"]\n",
    "    refs = img2refs[image_id]\n",
    "    \n",
    "    image = cv2.imread(image_path)\n",
    "    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
    "    predictor.set_image(image)\n",
    "    \n",
    "    texts = []\n",
    "    bboxes = []\n",
    "    points = []\n",
    "    ann_ids = []\n",
    "    for ref in refs:\n",
    "        ann_id = ref[\"ann_id\"]\n",
    "        \n",
    "        text = ref[\"sentences\"][0][\"raw\"].strip().strip(\".?!\").lower()\n",
    "        \n",
    "        ann = annotations[ann_id]\n",
    "        if len(ann[\"segmentation\"]) == 0:\n",
    "            m = np.zeros((image_info[\"height\"], image_info[\"width\"])).astype(\n",
    "                np.uint8\n",
    "            )\n",
    "            continue\n",
    "\n",
    "        if type(ann[\"segmentation\"][0]) == list:  # polygon\n",
    "            rle = mask.frPyObjects(\n",
    "                ann[\"segmentation\"], image_info[\"height\"], image_info[\"width\"]\n",
    "            )\n",
    "        else:\n",
    "            rle = ann[\"segmentation\"]\n",
    "            for i in range(len(rle)):\n",
    "                if not isinstance(rle[i][\"counts\"], bytes):\n",
    "                    rle[i][\"counts\"] = rle[i][\"counts\"].encode()\n",
    "        m = mask.decode(rle)\n",
    "        m = np.sum(\n",
    "            m, axis=2\n",
    "        )  # sometimes there are multiple binary map (corresponding to multiple segs)\n",
    "        m = m.astype(np.uint8)  # convert to np.uint8 \n",
    "        \n",
    "        left = np.where(m == 1)[1].min()\n",
    "        top = np.where(m == 1)[0].min()\n",
    "        right = np.where(m == 1)[1].max()\n",
    "        bottom = np.where(m == 1)[0].max()\n",
    "        box = [left, top, right, bottom]\n",
    "        \n",
    "        points_1, points_2 = get_two_representative_points(m)\n",
    "        \n",
    "        point = points_1\n",
    "        label = 1\n",
    "        \n",
    "        mask_pred = get_mask_from_point(predictor, np.array([point]), np.array([label]), np.array(box))\n",
    "        \n",
    "        mask_pred = mask_pred[0].astype(bool)\n",
    "        mask_gt = m.astype(bool)\n",
    "        iou = compute_iou(mask_pred, mask_gt)\n",
    "        # print(iou)\n",
    "        # show_image_with_mask_and_bbox_point(raw_image, mask_pred, box, points, labels)\n",
    "        if iou < threshold_iou:\n",
    "            continue\n",
    "        \n",
    "        bboxes.append(box)\n",
    "        points.append(point)\n",
    "        texts.append(text)\n",
    "        ann_ids.append(str(ann_id))\n",
    "    \n",
    "    if len(bboxes) == 0:\n",
    "        continue\n",
    "    \n",
    "    seg_zero_annotation_list.append({\n",
    "        \"id\": f\"{dataset_name}_\" + \"_\".join(ann_ids[:3]),\n",
    "        \"image_id\": image_id,\n",
    "        \"image_path\": image_path,\n",
    "        \"problem\": \"'\" + \"' and '\".join(texts) + \"'\",\n",
    "        \"bboxes\": bboxes,\n",
    "        \"center_points\": points\n",
    "    })\n",
    "        \n",
    "    cnt += 1\n",
    "        \n",
    "    if cnt > 20:\n",
    "        break\n",
    "        \n",
    "            \n",
    "print(f\"Total: {len(seg_zero_annotation_list)}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seg_zero_annotation_list[10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for item in seg_zero_annotation_list:\n",
    "    item['bboxes'] = [list(map(int, bbox)) for bbox in item['bboxes']]\n",
    "    item['center_points'] = [list(map(int, center_point)) for center_point in item['center_points']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seg_zero_annotation_list[10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4: Save and show examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f'seg_zero_{dataset_name}_annotation_list.json', 'w', encoding='utf-8') as f:\n",
    "    json.dump(seg_zero_annotation_list, f, ensure_ascii=False, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2 \n",
    "\n",
    "item = seg_zero_annotation_list[30]\n",
    "\n",
    "print(item['problem'])\n",
    "print(item['bboxes'])\n",
    "print(item['center_points'])\n",
    "\n",
    "image_path = item['image_path']\n",
    "image = cv2.imread(image_path)\n",
    "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
    "\n",
    "\n",
    "for bbox, center_point in zip(item['bboxes'], item['center_points']):\n",
    "    cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 2)\n",
    "    cv2.circle(image, (center_point[0], center_point[1]), 5, (0, 255, 0), -1)\n",
    "    \n",
    "plt.imshow(image)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 5: Please refer to gen_training_dataset.py"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "visionreasoner",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
