{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import dependencies\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set parameters (change manually if needed)\n",
    "all_feat = False # set True to plot all edges, False to plot only batches' maximums\n",
    "explicit_bias = False\n",
    "models = ['GT', 'GraphiT', 'SAN']\n",
    "datasets = ['ZINC', 'TOX21', 'PROTEINS']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate plots\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "matplotlib.rcParams.update({'font.size': 11})\n",
    "batch_median = False # false to normalize by edge median, default false\n",
    "plot_threshold = False\n",
    "mm = {'GT': 'Graph Transformer'} # map model name in plot\n",
    "title = 'Edge features ratio'\n",
    "xlabel = 'edge' if all_feat else 'batch'\n",
    "name = '_'.join(title.replace('/','over').split() + ['sparse', 'NoPE', 'BN', 'ExplicitBias' if explicit_bias else 'MixBias', 'Edge', 'edge' if all_feat else 'batch', 'batchmed' if batch_median else 'edgemed'])\n",
    "\n",
    "scale = 6\n",
    "w = 4.7/4\n",
    "h = 3.1/3\n",
    "fig, axs = plt.subplots(\n",
    "    len(models), \n",
    "    len(datasets), \n",
    "    layout='constrained', \n",
    "    figsize=(0.4*w*len(datasets)*scale, 0.3*len(models)*scale),\n",
    "    sharex='none' if all_feat else 'col',\n",
    "    sharey='row'\n",
    ")\n",
    "\n",
    "a = None\n",
    "for i,m in enumerate(models):\n",
    "    for j,d in enumerate(datasets):\n",
    "        if True:\n",
    "            try:\n",
    "                axs[i,j].set_ylim(1,1e6)\n",
    "                axs[i,j].set_yscale('log')\n",
    "                cspec = (m, d, 'sparse', 'NoPE', 'BN', 'NoBias', 'Edge', ('Escore',), 'base', 0)\n",
    "                a = load_spec(cspec)\n",
    "                plot2a(a, feat='e', axis=axs[i,j], title=None, legend=False, batch_median=batch_median, spec=cspec, all_feat=all_feat)\n",
    "                del a\n",
    "                if explicit_bias:\n",
    "                    cspec = (m, d, 'sparse', 'NoPE', 'BN', 'ExplicitBias', 'Edge', ('Escore',), 'train', 0)\n",
    "                else:\n",
    "                    cspec = (m, d, 'sparse', 'NoPE', 'BN', 'NoBias', 'Edge', ('Escore',), 'train', 0)\n",
    "                a = load_spec(cspec)\n",
    "                plot2(a, feat='e', axis=axs[i,j], title=None, legend=False, batch_median=batch_median, spec=cspec, alpha=0.5, all_feat=all_feat, plot_threshold=plot_threshold)\n",
    "                del a\n",
    "            except Exception as e:\n",
    "                print(cspec, e)\n",
    "for i,m in enumerate(models):\n",
    "    axs[i,0].set_ylabel(mm.get(m,m))\n",
    "for j,d in enumerate(datasets):\n",
    "    axs[0,j].set_title(d)\n",
    "    axs[i,j].set_xlabel(xlabel)\n",
    "handles, labels = axs[0,0].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='center right')\n",
    "fig.suptitle(title)\n",
    "for i in range(len(models)):\n",
    "    for j in range(len(datasets)):\n",
    "        p = axs[i,j].get_position()\n",
    "        axs[i,j].set_position([p.x0/w, p.y0*h, p.width/w, p.height*h])\n",
    "\n",
    "pathlib.Path('out/plots/base_vs_train').mkdir(parents=True, exist_ok=True)\n",
    "print(f'out/plots/base_vs_train/{name}.png')\n",
    "fig.savefig(f'out/plots/base_vs_train/{name}.pdf', format='pdf')\n",
    "fig.savefig(f'out/plots/base_vs_train/{name}.png')\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
