{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph Dynamical Systems\n",
    "\n",
    "This notebooks contains the experiments to evaluate graph edit networks on simple graph dynamical systems, namely the edit cycles, degree rules, and game of life datasets.\n",
    "\n",
    "This notebook uses the generator that takes care of unique items. It creates a time series for test set, where no graphs in all series are the same. From here, it creates train set, where no graph from test set is included."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hyperparameter setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import torch\n",
    "import sys\n",
    "\n",
    "sys.path.append('../pytorch_graph_edit_networks')\n",
    "sys.path.append('../additional_baselines')\n",
    "sys.path.append('../hep_th')\n",
    "\n",
    "import pytorch_graph_edit_networks as gen\n",
    "import baseline_models\n",
    "import hep_th\n",
    "\n",
    "import os\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n",
    "\n",
    "# model hyperparameters\n",
    "num_layers = 2\n",
    "dim_hid = 64\n",
    "\n",
    "# training hyperparameters\n",
    "learning_rate  = 1E-3\n",
    "weight_decay   = 1E-5\n",
    "loss_threshold = 1E-3\n",
    "max_epochs     = 10000\n",
    "print_step     = 1000\n",
    "\n",
    "# the number of repitions for each experiment\n",
    "R = 3\n",
    "# the number of test time series we use to evaluate learning afterwards\n",
    "N_test = 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up model names\n",
    "models = ['VGAE', 'GEN_crossent', 'GEN']\n",
    "#models = ['GEN_crossent']\n",
    "\n",
    "# set up functions to initialize the models\n",
    "def setup_vgae(dim_in, nonlin):\n",
    "    return baseline_models.VGAE(num_layers = num_layers, dim_in = dim_in, dim_hid = dim_hid, beta = 1E-3, sigma_scaling = 1E-3, nonlin = nonlin)\n",
    "def setup_gen(dim_in, nonlin):\n",
    "    return gen.GEN(num_layers = num_layers, dim_in = dim_in, dim_hid = dim_hid, nonlin = nonlin)\n",
    "setup_funs = [setup_vgae, setup_gen, setup_gen]\n",
    "# set up functions to compute the loss\n",
    "loss_fun = gen.GEN_loss()\n",
    "crossent_loss_fun = gen.GEN_loss_crossent()\n",
    "def vgae_loss(model, A, X, delta, Epsilon):\n",
    "    B = A + Epsilon\n",
    "    # delete all outgoing and incoming edges of deleted nodes\n",
    "    B[delta < -0.5, :] = 0\n",
    "    B[:, delta < -0.5] = 0\n",
    "    return model.compute_loss(torch.tensor(A, dtype=torch.float), torch.tensor(B, dtype=torch.float), torch.tensor(X, dtype=torch.float))\n",
    "def gen_loss_crossent(model, A, X, delta, Epsilon):\n",
    "    delta_pred, Epsilon_pred = model(torch.tensor(A, dtype=torch.float), torch.tensor(X, dtype=torch.float))\n",
    "    return crossent_loss_fun(delta_pred, Epsilon_pred, torch.tensor(delta, dtype=torch.float), torch.tensor(Epsilon, dtype=torch.float), torch.tensor(A, dtype=torch.float))\n",
    "def gen_loss(model, A, X, delta, Epsilon):\n",
    "    delta_pred, Epsilon_pred = model(torch.tensor(A, dtype=torch.float), torch.tensor(X, dtype=torch.float))\n",
    "    return loss_fun(delta_pred, Epsilon_pred, torch.tensor(delta, dtype=torch.float), torch.tensor(Epsilon, dtype=torch.float), torch.tensor(A, dtype=torch.float))\n",
    "loss_funs = [vgae_loss, gen_loss_crossent, gen_loss]\n",
    "# set up prediction functions\n",
    "def vgae_pred(model, A, X):\n",
    "    B = model(torch.tensor(A, dtype=torch.float), torch.tensor(X, dtype=torch.float))\n",
    "    B = B.detach().numpy()\n",
    "    Epsilon = B - A\n",
    "    delta = np.zeros(A.shape[0])\n",
    "    delta[np.sum(B, 1) < 0.5] = -1.\n",
    "    Epsilon[delta < -0.5, :] = 0.\n",
    "    Epsilon[:, delta < -0.5] = 0.\n",
    "    return delta, Epsilon\n",
    "def gen_pred(model, A, X):\n",
    "    delta_pred, Epsilon_pred = model(torch.tensor(A, dtype=torch.float), torch.tensor(X, dtype=torch.float))\n",
    "    delta_pred = delta_pred.detach().numpy()\n",
    "    Epsilon_pred = Epsilon_pred.detach().numpy()\n",
    "    delta = np.zeros(A.shape[0])\n",
    "    delta[delta_pred > 0.5] = 1.\n",
    "    delta[delta_pred < -0.5] = -1.\n",
    "    Epsilon = np.zeros(A.shape)\n",
    "    Epsilon[np.logical_and(A > 0.5, Epsilon_pred < -0.5)] = -1.\n",
    "    Epsilon[np.logical_and(A < 0.5, Epsilon_pred > +0.5)] = +1.\n",
    "    return delta, Epsilon\n",
    "pred_funs = [vgae_pred, gen_pred, gen_pred]\n",
    "\n",
    "eval_criteria = ['node_ins_recall',\n",
    "                 'node_ins_precision',\n",
    "                 'node_del_recall',\n",
    "                 'node_del_precision',\n",
    "                 'edge_ins_recall',\n",
    "                 'edge_ins_precision',\n",
    "                 'edge_del_recall',\n",
    "                 'edge_del_precision']\n",
    "# set up a function to compute precision and recall\n",
    "def prec_rec(X, Y):\n",
    "    # X is the prediction, Y is the target\n",
    "    target_insertions = Y > 0.5\n",
    "    predicted_insertions = X > 0.5\n",
    "    target_deletions = Y < -0.5\n",
    "    predicted_deletions = X < -0.5\n",
    "    # first, check the insertion recall\n",
    "    if np.sum(target_insertions) < 0.5:\n",
    "        ins_rec = 1.\n",
    "    else:\n",
    "        ins_rec  = np.mean(X[target_insertions] > 0.5)\n",
    "    # then the insertion precision\n",
    "    if np.sum(predicted_insertions) < 0.5:\n",
    "        ins_prec = 1.\n",
    "    else:\n",
    "        ins_prec = np.mean(Y[predicted_insertions] > 0.5)\n",
    "    # then the deletion recall\n",
    "    if np.sum(target_deletions) < 0.5:\n",
    "        del_rec = 1.\n",
    "    else:\n",
    "        del_rec  = np.mean(X[target_deletions] < -0.5)\n",
    "    # and finally the deletion precision\n",
    "    if np.sum(predicted_deletions) < 0.5:\n",
    "        del_prec = 1.\n",
    "    else:\n",
    "        del_prec = np.mean(Y[predicted_deletions] < -0.5)\n",
    "    return ins_rec, ins_prec, del_rec, del_prec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append('../graph_edit')\n",
    "sys.path.append('../degree_rules')\n",
    "sys.path.append('../game_of_life')\n",
    "\n",
    "\n",
    "import graph_edit_cycles\n",
    "import degree_rules\n",
    "import game_of_life\n",
    "import random\n",
    "\n",
    "# Here you change the dataset foe ach experiment\n",
    "\n",
    "# original experiments\n",
    "datasets = ['edit_cycles', 'degree_rules', 'game_of_life']\n",
    "dim_ins  = [4, 32, 1]\n",
    "\n",
    "# datasets for experiments in 4.2.1\n",
    "#datasets = ['degree_rules', 'degree_rules_erdos', 'degree_rules_conf']\n",
    "#dim_ins  = [32, 32, 32]\n",
    "\n",
    "# datasets for experiments in section 4.2, where we evaluated on the test set of size 100 \n",
    "#datasets = ['edit_cycles_test100', 'degree_rules_test100', 'game_of_life_test100']\n",
    "\n",
    "# set up a generative function for each data set\n",
    "def generate_edit_cycle():\n",
    "    As, Xs, tuples = graph_edit_cycles.generate_time_series(random.randrange(3), random.randrange(12), random.randrange(4, 12))\n",
    "    deltas = []\n",
    "    Epsilons = []\n",
    "    for tpl in tuples:\n",
    "        deltas.append(tpl[0])\n",
    "        Epsilons.append(tpl[1])\n",
    "    return As, Xs, deltas, Epsilons\n",
    "def generate_degree_rules():\n",
    "    # the initial number of nodes in each graph\n",
    "    n_init = 8\n",
    "    # the maximum number of nodes that can occur in each graph during evolution\n",
    "    n_max  = n_init * 4\n",
    "    return degree_rules.generate_time_series_from_random_matrix(n_init, n_max = n_max)\n",
    "\n",
    "def generate_degree_rules_erdos_renyi():\n",
    "    # the initial number of nodes in each graph\n",
    "    n_init = 8\n",
    "    # the maximum number of nodes that can occur in each graph during evolution\n",
    "    n_max  = n_init * 4\n",
    "    return degree_rules.generate_time_series_from_erdos_reny(n_init, n_max = n_max)\n",
    "\n",
    "def generate_degree_rules_configuration():\n",
    "    # the initial number of nodes in each graph\n",
    "    n_init = 8\n",
    "    # the maximum number of nodes that can occur in each graph during evolution\n",
    "    n_max  = n_init * 4\n",
    "    return degree_rules.generate_time_series_from_configuration_model(n_init, n_max = n_max)\n",
    "\n",
    "def generate_degree_unique(test_set):\n",
    "    # the initial number of nodes in each graph\n",
    "    n_init = 8\n",
    "    # the maximum number of nodes that can occur in each graph during evolution\n",
    "    n_max  = n_init * 4\n",
    "    #test_set = degree_rules.create_test_set(10, n_init)\n",
    "    return degree_rules.generate_unique_time_series(n_init, test_set, n_max = n_max)\n",
    "\n",
    "def generate_game_of_life():\n",
    "    # set hyper-parameters for the game of life random grid generation\n",
    "    grid_size = 10\n",
    "    num_shapes = 1\n",
    "    p = 0.1\n",
    "    T_max = 10\n",
    "    A, Xs, deltas = game_of_life.generate_random_time_series(grid_size, num_shapes, p, T_max)\n",
    "    As = [A] * len(Xs)\n",
    "    Epsilons = [np.zeros_like(A)] * len(Xs)\n",
    "    return As, Xs, deltas, Epsilons\n",
    "\n",
    "def generate_game_of_life_with_my_shapes():\n",
    "    # set hyper-parameters for the game of life random grid generation\n",
    "    grid_size = 10\n",
    "    num_shapes = 1\n",
    "    p = 0.1\n",
    "    T_max = 10\n",
    "    A, Xs, deltas = game_of_life.generate_random_time_series(grid_size, num_shapes, p, T_max, use_my_shapes = True)\n",
    "    As = [A] * len(Xs)\n",
    "    Epsilons = [np.zeros_like(A)] * len(Xs)\n",
    "    return As, Xs, deltas, Epsilons\n",
    "\n",
    "def generate_game_of_life_unique(test_set):\n",
    "    # set hyper-parameters for the game of life random grid generation\n",
    "    grid_size = 10\n",
    "    num_shapes = 1\n",
    "    p = 0.1\n",
    "    T_max = 10\n",
    "    A, Xs, deltas = game_of_life.generate_unique_time_series(test_set, grid_size=grid_size, num_shapes = num_shapes, p=p)\n",
    "    As = [A] * len(Xs)\n",
    "    Epsilons = [np.zeros_like(A)] * len(Xs)\n",
    "    return As, Xs, deltas, Epsilons\n",
    "\n",
    "# change genator_funs for the experiments you want to run\n",
    "\n",
    "# original experiments\n",
    "generator_funs = [generate_edit_cycle, generate_degree_rules, generate_game_of_life]\n",
    "\n",
    "# experiments in 4.2.1\n",
    "#generator_funs = [generate_degree_rules, generate_degree_rules_erdos_renyi, generate_degree_rules_configuration]\n",
    "\n",
    "# experiments in 4.2 - unique train/test\n",
    "#generator_funs = [generate_degree_unique, generate_game_of_life_unique]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Actual Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "--- data set game_of_life_test100 ---\n",
      "\n",
      "--- model VGAE ---\n",
      "-- repeat 1 of 5 --\n",
      "loss avg after 1000 epochs: 23450.6\n",
      "loss avg after 2000 epochs: 23504.7\n",
      "loss avg after 3000 epochs: 23516.9\n",
      "loss avg after 4000 epochs: 23300.6\n",
      "loss avg after 5000 epochs: 23433.7\n",
      "loss avg after 6000 epochs: 22965.7\n",
      "loss avg after 7000 epochs: 23341.4\n",
      "loss avg after 8000 epochs: 23550.8\n",
      "loss avg after 9000 epochs: 23209.4\n",
      "loss avg after 10000 epochs: 23760.6\n",
      "-- repeat 2 of 5 --\n",
      "loss avg after 1000 epochs: 23713.5\n",
      "loss avg after 2000 epochs: 24066\n",
      "loss avg after 3000 epochs: 23346.1\n",
      "loss avg after 4000 epochs: 23532.3\n",
      "loss avg after 5000 epochs: 23399.6\n",
      "loss avg after 6000 epochs: 23498.9\n",
      "loss avg after 7000 epochs: 24051.4\n",
      "loss avg after 8000 epochs: 23119.2\n",
      "loss avg after 9000 epochs: 23348.1\n",
      "loss avg after 10000 epochs: 23529\n",
      "-- repeat 3 of 5 --\n",
      "loss avg after 1000 epochs: 23456.1\n",
      "loss avg after 2000 epochs: 23678.3\n",
      "loss avg after 3000 epochs: 23347.1\n",
      "loss avg after 4000 epochs: 23283.2\n",
      "loss avg after 5000 epochs: 23313.2\n",
      "loss avg after 6000 epochs: 23633.8\n",
      "loss avg after 7000 epochs: 23456.4\n",
      "loss avg after 8000 epochs: 23532.1\n",
      "loss avg after 9000 epochs: 23559\n",
      "loss avg after 10000 epochs: 23685.9\n",
      "-- repeat 4 of 5 --\n",
      "loss avg after 1000 epochs: 23513.7\n",
      "loss avg after 2000 epochs: 23423\n",
      "loss avg after 3000 epochs: 23744.5\n",
      "loss avg after 4000 epochs: 23088\n",
      "loss avg after 5000 epochs: 23321.5\n",
      "loss avg after 6000 epochs: 23341.2\n",
      "loss avg after 7000 epochs: 22813.1\n",
      "loss avg after 8000 epochs: 23159.6\n",
      "loss avg after 9000 epochs: 23385\n",
      "loss avg after 10000 epochs: 23253.2\n",
      "-- repeat 5 of 5 --\n",
      "loss avg after 1000 epochs: 29673.8\n",
      "loss avg after 2000 epochs: 25806.2\n",
      "loss avg after 3000 epochs: 24520.3\n",
      "loss avg after 4000 epochs: 24260.4\n",
      "loss avg after 5000 epochs: 24299.7\n",
      "loss avg after 6000 epochs: 24407.1\n",
      "loss avg after 7000 epochs: 24478.1\n",
      "loss avg after 8000 epochs: 24064.6\n",
      "loss avg after 9000 epochs: 23974.3\n",
      "loss avg after 10000 epochs: 24246.6\n",
      "node_ins_recall: 0.2378 +- 0.0443775\n",
      "node_ins_precision: 1 +- 0\n",
      "node_del_recall: 1 +- 0\n",
      "node_del_precision: 0.0373376 +- 0.00177807\n",
      "edge_ins_recall: 1 +- 0\n",
      "edge_ins_precision: 0.978 +- 0.044\n",
      "edge_del_recall: 1 +- 0\n",
      "edge_del_precision: 0.9924 +- 0.0152\n",
      "--- model GEN_crossent ---\n",
      "-- repeat 1 of 5 --\n",
      "loss avg after 1000 epochs: 53.6687\n",
      "loss avg after 2000 epochs: 36.4805\n",
      "loss avg after 3000 epochs: 34.1886\n",
      "loss avg after 4000 epochs: 33.3431\n",
      "loss avg after 5000 epochs: 32.6936\n",
      "loss avg after 6000 epochs: 32.0854\n",
      "loss avg after 7000 epochs: 32.797\n",
      "loss avg after 8000 epochs: 32.9699\n",
      "loss avg after 9000 epochs: 34.859\n",
      "loss avg after 10000 epochs: 36.5474\n",
      "-- repeat 2 of 5 --\n",
      "loss avg after 1000 epochs: 142.656\n",
      "loss avg after 2000 epochs: 71.7779\n",
      "loss avg after 3000 epochs: 62.6712\n",
      "loss avg after 4000 epochs: 70.55\n",
      "loss avg after 5000 epochs: 63.008\n",
      "loss avg after 6000 epochs: 62.4791\n",
      "loss avg after 7000 epochs: 56.5943\n",
      "loss avg after 8000 epochs: 63.0959\n",
      "loss avg after 9000 epochs: 57.7374\n",
      "loss avg after 10000 epochs: 53.2294\n",
      "-- repeat 3 of 5 --\n",
      "loss avg after 1000 epochs: 97.5732\n",
      "loss avg after 2000 epochs: 58.8255\n",
      "loss avg after 3000 epochs: 44.0564\n",
      "loss avg after 4000 epochs: 41.798\n",
      "loss avg after 5000 epochs: 76.6003\n",
      "loss avg after 6000 epochs: 65.8614\n",
      "loss avg after 7000 epochs: 65.1144\n",
      "loss avg after 8000 epochs: 66.7487\n",
      "loss avg after 9000 epochs: 67.7362\n",
      "loss avg after 10000 epochs: 69.5962\n",
      "-- repeat 4 of 5 --\n",
      "loss avg after 1000 epochs: 77.1816\n",
      "loss avg after 2000 epochs: 65.0758\n",
      "loss avg after 3000 epochs: 71.0145\n",
      "loss avg after 4000 epochs: 64.1501\n",
      "loss avg after 5000 epochs: 64.1683\n",
      "loss avg after 6000 epochs: 69.2985\n",
      "loss avg after 7000 epochs: 67.2521\n",
      "loss avg after 8000 epochs: 67.7631\n",
      "loss avg after 9000 epochs: 58.2579\n",
      "loss avg after 10000 epochs: 53.5439\n",
      "-- repeat 5 of 5 --\n",
      "loss avg after 1000 epochs: 96.4891\n",
      "loss avg after 2000 epochs: 70.0287\n",
      "loss avg after 3000 epochs: 70.3067\n",
      "loss avg after 4000 epochs: 66.8127\n",
      "loss avg after 5000 epochs: 68.3633\n",
      "loss avg after 6000 epochs: 65.6934\n",
      "loss avg after 7000 epochs: 73.8226\n",
      "loss avg after 8000 epochs: 67.511\n",
      "loss avg after 9000 epochs: 63.6876\n",
      "loss avg after 10000 epochs: 65.3798\n",
      "node_ins_recall: 0.427353 +- 0.288787\n",
      "node_ins_precision: 0.9971 +- 0.0058\n",
      "node_del_recall: 0.55754 +- 0.292743\n",
      "node_del_precision: 1 +- 0\n",
      "edge_ins_recall: 1 +- 0\n",
      "edge_ins_precision: 1 +- 0\n",
      "edge_del_recall: 1 +- 0\n",
      "edge_del_precision: 1 +- 0\n",
      "--- model GEN ---\n",
      "-- repeat 1 of 5 --\n",
      "loss avg after 1000 epochs: 84.7315\n",
      "loss avg after 2000 epochs: 87.9547\n",
      "loss avg after 3000 epochs: 103.524\n",
      "loss avg after 4000 epochs: 88.2653\n",
      "loss avg after 5000 epochs: 92.6138\n",
      "loss avg after 6000 epochs: 77.3263\n",
      "loss avg after 7000 epochs: 70.4089\n",
      "loss avg after 8000 epochs: 125.846\n",
      "loss avg after 9000 epochs: 72.9331\n",
      "loss avg after 10000 epochs: 39.2898\n",
      "-- repeat 2 of 5 --\n",
      "loss avg after 1000 epochs: 5267.67\n",
      "loss avg after 2000 epochs: 2242.79\n",
      "loss avg after 3000 epochs: 646.145\n",
      "loss avg after 4000 epochs: 142.909\n",
      "loss avg after 5000 epochs: 67.5886\n",
      "loss avg after 6000 epochs: 69.1627\n",
      "loss avg after 7000 epochs: 76.7299\n",
      "loss avg after 8000 epochs: 67.3888\n",
      "loss avg after 9000 epochs: 60.7577\n",
      "loss avg after 10000 epochs: 72.9347\n",
      "-- repeat 3 of 5 --\n",
      "loss avg after 1000 epochs: 1610.18\n",
      "loss avg after 2000 epochs: 873.294\n",
      "loss avg after 3000 epochs: 262.639\n",
      "loss avg after 4000 epochs: 95.4437\n",
      "loss avg after 5000 epochs: 60.0424\n",
      "loss avg after 6000 epochs: 55.2865\n",
      "loss avg after 7000 epochs: 70.62\n",
      "loss avg after 8000 epochs: 71.461\n",
      "loss avg after 9000 epochs: 68.9292\n",
      "loss avg after 10000 epochs: 62.812\n",
      "-- repeat 4 of 5 --\n",
      "loss avg after 1000 epochs: 90.8698\n",
      "loss avg after 2000 epochs: 32.294\n",
      "loss avg after 3000 epochs: 30.4883\n",
      "loss avg after 4000 epochs: 24.4969\n",
      "loss avg after 5000 epochs: 24.4781\n",
      "loss avg after 6000 epochs: 32.8521\n",
      "loss avg after 7000 epochs: 18.945\n",
      "loss avg after 8000 epochs: 25.2102\n",
      "loss avg after 9000 epochs: 22.2606\n",
      "loss avg after 10000 epochs: 17.1423\n",
      "-- repeat 5 of 5 --\n",
      "loss avg after 1000 epochs: 14474.1\n",
      "loss avg after 2000 epochs: 94.8084\n",
      "loss avg after 3000 epochs: 74.9167\n",
      "loss avg after 4000 epochs: 53.6447\n",
      "loss avg after 5000 epochs: 25.9589\n",
      "loss avg after 6000 epochs: 21.0071\n",
      "loss avg after 7000 epochs: 13.1302\n",
      "loss avg after 8000 epochs: 77.7806\n",
      "loss avg after 9000 epochs: 25.925\n",
      "loss avg after 10000 epochs: 11.9503\n",
      "node_ins_recall: 0.439796 +- 0.252476\n",
      "node_ins_precision: 0.934165 +- 0.130921\n",
      "node_del_recall: 0.532146 +- 0.323201\n",
      "node_del_precision: 0.988704 +- 0.015553\n",
      "edge_ins_recall: 1 +- 0\n",
      "edge_ins_precision: 0.9992 +- 0.00116619\n",
      "edge_del_recall: 1 +- 0\n",
      "edge_del_precision: 1 +- 0\n",
      "\n",
      "--- data set degree_rules_test100 ---\n",
      "\n",
      "--- model VGAE ---\n",
      "node_ins_recall: 0.131112 +- 0.00807867\n",
      "node_ins_precision: 1 +- 0\n",
      "node_del_recall: 0.990428 +- 0.00574358\n",
      "node_del_precision: 0.925123 +- 0.0409349\n",
      "edge_ins_recall: 0.835711 +- 0.0687808\n",
      "edge_ins_precision: 0.854662 +- 0.155212\n",
      "edge_del_recall: 1 +- 0\n",
      "edge_del_precision: 0.848273 +- 0.173499\n",
      "--- model GEN_crossent ---\n",
      "node_ins_recall: 0.999914 +- 0.000105757\n",
      "node_ins_precision: 0.999094 +- 0.0010823\n",
      "node_del_recall: 0.999906 +- 0.0001888\n",
      "node_del_precision: 0.999373 +- 0.00113407\n",
      "edge_ins_recall: 0.905562 +- 0.0822205\n",
      "edge_ins_precision: 0.969919 +- 0.0181965\n",
      "edge_del_recall: 1 +- 0\n",
      "edge_del_precision: 0.990602 +- 0.00591954\n",
      "--- model GEN ---\n",
      "node_ins_recall: 0.999822 +- 0.000317515\n",
      "node_ins_precision: 0.997579 +- 0.00471465\n",
      "node_del_recall: 1 +- 0\n",
      "node_del_precision: 0.995943 +- 0.0053812\n",
      "edge_ins_recall: 0.891646 +- 0.138224\n",
      "edge_ins_precision: 0.981655 +- 0.0231837\n",
      "edge_del_recall: 1 +- 0\n",
      "edge_del_precision: 0.989806 +- 0.0132954\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "# iterate over all datasets\n",
    "import pandas as pd\n",
    "\n",
    "# set unique, if you want to test on unique test set\n",
    "unique = False\n",
    "\n",
    "if unique:\n",
    "    # definitions from generator so we can use the same for creating the test set\n",
    "    df = pd.DataFrame(eval_criteria, columns = ['eval'])\n",
    "    n_init = 8\n",
    "    n_max = 8*4\n",
    "    grid_size = 10\n",
    "    num_shapes = 1\n",
    "    p = 0.1\n",
    "    T_max = 10\n",
    "    \n",
    "    #uncoment which experiment do you want\n",
    "    #test_set, unique_test_set = degree_rules.create_test_set(N_test, n_init, n_max = n_max)\n",
    "    test_set, unique_test_set = game_of_life.create_test_set(N_test)\n",
    "\n",
    "\n",
    "for d in range(len(datasets)):\n",
    "    print('\\n--- data set %s ---\\n' % datasets[d])\n",
    "    # load partial runtime results if possible\n",
    "    runtimes_file = 'results/%s_runtimes.csv' % datasets[d]\n",
    "    if os.path.exists(runtimes_file):\n",
    "        runtimes = np.loadtxt(runtimes_file, skiprows = 1, delimiter = '\\t')\n",
    "    else:\n",
    "        runtimes = np.full((R, len(models)), np.nan)\n",
    "    # iterate over all models\n",
    "    for k in range(len(models)):\n",
    "        print('--- model %s ---' % models[k])\n",
    "        # load partial results if possible\n",
    "        results_file = 'results/%s_%s_results.csv' % (datasets[d], models[k])\n",
    "        curves_file  = 'results/%s_%s_learning_curves.csv' % (datasets[d], models[k])\n",
    "        if os.path.exists(results_file):\n",
    "            results = np.loadtxt(results_file, skiprows = 1, delimiter = '\\t')\n",
    "            learning_curves = np.loadtxt(curves_file, delimiter = '\\t')\n",
    "        else:\n",
    "            results = np.full((R, len(eval_criteria)), np.nan)\n",
    "            learning_curves = np.full((max_epochs, R), np.nan)\n",
    "        # iterate over experimental repeats\n",
    "        for r in range(R):\n",
    "            # check if this repeat is already evaluated; if so, skip it\n",
    "            if not np.isnan(learning_curves[0, r]):\n",
    "                continue\n",
    "            print('-- repeat %d of %d --' % (r+1, R))\n",
    "            start_time = time.time()\n",
    "            # set up model\n",
    "            if datasets[d] == 'game_of_life':\n",
    "                nonlin = torch.nn.Sigmoid()\n",
    "            else:\n",
    "                nonlin = torch.nn.ReLU()\n",
    "            model = setup_funs[k](dim_ins[d], nonlin)\n",
    "            # set up optimizer\n",
    "            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
    "            # initialize moving loss average for printing\n",
    "            loss_avg = None\n",
    "            # start training\n",
    "            for epoch in range(max_epochs):\n",
    "                optimizer.zero_grad()\n",
    "                # sample a time series from the data set\n",
    "                if unique:\n",
    "                    As, Xs, deltas, Epsilons = generator_funs[d](unique_test_set)\n",
    "                else:\n",
    "                    As, Xs, deltas, Epsilons = generator_funs[d]()\n",
    "                # compute the loss over all time steps\n",
    "                loss = 0.\n",
    "                for t in range(len(As)):\n",
    "                    # compute loss\n",
    "                    loss_obj = loss_funs[k](model, As[t], Xs[t], deltas[t], Epsilons[t])\n",
    "                    # compute gradient\n",
    "                    loss_obj.backward()\n",
    "                    # accumulate loss\n",
    "                    loss += loss_obj.item()\n",
    "                # perform an optimizer step\n",
    "                optimizer.step()\n",
    "                # store the current loss value in the learning curve\n",
    "                learning_curves[epoch, r] = loss\n",
    "                # compute a new moving average over the loss\n",
    "                if loss_avg is None:\n",
    "                    loss_avg = loss\n",
    "                else:\n",
    "                    loss_avg = loss_avg * 0.9 + 0.1 * loss\n",
    "                # print every print_step steps\n",
    "                if(epoch+1) % print_step == 0:\n",
    "                    print('loss avg after %d epochs: %g' % (epoch+1, loss_avg))\n",
    "                # stop early if the moving average is small\n",
    "                if loss_avg < loss_threshold:\n",
    "                    break\n",
    "            # perform evaluation on new time series\n",
    "            results[r, :] = 0.\n",
    "            T = 0\n",
    "            for j in range(N_test):\n",
    "                # get a random time series from the dataset\n",
    "                if unique:\n",
    "                    As, Xs, deltas, Epsilons = test_set[j]\n",
    "                    if len(As) != len(Xs):\n",
    "                        print(\"#########\")\n",
    "                else:\n",
    "                    As, Xs, deltas, Epsilons = generator_funs[d]()\n",
    "                for t in range(len(As)):\n",
    "                    # predict the current time step with the network\n",
    "                    #print(\"test\")\n",
    "                    #print(\"ast\")\n",
    "                    #print(As[t].shape)\n",
    "                    ##print(\"xst\")\n",
    "                    #print(Xs[t].shape)\n",
    "                    delta, Epsilon = pred_funs[k](model, As[t], Xs[t])\n",
    "                    # assess node edit precision and recall\n",
    "                    results[r, :4] += prec_rec(delta, deltas[t])\n",
    "                    # assess edge edit precision and recall\n",
    "                    results[r, 4:] += prec_rec(Epsilon, Epsilons[t])\n",
    "                        \n",
    "                T += len(As)\n",
    "            results[r, :] /= T\n",
    "            # store runtime\n",
    "            runtimes[r, k] = time.time() - start_time\n",
    "            np.savetxt(runtimes_file, runtimes, delimiter = '\\t', fmt = '%g', header = '\\t'.join(models), comments = '')\n",
    "            # store results\n",
    "            np.savetxt(results_file, results, delimiter = '\\t', fmt = '%g', header = '\\t'.join(eval_criteria), comments = '')\n",
    "            # store learning curves\n",
    "            np.savetxt(curves_file, learning_curves, delimiter = '\\t', fmt = '%g')\n",
    "        # print results\n",
    "        res_collect = []\n",
    "        for crit in range(len(eval_criteria)):\n",
    "            res_collect.append('%g +- %g' % ( np.mean(results[:, crit]), np.std(results[:, crit])))\n",
    "            print('%s: %g +- %g' % (eval_criteria[crit], np.mean(results[:, crit]), np.std(results[:, crit])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualize learning curves\n",
    "import matplotlib.pyplot as plt\n",
    "smoothing_steps = 10\n",
    "fig, axes = plt.subplots(ncols=1, nrows=len(datasets))\n",
    "for d in range(len(datasets)):\n",
    "    for k in range(len(models)):\n",
    "        curves_file  = 'results/%s_%s_learning_curves.csv' % (datasets[d], models[k])\n",
    "        learning_curves = np.loadtxt(curves_file, delimiter = '\\t')\n",
    "        acum = np.cumsum(np.nanmean(learning_curves, 1))\n",
    "        axes[d].semilogy((acum[smoothing_steps:] - acum[:-smoothing_steps])/smoothing_steps)\n",
    "    axes[d].set_xlabel('epoch')\n",
    "    axes[d].set_ylabel('loss')\n",
    "    axes[d].set_title(datasets[d])\n",
    "    axes[d].legend(models)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:root] *",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
