{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Notebook Code to read PascalVOC2011 images and extract superpixels to store as pickle files\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import scipy.io as sio\n",
    "import torch\n",
    "from PIL import Image\n",
    "from torch.utils import data\n",
    "\n",
    "import random\n",
    "import scipy\n",
    "import pickle\n",
    "from skimage.segmentation import slic\n",
    "from skimage.future import graph\n",
    "from skimage import filters, color\n",
    "\n",
    "import scipy.ndimage\n",
    "import scipy.spatial\n",
    "from scipy.spatial.distance import cdist\n",
    "\n",
    "import time\n",
    "import dgl\n",
    "import torch\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Automatically downloading dataset from the link in this repo\n",
    "1. SBD from (https://github.com/shelhamer/fcn.berkeleyvision.org/tree/master/data/pascal)\n",
    "2. After extracting, get benchmark_RELEASE folder.  \n",
    "3. The benchmark_RELEASE folder will be placed in the current directory\n",
    "\n",
    "code started from https://github.com/zijundeng/pytorch-semantic-segmentation\n",
    "\n",
    "##### The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isfile('benchmark.tgz'):\n",
    "    print('downloading..')\n",
    "    !wget http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz\n",
    "    !tar -xzf benchmark.tgz\n",
    "else:\n",
    "    print('File already downloaded')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = '.'\n",
    "num_classes = 21\n",
    "ignore_label = 255\n",
    "\n",
    "\"\"\"\n",
    "color map\n",
    "0=background, 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle # 6=bus, 7=car, 8=cat, 9=chair, 10=cow, 11=diningtable,\n",
    "12=dog, 13=horse, 14=motorbike, 15=person # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor\n",
    "\"\"\"\n",
    "\n",
    "class VOC_SBD_Images(data.Dataset):\n",
    "    def __init__(self, mode):\n",
    "        self.imgs = self.read_dataset(mode)\n",
    "        if len(self.imgs) == 0:\n",
    "            raise RuntimeError('Found 0 images, please check the data set')\n",
    "        self.mode = mode\n",
    "        self.img_list = []\n",
    "        self.mask_list = []\n",
    "        self._pack_images_masks()\n",
    "        \n",
    "    def read_dataset(self, mode):\n",
    "        # in this paper, we train on the train set and evaluate on the val set\n",
    "        assert mode in ['train', 'val']\n",
    "        items = []\n",
    "        img_path = os.path.join(root, 'benchmark_RELEASE', 'dataset', 'img')\n",
    "        mask_path = os.path.join(root, 'benchmark_RELEASE', 'dataset', 'cls')\n",
    "\n",
    "        if mode == 'train':\n",
    "            data_list = [l.strip('\\n') for l in open(os.path.join(\n",
    "                root, 'benchmark_RELEASE', 'dataset', 'train.txt')).readlines()]\n",
    "        elif mode == 'val':\n",
    "            data_list = [l.strip('\\n') for l in open(os.path.join(\n",
    "                root, 'benchmark_RELEASE', 'dataset', 'val.txt')).readlines()]        \n",
    "\n",
    "        for it in data_list:\n",
    "            item = (os.path.join(img_path, it + '.jpg'), os.path.join(mask_path, it + '.mat'))\n",
    "            items.append(item)\n",
    "        return items\n",
    "    \n",
    "    def _pack_images_masks(self):\n",
    "        for index in range(self.__len__()):\n",
    "            img_path, mask_path = self.imgs[index]\n",
    "            img = Image.open(img_path).convert('RGB')\n",
    "            \n",
    "            mask = sio.loadmat(mask_path)['GTcls']['Segmentation'][0][0]\n",
    "            mask = Image.fromarray(mask.astype(np.uint8))\n",
    "\n",
    "            self.img_list.append(np.array(img))\n",
    "            self.mask_list.append(np.array(mask))\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self.img_list[index], self.mask_list[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_image_slic(params):\n",
    "    \n",
    "    img, index, n_images, args, to_print, shuffle = params\n",
    "    img_original = img\n",
    "\n",
    "    assert img.dtype == np.uint8, img.dtype\n",
    "    img = (img / 255.).astype(np.float32)\n",
    "\n",
    "    n_sp_extracted = args['n_sp'] + 1  # number of actually extracted superpixels (can be different from requested in SLIC)\n",
    "    \n",
    "    # number of superpixels we ask to extract (larger to extract more superpixels - closer to the desired n_sp)\n",
    "    n_sp_query = args['n_sp'] + 50\n",
    "    \n",
    "    while n_sp_extracted > args['n_sp']:\n",
    "        superpixels = slic(img, n_segments=n_sp_query, compactness=args['compactness'], multichannel=len(img.shape) > 2)\n",
    "        sp_indices = np.unique(superpixels)\n",
    "        n_sp_extracted = len(sp_indices)\n",
    "        n_sp_query -= 1  # reducing the number of superpixels until we get <= n superpixels\n",
    "\n",
    "    assert n_sp_extracted <= args['n_sp'] and n_sp_extracted > 0, (args['split'], index, n_sp_extracted, args['n_sp'])\n",
    "    \n",
    "    # make sure superpixel indices are numbers from 0 to n-1\n",
    "    assert n_sp_extracted == np.max(superpixels) + 1, ('superpixel indices', np.unique(superpixels))  \n",
    "\n",
    "    # Creating region adjacency graph based on boundary\n",
    "    gimg = color.rgb2gray(img_original)\n",
    "    edges = filters.sobel(gimg)\n",
    "    \n",
    "    try:\n",
    "        g = graph.rag_boundary(superpixels, edges)\n",
    "    except ValueError: # Error thrown when graph size is perhaps 1\n",
    "        print(\"ignored graph\")\n",
    "        g = nx.complete_graph(sp_indices) # so ignoring these for now and placing dummy info\n",
    "        nx.set_edge_attributes(g, 0., \"weight\")\n",
    "        nx.set_edge_attributes(g, 0, \"count\")\n",
    "    \n",
    "    if shuffle:\n",
    "        ind = np.random.permutation(n_sp_extracted)\n",
    "    else:\n",
    "        ind = np.arange(n_sp_extracted)\n",
    "\n",
    "    sp_order = sp_indices[ind].astype(np.int32)\n",
    "    if len(img.shape) == 2:\n",
    "        img = img[:, :, None]\n",
    "\n",
    "    n_ch = 1 if img.shape[2] == 1 else 3\n",
    "\n",
    "    sp_intensity, sp_coord = [], []\n",
    "    for seg in sp_order:\n",
    "        mask = (superpixels == seg).squeeze()\n",
    "        avg_value = np.zeros(n_ch)\n",
    "        std_value = np.zeros(n_ch)\n",
    "        max_value = np.zeros(n_ch)\n",
    "        min_value = np.zeros(n_ch)\n",
    "        for c in range(n_ch):\n",
    "            avg_value[c] = np.mean(img[:, :, c][mask])\n",
    "            std_value[c] = np.std(img[:, :, c][mask])\n",
    "            max_value[c] = np.max(img[:, :, c][mask])\n",
    "            min_value[c] = np.min(img[:, :, c][mask])\n",
    "        cntr = np.array(scipy.ndimage.measurements.center_of_mass(mask))  # row, col\n",
    "        \n",
    "        sp_intensity.append(np.concatenate((avg_value,\n",
    "                                           std_value,\n",
    "                                           max_value,\n",
    "                                           min_value), -1))\n",
    "        sp_coord.append(cntr)\n",
    "    sp_intensity = np.array(sp_intensity, np.float32)\n",
    "    sp_coord = np.array(sp_coord, np.float32)\n",
    "    if to_print and (index % 100 == 0):\n",
    "        print('image={}/{}, shape={}, min={:.2f}, max={:.2f}, n_sp={}'.format(index + 1, n_images, img.shape,\n",
    "                                                                              img.min(), img.max(), sp_intensity.shape[0]))\n",
    "    return sp_intensity, sp_coord, sp_order, superpixels, g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_superpixels_data(args):\n",
    "    print(\"Extracting for {} split..\".format(args['split']))\n",
    "    data_set = train_set if args['split']== 'train' else val_set\n",
    "    \n",
    "    random.seed(args['seed'])\n",
    "    np.random.seed(args['seed'])\n",
    "    num_samples = len(data_set.img_list)\n",
    "    sp_data = []\n",
    "    for i, img in enumerate(data_set.img_list):\n",
    "        slic_out = process_image_slic((img, i, len(data_set), args, True, False))\n",
    "        if slic_out is not None:\n",
    "            sp_data.append(slic_out)\n",
    "    \n",
    "    superpixels = [sp_data[i][3] for i in range(num_samples)]\n",
    "    rag_boundary_graphs = [sp_data[i][4] for i in range(num_samples)]\n",
    "    sp_data = [sp_data[i][:3] for i in range(num_samples)]\n",
    "\n",
    "    \n",
    "    \"\"\"\n",
    "    # NODE LABELING\n",
    "    : using the coord value of the superpixel node to select the \n",
    "      corresponding label from the ground truth pixel (segmentation mask)\n",
    "    \"\"\"\n",
    "    sp_node_labels = []\n",
    "    \n",
    "    for i, img in enumerate(data_set.mask_list):\n",
    "        coord = sp_data[i][1]                           # the x and y coord of the superpixel node (float)\n",
    "        sp_x_coord = np.rint(coord[:,0]).astype(int)    # the rounded x coord of the superpixel node (int)\n",
    "        sp_y_coord = np.rint(coord[:,1]).astype(int)    # the rounded y coord of the superpixel node (int)\n",
    "\n",
    "        # labeling the superpixel node with the same value of the original pixel \n",
    "        # ground truth  that is on the mean coord of the superpixel node\n",
    "        sp_node_labels.append(np.array(\n",
    "            [data_set.mask_list[i][sp_x_coord[_]][sp_y_coord[_]] for _ in range(len(sp_x_coord))], dtype=np.int32))\n",
    "        \n",
    "    with open('%s/%s_%dsp_%dcmpt_%s.pkl' % (args['out_dir'], args['dataset'], args['n_sp'], args['compactness'], args['split']), 'wb') as f:\n",
    "        pickle.dump((sp_node_labels, sp_data), f, protocol=2)\n",
    "    with open('%s/%s_%dsp_%dcmpt_%s_superpixels.pkl' % (args['out_dir'], args['dataset'], args['n_sp'], args['compactness'], args['split']), 'wb') as f:\n",
    "        pickle.dump(superpixels, f, protocol=2)\n",
    "    with open('%s/%s_%dsp_%dcmpt_%s_rag_boundary_graphs.pkl' % (args['out_dir'], args['dataset'], args['n_sp'], args['compactness'], args['split']), 'wb') as f:\n",
    "        pickle.dump(rag_boundary_graphs, f, protocol=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t0 = time.time()\n",
    "print(\"[I] Reading and loading SBD_VOC Images and Masks..\")\n",
    "train_set = VOC_SBD_Images('train')\n",
    "val_set = VOC_SBD_Images('val')\n",
    "print(\"[I] Time taken: {:.4f}s\".format(time.time()-t0))\n",
    "\n",
    "args= {\n",
    "    'n_sp': 500,\n",
    "    'compactness': 10,\n",
    "    'seed': 41,\n",
    "    'out_dir': '.',\n",
    "    'dataset': 'VOC'\n",
    "}\n",
    "count_ignored_graphs = 0\n",
    "t0 = time.time()\n",
    "print(\"[I] Extracting and writing superpixels data..\")\n",
    "\n",
    "# TRAIN SET\n",
    "args['split'] = 'train'\n",
    "write_superpixels_data(args)\n",
    "\n",
    "# VAL SET\n",
    "args['split'] = 'val'\n",
    "write_superpixels_data(args)\n",
    "\n",
    "print(\"[I] Time taken: {:.4f}s\".format(time.time()-t0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t0 = time.time()\n",
    "print(\"[I] Reading and loading SBD_VOC Images and Masks..\")\n",
    "train_set = VOC_SBD_Images('train')\n",
    "val_set = VOC_SBD_Images('val')\n",
    "print(\"[I] Time taken: {:.4f}s\".format(time.time()-t0))\n",
    "\n",
    "args= {\n",
    "    'n_sp': 500,\n",
    "    'compactness': 30,\n",
    "    'seed': 41,\n",
    "    'out_dir': '.',\n",
    "    'dataset': 'VOC'\n",
    "}\n",
    "count_ignored_graphs = 0\n",
    "t0 = time.time()\n",
    "print(\"[I] Extracting and writing superpixels data..\")\n",
    "\n",
    "# TRAIN SET\n",
    "args['split'] = 'train'\n",
    "write_superpixels_data(args)\n",
    "\n",
    "# VAL SET\n",
    "args['split'] = 'val'\n",
    "write_superpixels_data(args)\n",
    "\n",
    "print(\"[I] Time taken: {:.4f}s\".format(time.time()-t0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot superpixels with the boundaries overlayed on the image\n",
    "# plt.imshow(mark_boundaries(img, superpixels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
