{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  HEP-Th Performance experiments\n",
    "\n",
    "In this notebook we measure if the Hep_th experiment is capable of optimizing its loss function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import torch\n",
    "import sys\n",
    "import pandas as pd\n",
    "\n",
    "sys.path.append('../pytorch_graph_edit_networks')\n",
    "sys.path.append('../hep_th')\n",
    "\n",
    "import pytorch_graph_edit_networks as gen\n",
    "import hep_th\n",
    "\n",
    "# model parameters\n",
    "num_layers = 3\n",
    "dim_hid = 64\n",
    "nonlin = torch.nn.Tanh()\n",
    "\n",
    "# initialize model, one without edge filters, one with edge filters, but no limit, and one\n",
    "# with limit\n",
    "filter_options = [True, 538]\n",
    "# set up model names for result storage\n",
    "model_names = ['flex_filter', 'const_filter']\n",
    "comp_names  = ['forward', 'backward']\n",
    "headers = ['sizes']\n",
    "for model_name in model_names:\n",
    "    for comp_name in comp_names:\n",
    "        headers.append('%s_%s' % (model_name, comp_name))\n",
    "\n",
    "max_past = 12\n",
    "\n",
    "\n",
    "# set up a list of all experimental settings\n",
    "month_tuples = []\n",
    "for year in range(1992,2003+1):\n",
    "    max_month = 12\n",
    "    if year == 2003:\n",
    "        max_month = 4\n",
    "    for month in range(1, max_month+1):\n",
    "        month_tuples.append((year, month))\n",
    "settings = []\n",
    "for past in range(1, max_past+1):\n",
    "    for t in range(past, len(month_tuples)):\n",
    "        settings.append((month_tuples[t][0], month_tuples[t][1], past))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "learning_rate  = 1E-3\n",
    "weight_decay   = 1E-5\n",
    "\n",
    "fee = 538\n",
    "\n",
    "xx = 0\n",
    "curves = {}\n",
    "learning_curve = []\n",
    "\n",
    "month_tuples = []\n",
    "for year in range(1992,2003+1):\n",
    "    max_month = 12\n",
    "    if year == 2003:\n",
    "        max_month = 4\n",
    "    for month in range(1, max_month+1):\n",
    "        month_tuples.append((year, month))\n",
    "settings = []\n",
    "for past in range(1, max_past+1):\n",
    "    for t in range(past, len(month_tuples)):\n",
    "        settings.append((month_tuples[t][0], month_tuples[t][1], past))\n",
    "\n",
    "net = gen.GEN(num_layers = num_layers, dim_in = 1, dim_hid = dim_hid, nonlin = nonlin, filter_edge_edits = fee)\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
    "learning_curve2 = []\n",
    "for epoch in range(1):\n",
    "    print(f'epoch: {epoch}')\n",
    "    for iter, sets in tqdm(enumerate(settings)):\n",
    "        optimizer.zero_grad()\n",
    "        year, month, past = settings[iter]\n",
    "        loss = hep_th.compute_loss(net, year, month, past = past)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        learning_curve2.append(loss.item())\n",
    "        (epoch*len(sets) + iter)\n",
    "        if (epoch*len(sets) + iter)%100 == 0: print(loss.item())\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "breakpoints = []\n",
    "tau = 0\n",
    "for ind, i in enumerate(settings):\n",
    "    temp_tau = i[2]\n",
    "    if temp_tau != tau:\n",
    "        tau = temp_tau\n",
    "        breakpoints.append(ind)\n",
    "\n",
    "print(breakpoints)\n",
    "import seaborn as sns\n",
    "from matplotlib import pyplot as plt\n",
    "#sns.lineplot(x=range(len(learning_curve2)), y=learning_curve2)\n",
    "\n",
    "tau_3 = learning_curve2[breakpoints[2]:breakpoints[3]]\n",
    "tau_7 = learning_curve2[breakpoints[6]:breakpoints[7]]\n",
    "tau_12 = learning_curve2[breakpoints[11]:]\n",
    "\n",
    "\n",
    "taus = {'tau_3' : tau_3, 'tau_7' : tau_7, 'tau_12' : tau_12}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "df = pd.DataFrame.from_dict(taus, orient='index').T\n",
    "df = df.iloc[:123, :]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "melted = pd.melt(df.reset_index(), id_vars=['index'])\n",
    "melted.columns = ['iteration', r'$\\tau$', 'loss']\n",
    "melted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "plt.style.use('seaborn')\n",
    "plt.rcParams['text.usetex'] = True\n",
    "#sns.lineplot(data=melted, x='id', y='val', hue='var')\n",
    "#plt.show()\n",
    "fig, ax = plt.subplots(figsize=(6,3))\n",
    "sns.lineplot(data=melted, x='iteration', y='loss', hue=r'$\\tau$', ax=ax)\n",
    "plt.savefig('LossCurves.png', bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "repro",
   "language": "python",
   "display_name": "repro"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}