{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import itertools\n",
    "from pathlib import Path\n",
    "import yaml\n",
    "import importlib\n",
    "\n",
    "import numpy as np\n",
    "from torch.utils.data import DataLoader\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from elements.classes import ElementDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Making a dataset from scratch"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "save_dir = Path(\"example_datasets\")\n",
    "save_dir.mkdir(exist_ok=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class_configs = [\n",
    "    {\"shape\": None, \"color\": None, \"texture\": \"solid\"},\n",
    "    {\"shape\": None, \"color\": \"red\", \"texture\": \"solid\"},\n",
    "    {\"shape\": None, \"color\": \"blue\", \"texture\": \"stripes_diagonal\"},\n",
    "    {\"shape\": None, \"color\": \"green\", \"texture\": \"spots_polka\"},\n",
    "    {\"shape\": \"circle\", \"color\": None, \"texture\": \"solid\"},\n",
    "    {\"shape\": \"circle\", \"color\": None, \"texture\": \"spots_polka\"},\n",
    "    {\"shape\": \"triangle\", \"color\": \"green\", \"texture\": None},\n",
    "    {\"shape\": \"square\", \"color\": \"blue\", \"texture\": None},\n",
    "    {\"shape\": \"triangle\", \"color\": \"red\", \"texture\": \"stripes_diagonal\"},\n",
    "    {\"shape\": \"triangle\", \"color\": \"blue\", \"texture\": \"stripes_diagonal\"},\n",
    "    {\"shape\": \"square\", \"color\": \"green\", \"texture\": \"spots_polka\"},\n",
    "    {\"shape\": \"plus\", \"color\": \"magenta\", \"texture\": \"spots_polka\"},\n",
    "]\n",
    "\n",
    "allowed_shapes = ['square', 'circle', 'triangle', 'plus']\n",
    "allowed_colors = ['red', 'green', 'blue']\n",
    "allowed_textures = [\"solid\", \"spots_polka\", \"stripes_diagonal\"]\n",
    "\n",
    "allowed = {\n",
    "    \"shapes\": allowed_shapes,\n",
    "    \"colors\": allowed_colors,\n",
    "    \"textures\": allowed_textures\n",
    "}\n",
    "dataset = ElementDataset(allowed, class_configs, 1000, 224, 4, 64, 16, 42, 123)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "dataset.config"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(3, 5, figsize=(15, 9))\n",
    "axes = axes.flatten()\n",
    "for i, ax in enumerate(axes):\n",
    "    img = dataset.get_item(i)\n",
    "    classes = img.class_labels\n",
    "    classes = [str(v) for v in classes]\n",
    "    ax.imshow(img.img)\n",
    "    ax.set_title(\", \".join(classes))\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "plt.tight_layout()\n",
    "plt.savefig(save_dir / \"simple_small.png\")\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "dataloader = DataLoader(dataset, 32)\n",
    "labels = []\n",
    "for sample in dataloader:\n",
    "    labels.append(sample[1])\n",
    "labels = np.concatenate(labels)\n",
    "print(\"No. imgs per class\")\n",
    "print(labels.sum(axis=0))\n",
    "print(\"No. classes per image\")\n",
    "vals, counts = np.unique(labels.sum(axis=1), return_counts=True)\n",
    "print(\", \".join([f\"{vals[i]: .0f}: {counts[i]}\"for i in range(len(vals))]))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "allowed_shapes = ['square', 'circle', 'triangle', 'plus']\n",
    "allowed_colors = ['red', 'green', 'blue']\n",
    "allowed_textures = [\"solid\", \"spots_polka\", \"stripes_diagonal\"]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class_configs = list(itertools.product(allowed_shapes + [None], allowed_colors + [None], allowed_textures + [None]))\n",
    "print(len(class_configs))\n",
    "class_configs = [v for v in class_configs if sum([in_v is None for in_v in v]) < 2]\n",
    "print(len(class_configs))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class_configs = [{\"shape\": v[0], \"color\": v[1], \"texture\": v[2]} for v in class_configs]\n",
    "class_configs"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "allowed = {\n",
    "    \"shapes\": allowed_shapes,\n",
    "    \"colors\": allowed_colors,\n",
    "    \"textures\": allowed_textures\n",
    "}\n",
    "dataset = ElementDataset(allowed, class_configs, 1000, 224, 4, 64, 16, 42, 123)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "dataset.config"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def plot_dataset(my_dataset, savefig=None, show_classes=False):\n",
    "    fig, axes = plt.subplots(3, 5, figsize=(15, 9))\n",
    "    axes = axes.flatten()\n",
    "    for i, ax in enumerate(axes):\n",
    "        img, labels = my_dataset[i]\n",
    "        classes = np.where(labels == 1)[0]\n",
    "        ax.imshow(img.numpy().transpose(1, 2, 0))\n",
    "        if show_classes:\n",
    "            ax.set_title(f\"No. classes {len(classes)}\")\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "    plt.tight_layout()\n",
    "    if savefig is not None:\n",
    "        plt.savefig(savefig)\n",
    "    plt.show()\n",
    "\n",
    "plot_dataset(dataset, save_dir / \"simple_all.png\", True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "dataloader = DataLoader(dataset, 32)\n",
    "labels = []\n",
    "for sample in dataloader:\n",
    "    labels.append(sample[1])\n",
    "labels = np.concatenate(labels)\n",
    "print(\"No. imgs per class\")\n",
    "print(labels.sum(axis=0))\n",
    "print(\"No. classes per image\")\n",
    "vals, counts = np.unique(labels.sum(axis=1), return_counts=True)\n",
    "print(\", \".join([f\"{vals[i]: .0f}: {counts[i]}\"for i in range(len(vals))]))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Create a dataset from a config file"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "filename = Path(\"configs/simple_dataset.yaml\")\n",
    "with open(filename, \"r\") as fp:\n",
    "    config = yaml.safe_load(fp)\n",
    "config"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# This can be helpful if using lots of different datasets/classes\n",
    "def get_obj_from_str(string, reload=False):\n",
    "    module, cls = string.rsplit(\".\", 1)\n",
    "    if reload:\n",
    "        module_imp = importlib.import_module(module)\n",
    "        importlib.reload(module_imp)\n",
    "    return getattr(importlib.import_module(module, package=None), cls)\n",
    "\n",
    "def instantiate_from_config(config):\n",
    "    if not \"target\" in config:\n",
    "        raise KeyError(\"Expected key `target` to instantiate.\")\n",
    "    return get_obj_from_str(config[\"target\"])(**config.get(\"params\", dict()))\n",
    "\n",
    "dataset_creator = instantiate_from_config(config[\"dataset\"])"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# But here's how to just do it\n",
    "from elements.classes import ConceptElementDatasetCreator\n",
    "dataset_creator =  ConceptElementDatasetCreator(**config[\"dataset\"][\"params\"])"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Class 013 is spotty squares\n",
    "print(dataset_creator.class_configs[13])\n",
    "class_dataset = dataset_creator(\"013\")\n",
    "plot_dataset(class_dataset)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Create images only containing a specific concept\n",
    "concept_dataset = dataset_creator(\"red\")\n",
    "plot_dataset(concept_dataset)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Add spatial restrictions by using suffixes\n",
    "concept_dataset = dataset_creator(\"red_top\")\n",
    "plot_dataset(concept_dataset)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Or just sample more images from the dataset config\n",
    "random_dataset = dataset_creator(\"random500_12\")\n",
    "plot_dataset(random_dataset)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Once again, you can set spatial restrictions\n",
    "random_dataset = dataset_creator(\"random500_12_left\")\n",
    "plot_dataset(random_dataset)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "filename = Path(\"configs/standard_dataset.yaml\")\n",
    "with open(filename, \"r\") as fp:\n",
    "    config = yaml.safe_load(fp)\n",
    "dataset_creator =  ConceptElementDatasetCreator(**config[\"dataset\"][\"params\"])\n",
    "dataset = dataset_creator(\"random500_0\")\n",
    "plot_dataset(dataset, show_classes=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "filename = Path(\"configs/simple_red_equals_triangle_dataset.yaml\")\n",
    "with open(filename, \"r\") as fp:\n",
    "    config = yaml.safe_load(fp)\n",
    "dataset_creator =  ConceptElementDatasetCreator(**config[\"dataset\"][\"params\"])\n",
    "dataset = dataset_creator(\"random500_0\")\n",
    "plot_dataset(dataset, show_classes=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "filename = Path(\"configs/simple_red_objects_are_triangles_dataset.yaml\")\n",
    "with open(filename, \"r\") as fp:\n",
    "    config = yaml.safe_load(fp)\n",
    "dataset_creator = ConceptElementDatasetCreator(**config[\"dataset\"][\"params\"])\n",
    "dataset = dataset_creator(\"random500_0\")\n",
    "plot_dataset(dataset, show_classes=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
