{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import dependencies\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set parameters (change manually if needed)\n",
    "N_CPU = multiprocessing.cpu_count()-1\n",
    "explicit_bias = False\n",
    "datasets = ['ZINC', 'TOX21', 'PROTEINS']\n",
    "models = ['GT', 'SAN', 'GraphiT']\n",
    "\n",
    "if explicit_bias:\n",
    "    states = ['train']\n",
    "else:\n",
    "    states = ['base', 'train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate plots\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "matplotlib.rcParams.update({'font.size': 12})\n",
    "res = []\n",
    "for dataset in datasets:\n",
    "    for model in models:\n",
    "        for state in states:\n",
    "            if explicit_bias:\n",
    "                key = (model, dataset, 'exp_bias', state)\n",
    "            else:\n",
    "                key = (model, dataset, state)\n",
    "            print(key)\n",
    "            try:\n",
    "                tic = time.time()\n",
    "                if explicit_bias:\n",
    "                    spec = (model, dataset, 'sparse', 'NoPE', 'BN', 'ExplicitBias', 'Edge', ('Escore',), state, 0)\n",
    "                else:\n",
    "                    spec = (model, dataset, 'sparse', 'NoPE', 'BN', 'NoBias', 'Edge', ('Escore',), state, 0)\n",
    "                a = load_spec(spec) # avoid loading multiple logs in parallel\n",
    "                data_raw = get_values(a, use_abs=False, flatten=batch_median)\n",
    "                del a\n",
    "                # limit data to 50000 edges (proteins has more)\n",
    "                params = [(data_raw[layer][:50000,:], key, layer) for layer in range(len(data_raw))]\n",
    "                with multiprocessing.pool.ThreadPool(processes=min(N_CPU, len(data_raw))) as pool:\n",
    "                    for item in pool.imap_unordered(ks_ll_plot_w, params):\n",
    "                        res.append(item)\n",
    "                del data_raw\n",
    "                plt.close()\n",
    "                toc = time.time()\n",
    "                print(toc-tic)\n",
    "            except Exception as e:\n",
    "                print(e)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
