{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook for preparing and saving VOC2011 (VOC) graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import pickle\n",
    "import time\n",
    "import os\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run the generate_vocsuperpixels_raw.ipynb notebook inside current directory\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert to DGL format and save with pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "print(os.getcwd())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "# %load_ext autoreload\n",
    "# %autoreload 2\n",
    "\n",
    "from superpixels import VOCSegDatasetDGL \n",
    "\n",
    "# from data.data import LoadData\n",
    "from torch.utils.data import DataLoader\n",
    "# from data.superpixels import VOCSegDataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_NAME = 'VOC'\n",
    "graph_format = ['edge_wt_only_coord', 'edge_wt_coord_feat', 'edge_wt_region_boundary']\n",
    "graph_format = ['edge_wt_region_boundary']\n",
    "dataset = []\n",
    "for gf in graph_format:\n",
    "    start = time.time()\n",
    "    data = VOCSegDatasetDGL(DATASET_NAME, gf, slic_compactness=30) \n",
    "    print('Time (sec):',time.time() - start)\n",
    "    dataset.append(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_histo_graphs(dataset, title):\n",
    "    # histogram of graph sizes\n",
    "    graph_sizes = []\n",
    "    for graph in dataset:\n",
    "        graph_sizes.append(graph[0].number_of_nodes())\n",
    "        #graph_sizes.append(graph[0].number_of_edges())\n",
    "    plt.figure(1)\n",
    "    plt.hist(graph_sizes, bins=20)\n",
    "    plt.title(title)\n",
    "    plt.show()\n",
    "    graph_sizes = torch.Tensor(graph_sizes)\n",
    "    print('nb/min/max :',len(graph_sizes),graph_sizes.min().long().item(),graph_sizes.max().long().item())\n",
    "    \n",
    "plot_histo_graphs(dataset[0].train,'trainset')\n",
    "plot_histo_graphs(dataset[0].val,'valset')\n",
    "plot_histo_graphs(dataset[0].test,'testset')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(dataset[0].train))\n",
    "print(len(dataset[0].val))\n",
    "print(len(dataset[0].test))\n",
    "\n",
    "print(dataset[0].train[0])\n",
    "print(dataset[0].val[0])\n",
    "print(dataset[0].test[0])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare train, test and val pickles for PyG data source"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dump_voc_pyg_source(dataset, graph_format):\n",
    "    vallist = []\n",
    "    for data in dataset.val:\n",
    "        # print(data)\n",
    "        x = data[0].ndata['feat'] #x\n",
    "        edge_attr = data[0].edata['feat'] #edge_attr\n",
    "        edge_index = torch.stack(data[0].edges(), 0) #edge_index\n",
    "        y = data[1] #y\n",
    "        vallist.append((x, edge_attr, edge_index, y))\n",
    "\n",
    "    trainlist = []\n",
    "    for data in dataset.train:\n",
    "        # print(data)\n",
    "        x = data[0].ndata['feat'] #x\n",
    "        edge_attr = data[0].edata['feat'] #edge_attr\n",
    "        edge_index = torch.stack(data[0].edges(), 0) #edge_index\n",
    "        y = data[1] #y\n",
    "        trainlist.append((x, edge_attr, edge_index, y))\n",
    "\n",
    "    testlist = []\n",
    "    for data in dataset.test:\n",
    "        # print(data)\n",
    "        x = data[0].ndata['feat'] #x\n",
    "        edge_attr = data[0].edata['feat'] #edge_attr\n",
    "        edge_index = torch.stack(data[0].edges(), 0) #edge_index\n",
    "        y = data[1] #y\n",
    "        testlist.append((x, edge_attr, edge_index, y))\n",
    "        \n",
    "    print(len(trainlist), len(vallist), len(testlist))\n",
    "    \n",
    "    pyg_source_dir = './voc_superpixels_'+graph_format\n",
    "    if not os.path.exists(pyg_source_dir):\n",
    "        os.makedirs(pyg_source_dir)\n",
    "    \n",
    "    start = time.time()\n",
    "    with open(pyg_source_dir+'/train.pickle','wb') as f:\n",
    "        pickle.dump(trainlist,f)\n",
    "    print('Time (sec):',time.time() - start) # 1.84s\n",
    "    \n",
    "    start = time.time()\n",
    "    with open(pyg_source_dir+'/val.pickle','wb') as f:\n",
    "        pickle.dump(vallist,f)\n",
    "    print('Time (sec):',time.time() - start) # 0.29s\n",
    "    \n",
    "    start = time.time()\n",
    "    with open(pyg_source_dir+'/test.pickle','wb') as f:\n",
    "        pickle.dump(testlist,f)\n",
    "    print('Time (sec):',time.time() - start) # 0.44s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx, gf in enumerate(graph_format):\n",
    "    dump_voc_pyg_source(dataset[idx], gf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(dataset[0].val),len(dataset[0].train),len(dataset[0].test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(dataset[1].val),len(dataset[1].train),len(dataset[1].test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(dataset[2].val),len(dataset[0].train),len(dataset[2].test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(1428, 8498, 1430)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
