{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Knowledge compilation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "from pyHEXgraph import HEXGraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes_idx = {'apple': 0, 'aquarium_fish': 1, 'baby': 2, 'bear': 3, 'beaver': 4, 'bed': 5, 'bee': 6, 'beetle': 7, 'bicycle': 8, 'bottle': 9, 'bowl': 10, 'boy': 11, 'bridge': 12, 'bus': 13, 'butterfly': 14, 'camel': 15, 'can': 16, 'castle': 17, 'caterpillar': 18, 'cattle': 19, 'chair': 20, 'chimpanzee': 21, 'clock': 22, 'cloud': 23, 'cockroach': 24, 'couch': 25, 'crab': 26, 'crocodile': 27, 'cup': 28, 'dinosaur': 29, 'dolphin': 30, 'elephant': 31, 'flatfish': 32, 'forest': 33, 'fox': 34, 'girl': 35, 'hamster': 36, 'house': 37, 'kangaroo': 38, 'keyboard': 39, 'lamp': 40, 'lawn_mower': 41, 'leopard': 42, 'lion': 43, 'lizard': 44, 'lobster': 45, 'man': 46, 'maple_tree': 47, 'motorcycle': 48, 'mountain': 49, 'mouse': 50, 'mushroom': 51, 'oak_tree': 52, 'orange': 53, 'orchid': 54, 'otter': 55, 'palm_tree': 56, 'pear': 57, 'pickup_truck': 58, 'pine_tree': 59, 'plain': 60, 'plate': 61, 'poppy': 62, 'porcupine': 63, 'possum': 64, 'rabbit': 65, 'raccoon': 66, 'ray': 67, 'road': 68, 'rocket': 69, 'rose': 70, 'sea': 71, 'seal': 72, 'shark': 73, 'shrew': 74, 'skunk': 75, 'skyscraper': 76, 'snail': 77, 'snake': 78, 'spider': 79, 'squirrel': 80, 'streetcar': 81, 'sunflower': 82, 'sweet_pepper': 83, 'table': 84, 'tank': 85, 'telephone': 86, 'television': 87, 'tiger': 88, 'tractor': 89, 'train': 90, 'trout': 91, 'tulip': 92, 'turtle': 93, 'wardrobe': 94, 'whale': 95, 'willow_tree': 96, 'wolf': 97, 'woman': 98, 'worm': 99}\n",
    "\n",
    "superclasses_names = ['aquatic_mammals', 'fish', 'flowers', 'food_containers',\n",
    "                        'fruit_and_vegetables', 'household_electrical_devices',\n",
    "                        'household_furniture', 'insects', 'large_carnivores',\n",
    "                        'large_man-made_outdoor_things', 'large_natural_outdoor_scenes',\n",
    "                        'large_omnivores_and_herbivores', 'medium_mammals', 'non-insect_invertebrates',\n",
    "                        'people', 'reptiles', 'small_mammals', 'trees', 'vehicles_1', 'vehicles_2']\n",
    "\n",
    "superclasses = [['beaver', 'dolphin', 'otter', 'seal', 'whale'],\n",
    "              ['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'],\n",
    "              ['orchid', 'poppy', 'rose', 'sunflower', 'tulip'],\n",
    "              ['bottle', 'bowl', 'can', 'cup', 'plate'],\n",
    "              ['apple', 'mushroom', 'orange', 'pear', 'sweet_pepper'],\n",
    "              ['clock', 'keyboard', 'lamp', 'telephone', 'television'],\n",
    "              ['bed', 'chair', 'couch', 'table', 'wardrobe'],\n",
    "              ['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],\n",
    "              ['bear', 'leopard', 'lion', 'tiger', 'wolf'],\n",
    "              ['bridge', 'castle', 'house', 'road', 'skyscraper'],\n",
    "              ['cloud', 'forest', 'mountain', 'plain', 'sea'],\n",
    "              ['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'],\n",
    "              ['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],\n",
    "              ['crab', 'lobster', 'snail', 'spider', 'worm'],\n",
    "              ['baby', 'boy', 'girl', 'man', 'woman'],\n",
    "              ['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],\n",
    "              ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],\n",
    "              ['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree'],\n",
    "              ['bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train'],\n",
    "              ['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "\n",
    "T = nx.DiGraph()\n",
    "for i, sc in enumerate(superclasses_names):\n",
    "    T.add_node(sc)\n",
    "    for c in superclasses[i]:\n",
    "        T.add_node(c)\n",
    "        T.add_edge(sc, c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hexg = HEXGraph(file_path=\"./cifar/hex\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hexg = HEXGraph(T=T)\n",
    "hexg.save(\"./cifar/hex\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "i2n = {}\n",
    "for node, name in enumerate(T.nodes):\n",
    "    if name in classes_idx.keys():\n",
    "        idx = int(classes_idx[name])\n",
    "        i2n[idx] = node\n",
    "\n",
    "with open(\"./i2n.json\", 'w') as f:\n",
    "    json.dump(i2n, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def idx_to_state(idx, i2n, tm):\n",
    "    node = i2n[idx]\n",
    "    state = tm[:, node]\n",
    "    return state"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models.densenet import DenseNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for gr in range(4, 20, 2):\n",
    "    model = DenseNet(growth_rate = gr, block_config=(4, 8, 16, 10), num_init_features=16, num_classes=120)\n",
    "    print(sum(p.numel() for p in model.parameters() if p.requires_grad))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "d8a1dba90a7a3c73ffa205c3fa0146003aed819f1f43d631078b51d929ccab9c"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
