{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph Dynamical Systems\n",
    "\n",
    "This notebooks contains the experiments to evaluate graph edit networks on simple graph dynamical systems, namely the edit cycles, degree rules, and game of life datasets."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hyperparameter setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import torch\n",
    "import pytorch_graph_edit_networks as gen\n",
    "from torch_geometric.utils import dense_to_sparse\n",
    "import baseline_models_with_vgrnn as baseline_models\n",
    "import os\n",
    "import time\n",
    "import hep_th\n",
    "# torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
    "# torch.cuda.set_device(0)\n",
    "\n",
    "\n",
    "# model hyperparameters\n",
    "num_layers = 2\n",
    "dim_hid = 64\n",
    "\n",
    "# training hyperparameters\n",
    "learning_rate  = 1E-3\n",
    "weight_decay   = 1E-5\n",
    "loss_threshold = 1E-3\n",
    "#TODO run on more\n",
    "max_epochs     = 3000\n",
    "\n",
    "print_step     = 100\n",
    "\n",
    "R = 5        # number of repetitions for each experiment\n",
    "N_test = 10  # number of test time series we use to evaluate learning afterwards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# if torch.cuda.is_available():\n",
    "#     dev = \"cuda:0\"\n",
    "# else:\n",
    "#     dev = \"cpu\"\n",
    "#\n",
    "# dev"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# SETUP FUNCTIONS\n",
    "def setup_vgae(dim_in, nonlin):\n",
    "    return baseline_models.VGAE(num_layers=num_layers,\n",
    "                                dim_in=dim_in,\n",
    "                                dim_hid=dim_hid,\n",
    "                                beta=1E-3,\n",
    "                                sigma_scaling=1E-3,\n",
    "                                nonlin=nonlin)\n",
    "\n",
    "\n",
    "def setup_vgrnn(dim_in, nonlin):\n",
    "    return baseline_models.VGRNN(num_layers=num_layers,\n",
    "                                 dim_in=dim_in,\n",
    "                                 dim_hid=dim_hid)\n",
    "\n",
    "\n",
    "def setup_gen(dim_in, nonlin):\n",
    "    return gen.GEN(num_layers=num_layers,\n",
    "                   dim_in=dim_in,\n",
    "                   dim_hid=dim_hid,\n",
    "                   nonlin=nonlin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# LOSS FUNCTIONS\n",
    "loss_fun = gen.GEN_loss()\n",
    "crossent_loss_fun = gen.GEN_loss_crossent()\n",
    "def vgae_loss(model, A, X, delta, Epsilon, state=None):\n",
    "    model = model\n",
    "    B = A + Epsilon\n",
    "    # delete all outgoing and incoming edges of deleted nodes\n",
    "    B[delta < -0.5, :] = 0\n",
    "    B[:, delta < -0.5] = 0\n",
    "    loss = model.compute_loss(torch.tensor(A, dtype=torch.float),\n",
    "                              torch.tensor(B, dtype=torch.float),\n",
    "                              torch.tensor(X, dtype=torch.float))\n",
    "\n",
    "    return loss, state\n",
    "\n",
    "\n",
    "def vgrnn_loss(model, A, X, delta, Epsilon, state=None):\n",
    "    model = model\n",
    "    A = torch.tensor(A, dtype=torch.float)\n",
    "    A = A\n",
    "    edge_index, _ = dense_to_sparse(A)\n",
    "    edge_index = edge_index\n",
    "    X = torch.Tensor(X)\n",
    "    X = X\n",
    "    predicted = model(X, edge_index, hidden_in=state)\n",
    "    predicted = predicted\n",
    "    predicted, state = predicted[:-1], predicted[-1]\n",
    "    state=state\n",
    "\n",
    "    target = A + Epsilon\n",
    "    target = target\n",
    "    target[delta < -0.5, :] = 0\n",
    "    target[:, delta < -0.5] = 0\n",
    "    # print('------')\n",
    "    return model.compute_loss(*predicted, target), state\n",
    "\n",
    "\n",
    "def gen_loss_crossent(model, A, X, delta, Epsilon, state=None):\n",
    "    delta_pred, Epsilon_pred = model(torch.tensor(A, dtype=torch.float),\n",
    "                                     torch.tensor(X, dtype=torch.float))\n",
    "    loss = crossent_loss_fun(delta_pred, Epsilon_pred,\n",
    "                             torch.tensor(delta, dtype=torch.float),\n",
    "                             torch.tensor(Epsilon, dtype=torch.float),\n",
    "                             torch.tensor(A, dtype=torch.float))\n",
    "\n",
    "    return loss, state\n",
    "\n",
    "\n",
    "def gen_loss(model, A, X, delta, Epsilon, state=None):\n",
    "    delta_pred, Epsilon_pred = model(torch.tensor(A, dtype=torch.float),\n",
    "                                     torch.tensor(X, dtype=torch.float))\n",
    "    loss = loss_fun(delta_pred, Epsilon_pred,\n",
    "                    torch.tensor(delta, dtype=torch.float),\n",
    "                    torch.tensor(Epsilon, dtype=torch.float),\n",
    "                    torch.tensor(A, dtype=torch.float))\n",
    "\n",
    "    return loss, state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# PREDICTION FUNCTIONS\n",
    "def vgae_pred(model, A, X, state=None):\n",
    "    B = model(torch.tensor(A, dtype=torch.float), torch.tensor(X, dtype=torch.float))\n",
    "    B = B.detach().numpy()\n",
    "    Epsilon = B - A\n",
    "    delta = np.zeros(A.shape[0])\n",
    "    delta[np.sum(B, 1) < 0.5] = -1.\n",
    "    Epsilon[delta < -0.5, :] = 0.\n",
    "    Epsilon[:, delta < -0.5] = 0.\n",
    "    return delta, Epsilon, state\n",
    "\n",
    "\n",
    "def vgrnn_pred(model, A, X, state=None):\n",
    "    model = model\n",
    "    A = torch.tensor(A, dtype=torch.float)\n",
    "    A = A\n",
    "    edge_index = dense_to_sparse(A)\n",
    "    edge_index = edge_index\n",
    "    predicted = model(torch.tensor(X), edge_index, hidden_in=state)\n",
    "    predicted = predicted\n",
    "    predicted, state = predicted[:-1], predicted[-1]\n",
    "    predicted = predicted\n",
    "    state = state\n",
    "\n",
    "    Epsilon = predicted - A\n",
    "    Epsilon = Epsilon\n",
    "    delta = np.zeros(A.shape[0])\n",
    "    delta = delta\n",
    "    delta[np.sum(predicted, 1) < 0.5] = -1.\n",
    "    Epsilon[delta < -0.5, :] = 0.\n",
    "    Epsilon[:, delta < -0.5] = 0.\n",
    "\n",
    "    return delta, Epsilon, state\n",
    "\n",
    "\n",
    "def gen_pred(model, A, X, state=None):\n",
    "    delta_pred, Epsilon_pred = model(torch.tensor(A, dtype=torch.float), torch.tensor(X, dtype=torch.float))\n",
    "    delta_pred = delta_pred.detach().numpy()\n",
    "    Epsilon_pred = Epsilon_pred.detach().numpy()\n",
    "    delta = np.zeros(A.shape[0])\n",
    "    delta[delta_pred > 0.5] = 1.\n",
    "    delta[delta_pred < -0.5] = -1.\n",
    "    Epsilon = np.zeros(A.shape)\n",
    "    Epsilon[np.logical_and(A > 0.5, Epsilon_pred < -0.5)] = -1.\n",
    "    Epsilon[np.logical_and(A < 0.5, Epsilon_pred > +0.5)] = +1.\n",
    "    return delta, Epsilon, state\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# EVALUATION FUNCTIONS\n",
    "eval_criteria = ['node_ins_recall',\n",
    "                 'node_ins_precision',\n",
    "                 'node_del_recall',\n",
    "                 'node_del_precision',\n",
    "                 'edge_ins_recall',\n",
    "                 'edge_ins_precision',\n",
    "                 'edge_del_recall',\n",
    "                 'edge_del_precision']\n",
    "# set up a function to compute precision and recall\n",
    "def prec_rec(X, Y):\n",
    "    # X is the prediction, Y is the target\n",
    "    target_insertions = Y > 0.5\n",
    "    predicted_insertions = X > 0.5\n",
    "    target_deletions = Y < -0.5\n",
    "    predicted_deletions = X < -0.5\n",
    "    # first, check the insertion recall\n",
    "    if np.sum(target_insertions) < 0.5:\n",
    "        ins_rec = 1.\n",
    "    else:\n",
    "        ins_rec  = np.mean(X[target_insertions] > 0.5)\n",
    "    # then the insertion precision\n",
    "    if np.sum(predicted_insertions) < 0.5:\n",
    "        ins_prec = 1.\n",
    "    else:\n",
    "        ins_prec = np.mean(Y[predicted_insertions] > 0.5)\n",
    "    # then the deletion recall\n",
    "    if np.sum(target_deletions) < 0.5:\n",
    "        del_rec = 1.\n",
    "    else:\n",
    "        del_rec  = np.mean(X[target_deletions] < -0.5)\n",
    "    # and finally the deletion precision\n",
    "    if np.sum(predicted_deletions) < 0.5:\n",
    "        del_prec = 1.\n",
    "    else:\n",
    "        del_prec = np.mean(Y[predicted_deletions] < -0.5)\n",
    "    return ins_rec, ins_prec, del_rec, del_prec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import graph_edit_cycles\n",
    "import degree_rules\n",
    "import game_of_life\n",
    "import random\n",
    "\n",
    "\n",
    "# DATASET SETUP\n",
    "def generate_edit_cycle():\n",
    "    As, Xs, tuples = graph_edit_cycles.generate_time_series(random.randrange(3), random.randrange(12), random.randrange(4, 12))\n",
    "    deltas = []\n",
    "    Epsilons = []\n",
    "    for tpl in tuples:\n",
    "        deltas.append(tpl[0])\n",
    "        Epsilons.append(tpl[1])\n",
    "    return As, Xs, deltas, Epsilons\n",
    "\n",
    "\n",
    "def generate_degree_rules():\n",
    "    # the initial number of nodes in each graph\n",
    "    n_init = 8\n",
    "    # the maximum number of nodes that can occur in each graph during evolution\n",
    "    n_max  = n_init * 4\n",
    "    return degree_rules.generate_time_series_from_random_matrix(n_init, n_max = n_max)\n",
    "\n",
    "\n",
    "def generate_game_of_life():\n",
    "    # set hyper-parameters for the game of life random grid generation\n",
    "    grid_size = 10\n",
    "    num_shapes = 1\n",
    "    p = 0.1\n",
    "    T_max = 10\n",
    "    A, Xs, deltas = game_of_life.generate_random_time_series(grid_size, num_shapes, p, T_max)\n",
    "    As = [A] * len(Xs)\n",
    "    Epsilons = [np.zeros_like(A)] * len(Xs)\n",
    "    return As, Xs, deltas, Epsilons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# CONFIG FOR EXPERIMENTS\n",
    "\n",
    "# Models\n",
    "models = ['VGAE', 'VGRNN', 'GEN_crossent', 'GEN']\n",
    "setup_funs = [setup_vgae, setup_vgrnn, setup_gen, setup_gen]\n",
    "loss_funs = [vgae_loss, vgrnn_loss, gen_loss_crossent, gen_loss]\n",
    "pred_funs = [vgae_pred, vgrnn_pred, gen_pred, gen_pred]\n",
    "\n",
    "# Datasets\n",
    "#datasets = ['edit_cycles', 'degree_rules', 'game_of_life']\n",
    "datasets = ['game_of_life', 'degree_rules']\n",
    "#dim_ins  = [4, 32, 1]\n",
    "#dim_ins = [32, 1]\n",
    "dim_ins = [1, 32]\n",
    "#generator_funs = [generate_edit_cycle, generate_degree_rules, generate_game_of_life]\n",
    "generator_funs = [generate_game_of_life, generate_degree_rules]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Actual Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "--- data set game_of_life ---\n",
      "\n",
      "--- model VGAE ---\n",
      "node_ins_recall: 0.274 +- 0.0941488\n",
      "node_ins_precision: 1 +- 0\n",
      "node_del_recall: 1 +- 0\n",
      "node_del_precision: 0.0375 +- 0.00260077\n",
      "edge_ins_recall: 1 +- 0\n",
      "edge_ins_precision: 1 +- 0\n",
      "edge_del_recall: 1 +- 0\n",
      "edge_del_precision: 1 +- 0\n",
      "--- model VGRNN ---\n",
      "-- repeat 1 of 5 --\n",
      "loss avg after 100 epochs: 7.44709\n",
      "loss avg after 200 epochs: 7.40946\n",
      "loss avg after 300 epochs: 7.31923\n",
      "loss avg after 400 epochs: 7.29466\n",
      "loss avg after 500 epochs: 7.27402\n",
      "loss avg after 600 epochs: 7.22272\n",
      "loss avg after 700 epochs: 7.17407\n",
      "loss avg after 800 epochs: 7.17703\n",
      "loss avg after 900 epochs: 7.10687\n",
      "loss avg after 1000 epochs: 7.16115\n",
      "loss avg after 1100 epochs: 7.14281\n",
      "loss avg after 1200 epochs: 7.00304\n",
      "loss avg after 1300 epochs: 7.12211\n",
      "loss avg after 1400 epochs: 7.1911\n",
      "loss avg after 1500 epochs: 7.11026\n",
      "loss avg after 1600 epochs: 7.10524\n",
      "loss avg after 1700 epochs: 7.16129\n",
      "loss avg after 1800 epochs: 7.07459\n",
      "loss avg after 1900 epochs: 7.18062\n",
      "loss avg after 2000 epochs: 7.0973\n",
      "loss avg after 2100 epochs: 7.05152\n",
      "loss avg after 2200 epochs: 7.03487\n",
      "loss avg after 2300 epochs: 7.06822\n",
      "loss avg after 2400 epochs: 7.0197\n",
      "loss avg after 2500 epochs: 7.02478\n",
      "loss avg after 2600 epochs: 7.02043\n",
      "loss avg after 2700 epochs: 7.01321\n",
      "loss avg after 2800 epochs: 7.11503\n",
      "loss avg after 2900 epochs: 7.05019\n",
      "loss avg after 3000 epochs: 7.06226\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "expected scalar type Float but found Double",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-9-3b0043ee2024>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     76\u001b[0m                 \u001b[1;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mAs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     77\u001b[0m                     \u001b[1;31m# predict the current time step with the network\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 78\u001b[1;33m                     \u001b[0mdelta\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mEpsilon\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstate\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpred_funs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mk\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mAs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mXs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     79\u001b[0m                     \u001b[1;31m# assess node edit precision and recall\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     80\u001b[0m                     \u001b[0mresults\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mprec_rec\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdelta\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdeltas\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-5-044c502019e1>\u001b[0m in \u001b[0;36mvgrnn_pred\u001b[1;34m(model, A, X, state)\u001b[0m\n\u001b[0;32m     17\u001b[0m     \u001b[0medge_index\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdense_to_sparse\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mA\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     18\u001b[0m     \u001b[0medge_index\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0medge_index\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 19\u001b[1;33m     \u001b[0mpredicted\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0medge_index\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhidden_in\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     20\u001b[0m     \u001b[0mpredicted\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpredicted\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     21\u001b[0m     \u001b[0mpredicted\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstate\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpredicted\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpredicted\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\repro\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mX:\\Faks\\DS\\sem4\\Reproducibility\\code\\baseline_models_with_vgrnn.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x, edge_idx, hidden_in)\u001b[0m\n\u001b[0;32m    306\u001b[0m             \u001b[0mh\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mVariable\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_in\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    307\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 308\u001b[1;33m         \u001b[0mphi_x_t\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mphi_x\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    309\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    310\u001b[0m         \u001b[1;31m# encoder\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\repro\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\repro\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    139\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    140\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 141\u001b[1;33m             \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    142\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    143\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\repro\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\repro\\lib\\site-packages\\torch\\nn\\modules\\linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    101\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    102\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 103\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    104\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    105\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\repro\\lib\\site-packages\\torch\\nn\\functional.py\u001b[0m in \u001b[0;36mlinear\u001b[1;34m(input, weight, bias)\u001b[0m\n\u001b[0;32m   1846\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mhas_torch_function_variadic\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1847\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1848\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1849\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1850\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mRuntimeError\u001b[0m: expected scalar type Float but found Double"
     ]
    }
   ],
   "source": [
    "# iterate over all datasets\n",
    "for d in range(len(datasets)):\n",
    "    print('\\n--- data set %s ---\\n' % datasets[d])\n",
    "    # load partial runtime results if possible\n",
    "    runtimes_file = 'results/%s_runtimes.csv' % datasets[d]\n",
    "    if os.path.exists(runtimes_file):\n",
    "        runtimes = np.loadtxt(runtimes_file, skiprows = 1, delimiter = '\\t')\n",
    "    else:\n",
    "        runtimes = np.full((R, len(models)), np.nan)\n",
    "    # iterate over all models\n",
    "    for k in range(len(models)):\n",
    "        print('--- model %s ---' % models[k])\n",
    "        # load partial results if possible\n",
    "        results_file = 'results/%s_%s_results.csv' % (datasets[d], models[k])\n",
    "        curves_file  = 'results/%s_%s_learning_curves.csv' % (datasets[d], models[k])\n",
    "        if os.path.exists(results_file):\n",
    "            results = np.loadtxt(results_file, skiprows = 1, delimiter = '\\t')\n",
    "            learning_curves = np.loadtxt(curves_file, delimiter = '\\t')\n",
    "        else:\n",
    "            results = np.full((R, len(eval_criteria)), np.nan)\n",
    "            learning_curves = np.full((max_epochs, R), np.nan)\n",
    "        # iterate over experimental repeats\n",
    "        for r in range(R):\n",
    "            # check if this repeat is already evaluated; if so, skip it\n",
    "            if not np.isnan(learning_curves[0, r]):\n",
    "                continue\n",
    "            print('-- repeat %d of %d --' % (r+1, R))\n",
    "            start_time = time.time()\n",
    "            # set up model\n",
    "            if datasets[d] == 'game_of_life':\n",
    "                nonlin = torch.nn.Sigmoid()\n",
    "            else:\n",
    "                nonlin = torch.nn.ReLU()\n",
    "            model = setup_funs[k](dim_ins[d], nonlin)\n",
    "            # set up optimizer\n",
    "            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
    "            # initialize moving loss average for printing\n",
    "            loss_avg = None\n",
    "            # start training\n",
    "            for epoch in range(max_epochs):\n",
    "                optimizer.zero_grad()\n",
    "                # sample a time series from the data set\n",
    "                As, Xs, deltas, Epsilons = generator_funs[d]()\n",
    "                # compute the loss over all time steps\n",
    "                loss = 0.\n",
    "                state = None\n",
    "                for t in range(len(As)):\n",
    "                    # compute loss\n",
    "                    loss_obj, state = loss_funs[k](model, As[t], Xs[t], deltas[t], Epsilons[t], state=state)\n",
    "                    # compute gradient\n",
    "                    loss_obj.backward()\n",
    "                    # accumulate loss\n",
    "                    loss += loss_obj.item()\n",
    "                # perform an optimizer step\n",
    "                optimizer.step()\n",
    "                # store the current loss value in the learning curve\n",
    "                learning_curves[epoch, r] = loss\n",
    "                # compute a new moving average over the loss\n",
    "                if loss_avg is None:\n",
    "                    loss_avg = loss\n",
    "                else:\n",
    "                    loss_avg = loss_avg * 0.9 + 0.1 * loss\n",
    "                # print every print_step steps\n",
    "                if(epoch+1) % print_step == 0:\n",
    "                    print('loss avg after %d epochs: %g' % (epoch+1, loss_avg))\n",
    "                # stop early if the moving average is small\n",
    "                if loss_avg < loss_threshold:\n",
    "                    break\n",
    "            # perform evaluation on new time series\n",
    "            results[r, :] = 0.\n",
    "            T = 0\n",
    "            for j in range(N_test):\n",
    "                # get a random time series from the dataset\n",
    "                As, Xs, deltas, Epsilons = generator_funs[d]()\n",
    "                state = None\n",
    "                for t in range(len(As)):\n",
    "                    # predict the current time step with the network\n",
    "                    delta, Epsilon, state = pred_funs[k](model, As[t], Xs[t], state=state)\n",
    "                    # assess node edit precision and recall\n",
    "                    results[r, :4] += prec_rec(delta, deltas[t])\n",
    "                    # assess edge edit precision and recall\n",
    "                    results[r, 4:] += prec_rec(Epsilon, Epsilons[t])\n",
    "                    # TODO this is only for debugging purposes\n",
    "                    if (d != 1 and k == 2 and np.any(results[r, :] < 0.99)) or np.any(np.isnan(results[r, :])):\n",
    "                        print('delta (predicted) = %s' % str(delta))\n",
    "                        print('delta (target) = %s' % str(deltas[t]))\n",
    "                        print('Epsilon (predicted) = %s' % str(Epsilon))\n",
    "                        print('Epsilon (target) = %s' % str(Epsilons[t]))\n",
    "                        raise ValueError('stop')\n",
    "                        \n",
    "                T += len(As)\n",
    "            results[r, :] /= T\n",
    "            # store runtime\n",
    "            runtimes[r, k] = time.time() - start_time\n",
    "            np.savetxt(runtimes_file, runtimes, delimiter = '\\t', fmt = '%g', header = '\\t'.join(models), comments = '')\n",
    "            # store results\n",
    "            np.savetxt(results_file, results, delimiter = '\\t', fmt = '%g', header = '\\t'.join(eval_criteria), comments = '')\n",
    "            # store learning curves\n",
    "            np.savetxt(curves_file, learning_curves, delimiter = '\\t', fmt = '%g')\n",
    "        # print results\n",
    "        for crit in range(len(eval_criteria)):\n",
    "            print('%s: %g +- %g' % (eval_criteria[crit], np.mean(results[:, crit]), np.std(results[:, crit])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "# visualize learning curves\n",
    "import matplotlib.pyplot as plt\n",
    "smoothing_steps = 10\n",
    "fig, axes = plt.subplots(ncols=1, nrows=len(datasets))\n",
    "for d in range(len(datasets)):\n",
    "    for k in range(len(models)):\n",
    "        curves_file  = 'results/%s_%s_learning_curves.csv' % (datasets[d], models[k])\n",
    "        learning_curves = np.loadtxt(curves_file, delimiter = '\\t')\n",
    "        acum = np.cumsum(np.nanmean(learning_curves, 1))\n",
    "        axes[d].semilogy((acum[smoothing_steps:] - acum[:-smoothing_steps])/smoothing_steps)\n",
    "    axes[d].set_xlabel('epoch')\n",
    "    axes[d].set_ylabel('loss')\n",
    "    axes[d].set_title(datasets[d])\n",
    "    axes[d].legend(models)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:reproducibility-challenge]",
   "language": "python",
   "name": "conda-env-reproducibility-challenge-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
