{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, pickle, sys\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats\n",
    "import numpy as np\n",
    "import glob\n",
    "from tqdm import tqdm\n",
    "from prettytable import PrettyTable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 96/96 [00:03<00:00, 30.17it/s]\n"
     ]
    }
   ],
   "source": [
    "d = '../results_release/nasbench1/proxies'\n",
    "runs = []\n",
    "processed = set()\n",
    "\n",
    "for f in tqdm(os.listdir(d)):\n",
    "    pf = open(os.path.join(d,f),'rb')\n",
    "    while 1:\n",
    "        try:\n",
    "            p = pickle.load(pf)\n",
    "            if p['hash'] in processed:\n",
    "                continue\n",
    "            processed.add(p['hash'])\n",
    "            runs.append(p)\n",
    "        except EOFError:\n",
    "            break\n",
    "    pf.close()\n",
    "with open('../data/nasbench1_accuracy.p','rb') as f:\n",
    "    all_accur = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "423624 423624\n"
     ]
    }
   ],
   "source": [
    "print(len(runs),len(all_accur))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../results_release/nasbench1/proxies 423624\n",
      "+---------+-----------+-------+-------+--------+---------+-----------+\n",
      "| Dataset | grad_norm |  snip | grasp | fisher | synflow | jacob_cov |\n",
      "+---------+-----------+-------+-------+--------+---------+-----------+\n",
      "| CIFAR10 |   0.198   | 0.164 | 0.448 | 0.257  |  0.372  |   0.378   |\n",
      "+---------+-----------+-------+-------+--------+---------+-----------+\n"
     ]
    }
   ],
   "source": [
    "t=None\n",
    "\n",
    "print(d, len(runs))\n",
    "metrics={}\n",
    "for k in runs[0]['logmeasures'].keys():\n",
    "    metrics[k] = []\n",
    "acc = []\n",
    "hashes = []\n",
    "\n",
    "if t is None:\n",
    "    hl=['Dataset']\n",
    "    hl.extend(['grad_norm', 'snip', 'grasp', 'fisher', 'synflow', 'jacob_cov'])\n",
    "    t = PrettyTable(hl)\n",
    "\n",
    "for r in runs:\n",
    "    for k,v in r['logmeasures'].items():\n",
    "        metrics[k].append(v)\n",
    "    \n",
    "    acc.append(all_accur[r['hash']][0])\n",
    "    hashes.append(r['hash'])\n",
    "\n",
    "res = []\n",
    "for k in hl:\n",
    "    if k=='Dataset':\n",
    "        continue\n",
    "    v = metrics[k]\n",
    "    cr = abs(stats.spearmanr(acc,v,nan_policy='omit').correlation)\n",
    "    #print(f'{k} = {cr}')\n",
    "    res.append(round(cr,3))\n",
    "\n",
    "ds = 'CIFAR10'\n",
    "t.add_row([ds]+res)\n",
    "\n",
    "print(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
