{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization of the superpixel images, graphs, edge connections and node labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Superpixels from\n",
    "Superpixels are generated using the file `'./scripts/COCO/generate_cocosuperpixels_raw.py'`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from torchvision import transforms, datasets\n",
    "\n",
    "import os\n",
    "import pickle\n",
    "from scipy.spatial.distance import cdist\n",
    "import scipy.io as sio\n",
    "from scipy import ndimage\n",
    "import numpy as np\n",
    "\n",
    "import dgl\n",
    "import torch\n",
    "from torch.utils import data\n",
    "from PIL import Image\n",
    "import time\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "\n",
    "matplotlib.rcParams.update({'font.size': 22})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare COCO Images and mask\n",
    "\n",
    "### Automatically downloading dataset from the link in this repo\n",
    "The links are from the source https://cocodataset.org/#download\n",
    "There are 118K Train images and 5K Val images\n",
    "\n",
    "#### Before proceeding further:\n",
    "- Download the repo https://github.com/cocodataset/cocoapi in the current directory \n",
    "- Then run `make` inside the `cocoapi/PythonAPI` directory  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isfile('train2017.zip'):\n",
    "    print('downloading..')\n",
    "    !curl http://images.cocodataset.org/zips/train2017.zip -o train2017.zip\n",
    "    !unzip train2017.zip\n",
    "    !mv train2017 cocoapi/images/train2017\n",
    "else:\n",
    "    print('File already downloaded')\n",
    "    \n",
    "if not os.path.isfile('val2017.zip'):\n",
    "    print('downloading..')\n",
    "    !curl http://images.cocodataset.org/zips/val2017.zip -o val2017.zip\n",
    "    !unzip val2017.zip\n",
    "    !mv val2017 cocoapi/images/val2017\n",
    "else:\n",
    "    print('File already downloaded')\n",
    "    \n",
    "if not os.path.isfile('annotations_trainval2017.zip'):\n",
    "    print('downloading..')\n",
    "    !curl http://images.cocodataset.org/annotations/annotations_trainval2017.zip -o annotations_trainval2017.zip\n",
    "    !unzip annotations_trainval2017.zip\n",
    "    !mv annotations cocoapi/annotations\n",
    "else:\n",
    "    print('File already downloaded')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from cocoapi.PythonAPI.pycocotools.coco import COCO\n",
    "import skimage.io as io\n",
    "import pylab\n",
    "pylab.rcParams['figure.figsize'] = (8.0, 10.0)\n",
    "pylab.rcParams.update({'font.size': 22})\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "    COCO categories: \n",
    "    person bicycle car motorcycle airplane bus train truck boat traffic light fire hydrant stop\n",
    "    sign parking meter bench bird cat dog horse sheep cow elephant bear zebra giraffe backpack\n",
    "    umbrella handbag tie suitcase frisbee skis snowboard sports ball kite baseball bat baseball\n",
    "    glove skateboard surfboard tennis racket bottle wine glass cup fork knife spoon bowl banana\n",
    "    apple sandwich orange broccoli carrot hot dog pizza donut cake chair couch potted plant bed\n",
    "    dining table toilet tv laptop mouse remote keyboard cell phone microwave oven toaster sink\n",
    "    refrigerator book clock vase scissors teddy bear hair drier toothbrush\n",
    "\"\"\"\n",
    "\n",
    "class COCO_Images_Masks(data.Dataset):\n",
    "    def __init__(self, mode, root='./cocoapi'):\n",
    "        self.root = root\n",
    "        self.mode = mode\n",
    "        self.all_superpixels = []\n",
    "        self.all_rag_boundary_graphs = []\n",
    "        self.all_sp_data = []\n",
    "        self.all_sp_node_labels = []\n",
    "        \n",
    "        self.n_sp = 100\n",
    "        self.compactness = 10\n",
    "        self.seed = 41\n",
    "        self.out_dir = '.'\n",
    "        self.dataset = 'COCO'\n",
    "        \n",
    "        self.args = self.mode, self.seed, self.n_sp, self.compactness\n",
    "        self.img_list = []\n",
    "        self.mask_list = []\n",
    "        \n",
    "        self.num_images = self._pack_images_masks(mode)\n",
    "        \n",
    "    def _pack_images_masks(self, mode):\n",
    "        # in this paper, we train on the train set and evaluate on the val set\n",
    "        assert mode in ['train', 'val']\n",
    "        \n",
    "        dataType = 'val2017' if mode == 'val' else 'train2017'\n",
    "        annFile = '{}/annotations/instances_{}.json'.format(self.root, dataType)\n",
    "        \n",
    "        # initialize COCO api for instance annotations\n",
    "        coco=COCO(annFile)\n",
    "        \n",
    "        # cats = coco.loadCats(coco.getCatIds())\n",
    "        # print(cats)\n",
    "        \n",
    "        imgIds = coco.getImgIds()\n",
    "        cat_ids = coco.getCatIds()\n",
    "        \n",
    "        all_imgs = coco.loadImgs(imgIds)#[0]\n",
    "        # for index in tqdm(range(len(all_imgs))):\n",
    "        \n",
    "        sample_length = 50\n",
    "        \n",
    "        for index in tqdm(range(sample_length)):\n",
    "            img_meta_info = all_imgs[index]\n",
    "            # img = io.imread(img_meta_info['coco_url']) # This command actually fetches the img from url each time\n",
    "            img = Image.open(os.path.join(self.root, 'images', dataType, img_meta_info['file_name'])).convert('RGB')\n",
    "        \n",
    "            anns_ids = coco.getAnnIds(imgIds=img_meta_info['id'], catIds=cat_ids, iscrowd=None)\n",
    "            anns = coco.loadAnns(anns_ids)\n",
    "            \n",
    "            mask = np.zeros((img_meta_info['height'],img_meta_info['width']))\n",
    "            for ann in anns:\n",
    "                mask = np.maximum(mask,coco.annToMask(ann)*ann['category_id'])\n",
    "\n",
    "            self.img_list.append(np.array(img))\n",
    "            self.mask_list.append(mask)\n",
    "        return sample_length\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self.img_list[index], self.mask_list[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.num_images\n",
    "\n",
    "val_set = COCO_Images_Masks('val')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_set = COCO_Images_Masks('val')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Functions definition for graph construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigma(dists, kth=8):\n",
    "    # Compute sigma and reshape\n",
    "    try:\n",
    "        # Get k-nearest neighbors for each node\n",
    "        knns = np.partition(dists, kth, axis=-1)[:, kth::-1]\n",
    "        sigma = knns.sum(axis=1).reshape((knns.shape[0], 1))/kth\n",
    "    except ValueError:     # handling for graphs with num_nodes less than kth\n",
    "        num_nodes = dists.shape[0]\n",
    "        # this sigma value is irrelevant since not used for final compute_edge_list\n",
    "        sigma = np.array([1]*num_nodes).reshape(num_nodes,1)\n",
    "        \n",
    "    return sigma + 1e-8 # adding epsilon to avoid zero value of sigma\n",
    "\n",
    "def compute_adjacency_matrix_images(coord, feat, use_feat=True, kth=8):\n",
    "    coord = coord.reshape(-1, 2)\n",
    "    # Compute coordinate distance\n",
    "    c_dist = cdist(coord, coord)\n",
    "    \n",
    "    if use_feat:\n",
    "        # Compute feature distance\n",
    "        f_dist = cdist(feat, feat)\n",
    "        # Compute adjacency\n",
    "        A = np.exp(- (c_dist/sigma(c_dist))**2 - (f_dist/sigma(f_dist))**2 )\n",
    "    else:\n",
    "        A = np.exp(- (c_dist/sigma(c_dist))**2)\n",
    "        \n",
    "    # Convert to symmetric matrix\n",
    "    A = 0.5 * (A + A.T)\n",
    "    A[np.diag_indices_from(A)] = 0\n",
    "    return A        \n",
    "\n",
    "\n",
    "def compute_edges_list(A, kth=8+1):\n",
    "    # Get k-similar neighbor indices for each node\n",
    "\n",
    "    num_nodes = A.shape[0]\n",
    "    new_kth = num_nodes - kth\n",
    "    \n",
    "    if num_nodes > 9:\n",
    "        knns = np.argpartition(A, new_kth-1, axis=-1)[:, new_kth:-1]\n",
    "        knn_values = np.partition(A, new_kth-1, axis=-1)[:, new_kth:-1] # NEW\n",
    "    else:\n",
    "        # handling for graphs with less than kth nodes\n",
    "        # in such cases, the resulting graph will be fully connected\n",
    "        knns = np.tile(np.arange(num_nodes), num_nodes).reshape(num_nodes, num_nodes)\n",
    "        knn_values = A # NEW\n",
    "        \n",
    "        # removing self loop\n",
    "        if num_nodes != 1:\n",
    "            knn_values = A[knns != np.arange(num_nodes)[:,None]].reshape(num_nodes,-1) # NEW\n",
    "            knns = knns[knns != np.arange(num_nodes)[:,None]].reshape(num_nodes,-1)\n",
    "    return knns, knn_values # NEW"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SuperpixDGL class for reading superpixels file and constructing graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SuperPixDGL(torch.utils.data.Dataset):\n",
    "    def __init__(self,\n",
    "                 data_dir,\n",
    "                 dataset,\n",
    "                 split,\n",
    "                 graph_format='edge_wt_only_coord'):\n",
    "        assert graph_format in ['edge_wt_only_coord', 'edge_wt_coord_feat', 'edge_wt_region_boundary']\n",
    "        self.split = split\n",
    "        self.graph_lists = []\n",
    "\n",
    "        with open(os.path.join(data_dir, 'sample/COCO_500sp_%s_superpixels.pkl' % split), 'rb') as f:\n",
    "            self.superpixels = pickle.load(f)\n",
    "        \n",
    "        with open(os.path.join(data_dir, 'sample/COCO_500sp_%s.pkl' % split), 'rb') as f:\n",
    "            self.labels, self.sp_data = pickle.load(f)\n",
    "            self.sp_node_labels = self.labels\n",
    "        \n",
    "        if graph_format == 'edge_wt_region_boundary':\n",
    "            with open(os.path.join(data_dir, 'sample/COCO_500sp_%s_rag_boundary_graphs.pkl' % split), 'rb') as f:\n",
    "                self.region_boundary_graphs = pickle.load(f)\n",
    "\n",
    "        self.graph_format = graph_format \n",
    "        self.n_samples = len(self.labels)\n",
    "        \n",
    "        self._prepare()\n",
    "    \n",
    "    def _prepare(self):\n",
    "        print(\"preparing %d graphs for the %s set...\" % (self.n_samples, self.split.upper()))\n",
    "        self.Adj_matrices, self.node_features, self.edges_lists, self.edge_features = [], [], [], []\n",
    "        for index, sample in enumerate(self.sp_data):\n",
    "            mean_px, coord = sample[:2]\n",
    "            \n",
    "            try:\n",
    "                coord = coord / self.img_size\n",
    "            except AttributeError:\n",
    "                VOC_has_variable_image_sizes = True\n",
    "                \n",
    "            if self.graph_format == 'edge_wt_coord_feat':\n",
    "                A = compute_adjacency_matrix_images(coord, mean_px) # using super-pixel locations + features\n",
    "                edges_list, edge_values_list = compute_edges_list(A) \n",
    "            elif self.graph_format == 'edge_wt_only_coord':\n",
    "                A = compute_adjacency_matrix_images(coord, mean_px, False) # using only super-pixel locations\n",
    "                edges_list, edge_values_list = compute_edges_list(A) \n",
    "            elif self.graph_format == 'edge_wt_region_boundary':\n",
    "                A, edges_list, edge_values_list = None, None, None\n",
    "\n",
    "            N_nodes = mean_px.shape[0]\n",
    "            \n",
    "            mean_px = mean_px.reshape(N_nodes, -1)\n",
    "            coord = coord.reshape(N_nodes, 2)\n",
    "            x = np.concatenate((mean_px, coord), axis=1)\n",
    "\n",
    "            if edge_values_list is not None:\n",
    "                edge_values_list = edge_values_list.reshape(-1) \n",
    "            \n",
    "            self.node_features.append(x)\n",
    "            self.edge_features.append(edge_values_list) \n",
    "            self.Adj_matrices.append(A)\n",
    "            self.edges_lists.append(edges_list)\n",
    "        \n",
    "    def __len__(self):\n",
    "        \"\"\"Return the number of graphs in the dataset.\"\"\"\n",
    "        return self.n_samples\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \n",
    "        if self.graph_format == 'edge_wt_region_boundary':\n",
    "            if self.node_features[idx].shape[0] == 1:\n",
    "                # handling for 1 node where the self loop would be the only edge\n",
    "                # since, VOC Superpixels has few samples (5 samples) with only 1 node\n",
    "                g = dgl.DGLGraph()\n",
    "                g.add_nodes(self.node_features[idx].shape[0]) \n",
    "                g = dgl.add_self_loop(g)\n",
    "                # dummy edge feat since no actual edge present\n",
    "                g.edata['feat'] = torch.zeros(1, 2) # 1 edge and 2 feat dim\n",
    "                self.Adj_matrices[idx] = g.adjacency_matrix().to_dense().numpy()\n",
    "            else:\n",
    "                g = dgl.from_networkx(self.region_boundary_graphs[idx].to_directed(),\n",
    "                                  edge_attrs=['weight', 'count'])\n",
    "                g.edata['feat'] = torch.stack((g.edata['weight'], g.edata['count']),-1)\n",
    "                del g.edata['weight'], g.edata['count']\n",
    "                self.Adj_matrices[idx] = g.adjacency_matrix().to_dense().numpy()\n",
    "        else:\n",
    "            g = dgl.DGLGraph()\n",
    "            g.add_nodes(self.node_features[idx].shape[0])\n",
    "            for src, dsts in enumerate(self.edges_lists[idx]):\n",
    "                g.add_edges(src, dsts[dsts!=src])\n",
    "                \n",
    "        g.ndata['feat'] = torch.Tensor(self.node_features[idx])\n",
    "        \n",
    "        return g, self.sp_node_labels[idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Only coordinates for knn graph construction\n",
    "This is done by setting `graph_format` option.   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Taking the test dataset only for sample visualization\n",
    "graph_format = 'edge_wt_only_coord'\n",
    "tt = time.time()\n",
    "data_only_coord_knn = SuperPixDGL(\"./\", \n",
    "                                  dataset='COCO',\n",
    "                                  split='val', \n",
    "                                  graph_format=graph_format)\n",
    "\n",
    "print(\"Time taken: {:.4f}s\".format(time.time()-tt))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Both coordinates and features for knn graph construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_format = 'edge_wt_coord_feat'\n",
    "tt = time.time()\n",
    "data_coord_feat_knn = SuperPixDGL(\"./\", \n",
    "                                  dataset='COCO',\n",
    "                                  split='val', \n",
    "                                  graph_format=graph_format)\n",
    "\n",
    "print(\"Time taken: {:.4f}s\".format(time.time()-tt))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Region Boundary based graph construction with variable edges for every node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_format = 'edge_wt_region_boundary'\n",
    "tt = time.time()\n",
    "data_region_boundary = SuperPixDGL(\"./\", \n",
    "                                  dataset='COCO',\n",
    "                                  split='val', \n",
    "                                  graph_format=graph_format)\n",
    "\n",
    "print(\"Time taken: {:.4f}s\".format(time.time()-tt))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Superpixels plot function definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.distance import pdist, squareform\n",
    "from pylab import rcParams\n",
    "from skimage.segmentation import mark_boundaries\n",
    "\n",
    "def show_image(plt, idx, alpha=1.0):\n",
    "    plt.imshow(val_set.img_list[idx])\n",
    "\n",
    "    plt.axis('off')\n",
    "    plt.title.set_text(\" Original Image\")\n",
    "\n",
    "def plot_superpixels_graph(plt, data, idx, overlay=None):\n",
    "    with_edges = True\n",
    "    sp_data = data.sp_data[idx]\n",
    "    node_labels = data.sp_node_labels[idx]\n",
    "    adj_matrix = data.Adj_matrices[idx]\n",
    "    g = data[idx][0]\n",
    "    \n",
    "    Y = squareform(pdist(sp_data[1], 'euclidean'))\n",
    "    x_coord = sp_data[1] #np.flip(dataset.train.sp_data[_][1], 1)\n",
    "    # intensities = sp_data[0].mean(axis=1)\n",
    "    \n",
    "    G = nx.from_numpy_matrix(Y)\n",
    "    pos = dict(zip(range(len(x_coord)), x_coord.tolist()))\n",
    "    rotated_pos = {node: (y,-x) for (node, (x,y)) in pos.items()} # rotate the coords by 90 degree\n",
    "    \n",
    "    if overlay is not None:\n",
    "        if overlay==\"image\":\n",
    "            plt.imshow(val_set.img_list[idx])\n",
    "        else:\n",
    "            pass\n",
    "        plt.imshow(mark_boundaries(val_set.img_list[idx], data.superpixels[idx],\n",
    "                               color=[0,1,0], outline_color=[0,1,0]))\n",
    "        rotated_pos = {node: (x,-y) for (node, (x,y)) in rotated_pos.items()} # reflect the graph on x-axis for overlaying\n",
    "    \n",
    "    edge_list = torch.stack(g.edges(),0).T.tolist()\n",
    "        \n",
    "    nx.draw_networkx_nodes(G, rotated_pos, node_color=node_labels, node_size=40) # len(intensities))\n",
    "    if with_edges and overlay==\"image\":\n",
    "        nx.draw_networkx_edges(G, rotated_pos, edge_list, edge_color='w', alpha=0.65)\n",
    "    elif with_edges and overlay != \"slic\":\n",
    "        nx.draw_networkx_edges(G, rotated_pos, edge_list, alpha=0.15)\n",
    "    \n",
    "    title = \"\"\n",
    "    \n",
    "    if data.graph_format == 'edge_wt_region_boundary':\n",
    "        title += \" `rag-boundary` graph overlay on SLIC SP\"\n",
    "        if overlay == None:\n",
    "            title = \"final `rag-boundary` graph\"\n",
    "    elif data.graph_format == 'edge_wt_only_coord':\n",
    "        title += \" `coo` graph overlay on SLIC SP\"\n",
    "        if overlay == None:\n",
    "            title = \"final `coo` graph\"\n",
    "    else:\n",
    "        title += \" `coo-feat` graph overlay on SLIC SP\"\n",
    "        if overlay == None:\n",
    "            title = \"final `coo-feat` graph\"\n",
    "    \n",
    "    if overlay == \"slic\":\n",
    "        title = \"SLIC SP (compactness=30)\"\n",
    "    \n",
    "    plt.title.set_text(title)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plotting sample superpixels, and graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples_plot = 4\n",
    "# sample_indices = np.random.choice(int(len(data_only_coord_knn)/2), num_samples_plot, replace=False)\n",
    "sample_indices = np.array([38])#, 16]) # Set manually\n",
    "print(sample_indices)     \n",
    "\n",
    "if not os.path.exists('./coco_viz_files'):\n",
    "    os.makedirs('./coco_viz_files')\n",
    "    \n",
    "for f_idx, idx in enumerate(sample_indices):\n",
    "    print()\n",
    "    print(\"------ Image ID ------ : \", idx)\n",
    "    print(\"Num nodes: \", data_region_boundary[idx][0].number_of_nodes())\n",
    "    print(\"Num edges Graph: edge_wt_only_coord: \", data_only_coord_knn[idx][0].number_of_edges())\n",
    "    print(\"Num edges Graph: edge_wt_coord_feat: \", data_coord_feat_knn[idx][0].number_of_edges())\n",
    "    print(\"Num edges Graph: Region boundary graph: \", data_region_boundary[idx][0].number_of_edges())\n",
    "    \n",
    "#     f = plt.figure(f_idx, figsize=(18, 3))\n",
    "#     plt1 = f.add_subplot(141)\n",
    "#     show_image(plt1, idx)\n",
    "\n",
    "#     plt2 = f.add_subplot(142)\n",
    "#     plot_superpixels_graph(plt2, data_only_coord_knn, idx, overlay=\"slic\")\n",
    "    \n",
    "#     plt3 = f.add_subplot(143)\n",
    "#     plot_superpixels_graph(plt3, data_region_boundary, idx, overlay=\"image\")\n",
    "\n",
    "#     plt4 = f.add_subplot(144)\n",
    "#     plot_superpixels_graph(plt4, data_region_boundary, idx)\n",
    "    \n",
    "#     plt.subplots_adjust(hspace=0.1, wspace=0.1)\n",
    "#     f.savefig('coco_viz_files/coco_'+str(idx)+'_row1.pdf', dpi=300)\n",
    "#     plt.show()\n",
    "    \n",
    "    f = plt.figure(f_idx, figsize=(20, 7))\n",
    "    plt1 = f.add_subplot(121)\n",
    "    show_image(plt1, idx)\n",
    "\n",
    "    plt2 = f.add_subplot(122)\n",
    "    plot_superpixels_graph(plt2, data_only_coord_knn, idx, overlay=\"slic\")\n",
    "    \n",
    "    plt.subplots_adjust(hspace=0.1, wspace=0.1)\n",
    "    f.savefig('coco_viz_files/coco_'+str(idx)+'_row1.pdf', dpi=300)\n",
    "    plt.show()\n",
    "        \n",
    "    f = plt.figure(f_idx, figsize=(20, 7))\n",
    "    plt3 = f.add_subplot(121)\n",
    "    plot_superpixels_graph(plt3, data_only_coord_knn, idx, overlay=\"image\")\n",
    "    \n",
    "    plt4 = f.add_subplot(122)\n",
    "    plot_superpixels_graph(plt4, data_only_coord_knn, idx)\n",
    "    \n",
    "    plt.subplots_adjust(hspace=0.1, wspace=0.1)\n",
    "    f.savefig('coco_viz_files/coco_'+str(idx)+'_row2.pdf', dpi=300)\n",
    "    plt.show()\n",
    "    \n",
    "    f = plt.figure(f_idx, figsize=(20, 7))\n",
    "    plt5 = f.add_subplot(121)\n",
    "    plot_superpixels_graph(plt5, data_coord_feat_knn, idx, overlay=\"image\")\n",
    "\n",
    "    \n",
    "    plt6 = f.add_subplot(122)\n",
    "    plot_superpixels_graph(plt6, data_coord_feat_knn, idx)\n",
    "    \n",
    "    plt.subplots_adjust(hspace=0.1, wspace=0.1)\n",
    "    f.savefig('coco_viz_files/coco_'+str(idx)+'_row3.pdf', dpi=300)\n",
    "    plt.show()\n",
    "    \n",
    "    f = plt.figure(f_idx, figsize=(20, 7))\n",
    "    plt7 = f.add_subplot(121)\n",
    "    plot_superpixels_graph(plt7, data_region_boundary, idx, overlay=\"image\")\n",
    "\n",
    "    plt8 = f.add_subplot(122)\n",
    "    plot_superpixels_graph(plt8, data_region_boundary, idx)\n",
    "\n",
    "    plt.subplots_adjust(hspace=0.1, wspace=0.1)\n",
    "    f.savefig('coco_viz_files/coco_'+str(idx)+'_row4.pdf', dpi=300)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
