{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization of the superpixel images, graphs, edge connections and node labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Superpixels from\n",
    "Superpixels are generated using the notebook `'./scripts/PascalVOC/generate_data_voc.ipynb'`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from torchvision import transforms, datasets\n",
    "\n",
    "import os\n",
    "import pickle\n",
    "from scipy.spatial.distance import cdist\n",
    "import scipy.io as sio\n",
    "from scipy import ndimage\n",
    "import numpy as np\n",
    "\n",
    "import dgl\n",
    "import torch\n",
    "from torch.utils import data\n",
    "from PIL import Image\n",
    "import time\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "\n",
    "matplotlib.rcParams.update({'font.size': 22})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare VOC_SBD Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the image and mask files if not already there\n",
    "if not os.path.isfile('./VOC/benchmark.tgz'):\n",
    "    print('downloading..')\n",
    "    !wget http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz\n",
    "    !tar -xzf benchmark.tgz -C ./VOC\n",
    "else:\n",
    "    print('File already downloaded')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = './'\n",
    "num_classes = 21\n",
    "ignore_label = 255\n",
    "\n",
    "\"\"\"\n",
    "color map\n",
    "0=background, 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle # 6=bus, 7=car, 8=cat, 9=chair, 10=cow, 11=diningtable,\n",
    "12=dog, 13=horse, 14=motorbike, 15=person # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor\n",
    "\"\"\"\n",
    "\n",
    "class VOC_SBD_Images(data.Dataset):\n",
    "    def __init__(self, mode):\n",
    "        self.imgs = self.read_dataset(mode)\n",
    "        if len(self.imgs) == 0:\n",
    "            raise RuntimeError('Found 0 images, please check the data set')\n",
    "        self.mode = mode\n",
    "        self.img_list = []\n",
    "        self.mask_list = []\n",
    "        self._pack_images_masks()\n",
    "        \n",
    "    def read_dataset(self, mode):\n",
    "        # in this paper, we train on the train set and evaluate on the val set\n",
    "        assert mode in ['train', 'val']\n",
    "        items = []\n",
    "        img_path = os.path.join(root, 'benchmark_RELEASE', 'dataset', 'img')\n",
    "        mask_path = os.path.join(root, 'benchmark_RELEASE', 'dataset', 'cls')\n",
    "\n",
    "        if mode == 'train':\n",
    "            data_list = [l.strip('\\n') for l in open(os.path.join(\n",
    "                root, 'benchmark_RELEASE', 'dataset', 'train.txt')).readlines()]\n",
    "        elif mode == 'val':\n",
    "            data_list = [l.strip('\\n') for l in open(os.path.join(\n",
    "                root, 'benchmark_RELEASE', 'dataset', 'val.txt')).readlines()]        \n",
    "\n",
    "        for it in data_list:\n",
    "            item = (os.path.join(img_path, it + '.jpg'), os.path.join(mask_path, it + '.mat'))\n",
    "            items.append(item)\n",
    "        return items\n",
    "    \n",
    "    def _pack_images_masks(self):\n",
    "        for index in range(self.__len__()):\n",
    "            img_path, mask_path = self.imgs[index]\n",
    "            img = Image.open(img_path).convert('RGB')\n",
    "            \n",
    "            mask = sio.loadmat(mask_path)['GTcls']['Segmentation'][0][0]\n",
    "            mask = Image.fromarray(mask.astype(np.uint8))\n",
    "\n",
    "            self.img_list.append(np.array(img))\n",
    "            self.mask_list.append(np.array(mask))\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self.img_list[index], self.mask_list[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.imgs)\n",
    "\n",
    "val_set = VOC_SBD_Images('val')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Functions definition for graph construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigma(dists, kth=8):\n",
    "    # Compute sigma and reshape\n",
    "    try:\n",
    "        # Get k-nearest neighbors for each node\n",
    "        knns = np.partition(dists, kth, axis=-1)[:, kth::-1]\n",
    "        sigma = knns.sum(axis=1).reshape((knns.shape[0], 1))/kth\n",
    "    except ValueError:     # handling for graphs with num_nodes less than kth\n",
    "        num_nodes = dists.shape[0]\n",
    "        # this sigma value is irrelevant since not used for final compute_edge_list\n",
    "        sigma = np.array([1]*num_nodes).reshape(num_nodes,1)\n",
    "        \n",
    "    return sigma + 1e-8 # adding epsilon to avoid zero value of sigma\n",
    "\n",
    "def compute_adjacency_matrix_images(coord, feat, use_feat=True, kth=8):\n",
    "    coord = coord.reshape(-1, 2)\n",
    "    # Compute coordinate distance\n",
    "    c_dist = cdist(coord, coord)\n",
    "    \n",
    "    if use_feat:\n",
    "        # Compute feature distance\n",
    "        f_dist = cdist(feat, feat)\n",
    "        # Compute adjacency\n",
    "        A = np.exp(- (c_dist/sigma(c_dist))**2 - (f_dist/sigma(f_dist))**2 )\n",
    "    else:\n",
    "        A = np.exp(- (c_dist/sigma(c_dist))**2)\n",
    "        \n",
    "    # Convert to symmetric matrix\n",
    "    A = 0.5 * (A + A.T)\n",
    "    A[np.diag_indices_from(A)] = 0\n",
    "    return A        \n",
    "\n",
    "\n",
    "def compute_edges_list(A, kth=8+1):\n",
    "    # Get k-similar neighbor indices for each node\n",
    "\n",
    "    num_nodes = A.shape[0]\n",
    "    new_kth = num_nodes - kth\n",
    "    \n",
    "    if num_nodes > 9:\n",
    "        knns = np.argpartition(A, new_kth-1, axis=-1)[:, new_kth:-1]\n",
    "        knn_values = np.partition(A, new_kth-1, axis=-1)[:, new_kth:-1] # NEW\n",
    "    else:\n",
    "        # handling for graphs with less than kth nodes\n",
    "        # in such cases, the resulting graph will be fully connected\n",
    "        knns = np.tile(np.arange(num_nodes), num_nodes).reshape(num_nodes, num_nodes)\n",
    "        knn_values = A # NEW\n",
    "        \n",
    "        # removing self loop\n",
    "        if num_nodes != 1:\n",
    "            knn_values = A[knns != np.arange(num_nodes)[:,None]].reshape(num_nodes,-1) # NEW\n",
    "            knns = knns[knns != np.arange(num_nodes)[:,None]].reshape(num_nodes,-1)\n",
    "    return knns, knn_values # NEW"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SuperpixDGL class for reading superpixels file and constructing graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SuperPixDGL(torch.utils.data.Dataset):\n",
    "    def __init__(self,\n",
    "                 data_dir,\n",
    "                 dataset,\n",
    "                 split,\n",
    "                 graph_format='edge_wt_only_coord',\n",
    "                 slic_compactness=30):\n",
    "        assert graph_format in ['edge_wt_only_coord', 'edge_wt_coord_feat', 'edge_wt_region_boundary']\n",
    "        self.split = split\n",
    "        self.graph_lists = []\n",
    "\n",
    "        with open(os.path.join(data_dir, 'VOC_500sp_%scmpt_%s_superpixels.pkl' % (str(slic_compactness), split)), 'rb') as f:\n",
    "            self.superpixels = pickle.load(f)\n",
    "        \n",
    "        with open(os.path.join(data_dir, 'VOC_500sp_%scmpt_%s.pkl' % (str(slic_compactness), split)), 'rb') as f:\n",
    "            self.labels, self.sp_data = pickle.load(f)\n",
    "            #self.graph_labels = self.labels\n",
    "            self.sp_node_labels = self.labels\n",
    "        \n",
    "        if graph_format == 'edge_wt_region_boundary':\n",
    "            with open(os.path.join(data_dir, 'VOC_500sp_%scmpt_%s_rag_boundary_graphs.pkl' % (str(slic_compactness),\n",
    "                                                                                              split)), 'rb') as f:\n",
    "                self.region_boundary_graphs = pickle.load(f)\n",
    "\n",
    "        self.graph_format = graph_format \n",
    "        self.n_samples = len(self.labels)\n",
    "        \n",
    "        self._prepare()\n",
    "    \n",
    "    def _prepare(self):\n",
    "        print(\"preparing %d graphs for the %s set...\" % (self.n_samples, self.split.upper()))\n",
    "        self.Adj_matrices, self.node_features, self.edges_lists, self.edge_features = [], [], [], []\n",
    "        for index, sample in enumerate(self.sp_data):\n",
    "            mean_px, coord = sample[:2]\n",
    "            \n",
    "            try:\n",
    "                coord = coord / self.img_size\n",
    "            except AttributeError:\n",
    "                VOC_has_variable_image_sizes = True\n",
    "                \n",
    "            if self.graph_format == 'edge_wt_coord_feat':\n",
    "                A = compute_adjacency_matrix_images(coord, mean_px) # using super-pixel locations + features\n",
    "                edges_list, edge_values_list = compute_edges_list(A) \n",
    "            elif self.graph_format == 'edge_wt_only_coord':\n",
    "                A = compute_adjacency_matrix_images(coord, mean_px, False) # using only super-pixel locations\n",
    "                edges_list, edge_values_list = compute_edges_list(A) \n",
    "            elif self.graph_format == 'edge_wt_region_boundary':\n",
    "                A, edges_list, edge_values_list = None, None, None\n",
    "\n",
    "            N_nodes = mean_px.shape[0]\n",
    "            \n",
    "            mean_px = mean_px.reshape(N_nodes, -1)\n",
    "            coord = coord.reshape(N_nodes, 2)\n",
    "            x = np.concatenate((mean_px, coord), axis=1)\n",
    "\n",
    "            if edge_values_list is not None:\n",
    "                edge_values_list = edge_values_list.reshape(-1) \n",
    "            \n",
    "            self.node_features.append(x)\n",
    "            self.edge_features.append(edge_values_list) \n",
    "            self.Adj_matrices.append(A)\n",
    "            self.edges_lists.append(edges_list)\n",
    "        \n",
    "    def __len__(self):\n",
    "        \"\"\"Return the number of graphs in the dataset.\"\"\"\n",
    "        return self.n_samples\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \n",
    "        if self.graph_format == 'edge_wt_region_boundary':\n",
    "            if self.node_features[idx].shape[0] == 1:\n",
    "                # handling for 1 node where the self loop would be the only edge\n",
    "                # since, VOC Superpixels has few samples (5 samples) with only 1 node\n",
    "                g = dgl.DGLGraph()\n",
    "                g.add_nodes(self.node_features[idx].shape[0]) \n",
    "                g = dgl.add_self_loop(g)\n",
    "                # dummy edge feat since no actual edge present\n",
    "                g.edata['feat'] = torch.zeros(1, 2) # 1 edge and 2 feat dim\n",
    "                self.Adj_matrices[idx] = g.adjacency_matrix().to_dense().numpy()\n",
    "            else:\n",
    "                g = dgl.from_networkx(self.region_boundary_graphs[idx].to_directed(),\n",
    "                                  edge_attrs=['weight', 'count'])\n",
    "                g.edata['feat'] = torch.stack((g.edata['weight'], g.edata['count']),-1)\n",
    "                del g.edata['weight'], g.edata['count']\n",
    "                self.Adj_matrices[idx] = g.adjacency_matrix().to_dense().numpy()\n",
    "        else:\n",
    "            g = dgl.DGLGraph()\n",
    "            g.add_nodes(self.node_features[idx].shape[0])\n",
    "            for src, dsts in enumerate(self.edges_lists[idx]):\n",
    "                g.add_edges(src, dsts[dsts!=src])\n",
    "                \n",
    "        g.ndata['feat'] = torch.Tensor(self.node_features[idx])\n",
    "        \n",
    "        return g, self.sp_node_labels[idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Only coordinates for knn graph construction\n",
    "This is done by setting `graph_format` option.   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Taking the test dataset only for sample visualization\n",
    "graph_format = 'edge_wt_only_coord'\n",
    "tt = time.time()\n",
    "data_only_coord_knn = SuperPixDGL(\"./\", \n",
    "                                  dataset='VOC',\n",
    "                                  split='val', \n",
    "                                  graph_format=graph_format)\n",
    "\n",
    "print(\"Time taken: {:.4f}s\".format(time.time()-tt))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Both coordinates and features for knn graph construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_format = 'edge_wt_coord_feat'\n",
    "tt = time.time()\n",
    "data_coord_feat_knn = SuperPixDGL(\"./\", \n",
    "                                  dataset='VOC',\n",
    "                                  split='val', \n",
    "                                  graph_format=graph_format)\n",
    "\n",
    "print(\"Time taken: {:.4f}s\".format(time.time()-tt))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Region Boundary based graph construction with variable edges for every node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_format = 'edge_wt_region_boundary'\n",
    "tt = time.time()\n",
    "data_region_boundary = SuperPixDGL(\"./\", \n",
    "                                  dataset='VOC',\n",
    "                                  split='val', \n",
    "                                  graph_format=graph_format)\n",
    "\n",
    "print(\"Time taken: {:.4f}s\".format(time.time()-tt))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Superpixels plot function definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.distance import pdist, squareform\n",
    "from pylab import rcParams\n",
    "from skimage.segmentation import mark_boundaries\n",
    "\n",
    "def show_image(plt, idx, alpha=1.0):\n",
    "    plt.imshow(val_set.img_list[idx])\n",
    "\n",
    "    plt.axis('off')\n",
    "    plt.title.set_text(\" Original Image\")\n",
    "\n",
    "def plot_superpixels_graph(plt, data, idx, overlay=None):\n",
    "    with_edges = True\n",
    "    sp_data = data.sp_data[idx]\n",
    "    node_labels = data.sp_node_labels[idx]\n",
    "    adj_matrix = data.Adj_matrices[idx]\n",
    "    g = data[idx][0]\n",
    "    \n",
    "    Y = squareform(pdist(sp_data[1], 'euclidean'))\n",
    "    x_coord = sp_data[1] #np.flip(dataset.train.sp_data[_][1], 1)\n",
    "    # intensities = sp_data[0].mean(axis=1)\n",
    "    \n",
    "    G = nx.from_numpy_matrix(Y)\n",
    "    pos = dict(zip(range(len(x_coord)), x_coord.tolist()))\n",
    "    rotated_pos = {node: (y,-x) for (node, (x,y)) in pos.items()} # rotate the coords by 90 degree\n",
    "    \n",
    "    if overlay is not None:\n",
    "        if overlay==\"image\":\n",
    "            plt.imshow(val_set.img_list[idx])\n",
    "        else:\n",
    "            pass\n",
    "        plt.imshow(mark_boundaries(val_set.img_list[idx], data.superpixels[idx],\n",
    "                               color=[0,1,0], outline_color=[0,1,0]))\n",
    "        rotated_pos = {node: (x,-y) for (node, (x,y)) in rotated_pos.items()} # reflect the graph on x-axis for overlaying\n",
    "    \n",
    "    edge_list = torch.stack(g.edges(),0).T.tolist()\n",
    "        \n",
    "    nx.draw_networkx_nodes(G, rotated_pos, node_color=node_labels, node_size=40) # len(intensities))\n",
    "    if with_edges and overlay==\"image\":\n",
    "        nx.draw_networkx_edges(G, rotated_pos, edge_list, edge_color='w', alpha=0.65)\n",
    "    elif with_edges and overlay != \"slic\":\n",
    "        nx.draw_networkx_edges(G, rotated_pos, edge_list, alpha=0.15)\n",
    "    \n",
    "    title = \"\"\n",
    "    \n",
    "    if data.graph_format == 'edge_wt_region_boundary':\n",
    "        title += \" `rag-boundary` graph overlay on SLIC SP\"\n",
    "        if overlay == None:\n",
    "            title = \"final `rag-boundary` graph\"\n",
    "    elif data.graph_format == 'edge_wt_only_coord':\n",
    "        title += \" `coo` graph overlay on SLIC SP\"\n",
    "        if overlay == None:\n",
    "            title = \"final `coo` graph\"\n",
    "    else:\n",
    "        title += \" `coo-feat` graph overlay on SLIC SP\"\n",
    "        if overlay == None:\n",
    "            title = \"final `coo-feat` graph\"\n",
    "    \n",
    "    if overlay == \"slic\":\n",
    "        title = \"SLIC SP (compactness=30)\"\n",
    "    \n",
    "    plt.title.set_text(title)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plotting sample superpixels, and graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples_plot = 4\n",
    "# sample_indices = np.random.choice(int(len(data_only_coord_knn)/2), num_samples_plot, replace=False)\n",
    "sample_indices = np.array([117]) #0, 165, 312])\n",
    "print(sample_indices)     \n",
    "\n",
    "if not os.path.exists('./voc_viz_files'):\n",
    "    os.makedirs('./voc_viz_files')\n",
    "    \n",
    "for f_idx, idx in enumerate(sample_indices):\n",
    "    print()\n",
    "    print(\"------ Image ID ------ : \", idx)\n",
    "    print(\"Num nodes: \", data_region_boundary[idx][0].number_of_nodes())\n",
    "    print(\"Num edges Graph: edge_wt_only_coord: \", data_only_coord_knn[idx][0].number_of_edges())\n",
    "    print(\"Num edges Graph: edge_wt_coord_feat: \", data_coord_feat_knn[idx][0].number_of_edges())\n",
    "    print(\"Num edges Graph: Region boundary graph: \", data_region_boundary[idx][0].number_of_edges())\n",
    "    \n",
    "    f = plt.figure(f_idx, figsize=(20, 7))\n",
    "    plt1 = f.add_subplot(121)\n",
    "    show_image(plt1, idx)\n",
    "\n",
    "    plt2 = f.add_subplot(122)\n",
    "    plot_superpixels_graph(plt2, data_only_coord_knn, idx, overlay=\"slic\")\n",
    "    \n",
    "    plt.subplots_adjust(hspace=0.1, wspace=0.1)\n",
    "    f.savefig('voc_viz_files/voc_'+str(idx)+'_row1.pdf', dpi=300)\n",
    "    plt.show()\n",
    "        \n",
    "    f = plt.figure(f_idx, figsize=(20, 7))\n",
    "    plt3 = f.add_subplot(121)\n",
    "    plot_superpixels_graph(plt3, data_only_coord_knn, idx, overlay=\"image\")\n",
    "    \n",
    "    plt4 = f.add_subplot(122)\n",
    "    plot_superpixels_graph(plt4, data_only_coord_knn, idx)\n",
    "    \n",
    "    plt.subplots_adjust(hspace=0.1, wspace=0.1)\n",
    "    f.savefig('voc_viz_files/voc_'+str(idx)+'_row2.pdf', dpi=300)\n",
    "    plt.show()\n",
    "    \n",
    "    f = plt.figure(f_idx, figsize=(20, 7))\n",
    "    plt5 = f.add_subplot(121)\n",
    "    plot_superpixels_graph(plt5, data_coord_feat_knn, idx, overlay=\"image\")\n",
    "\n",
    "    \n",
    "    plt6 = f.add_subplot(122)\n",
    "    plot_superpixels_graph(plt6, data_coord_feat_knn, idx)\n",
    "    \n",
    "    plt.subplots_adjust(hspace=0.1, wspace=0.1)\n",
    "    f.savefig('voc_viz_files/voc_'+str(idx)+'_row3.pdf', dpi=300)\n",
    "    plt.show()\n",
    "    \n",
    "    f = plt.figure(f_idx, figsize=(20, 7))\n",
    "    plt7 = f.add_subplot(121)\n",
    "    plot_superpixels_graph(plt7, data_region_boundary, idx, overlay=\"image\")\n",
    "\n",
    "    plt8 = f.add_subplot(122)\n",
    "    plot_superpixels_graph(plt8, data_region_boundary, idx)\n",
    "\n",
    "    plt.subplots_adjust(hspace=0.1, wspace=0.1)\n",
    "    f.savefig('voc_viz_files/voc_'+str(idx)+'_row4.pdf', dpi=300)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
