{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SIT-GNN vs. GCN vs. DBGNN\n",
    "\n",
    "## Data Sets\n",
    "\n",
    "The data sets used in the paper are in the data folder.\n",
    "Specify the number of classes as defined in the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = '' # adapt to your path, if needed\n",
    "\n",
    "path_data = root+\"data/synthetic_paths.ngram\"\n",
    "label_data = root+\"data/synthetic_labels.txt\"\n",
    "splits_data = root+\"data/synthetic_splits_cv.json\"\n",
    "classes = 2\n",
    "\n",
    "# Have a look at the data folder for the remaining data sets. Classes are given in the paper."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Environment\n",
    "\n",
    "Use the devcontainer to install the required dependencies.\n",
    "The following imports are needed to run the code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import json\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from torch.optim import SGD\n",
    "from torch.nn import ModuleList, Module, Linear\n",
    "from torch.nn.functional import one_hot\n",
    "from torch_geometric.nn import GCNConv, MessagePassing\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.transforms import RemoveDuplicatedEdges\n",
    "\n",
    "from sklearn.model_selection import StratifiedShuffleSplit\n",
    "from sklearn.metrics import balanced_accuracy_score\n",
    "\n",
    "from tqdm import trange\n",
    "import pathpyG as pp\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Data\n",
    "\n",
    "### Creation of De Bruijn graphs\n",
    "\n",
    "Path data and labels are read and converted to higher-order De Bruijn graphs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def map_labels_to_y(labels, num_nodes, node_index_to_id):\n",
    "    classes = sorted(set(labels.values()))\n",
    "    \n",
    "    y = torch.zeros((num_nodes,), dtype=torch.int32)\n",
    "    y[:] = -1\n",
    "\n",
    "    for index in range(num_nodes):\n",
    "        y[index] = classes.index(labels[node_index_to_id[index]])\n",
    "\n",
    "    assert not -1 in y, \"not all nodes have labels\"\n",
    "\n",
    "    return y, classes\n",
    "\n",
    "\n",
    "def load_labels_from_file(path, num_nodes, node_index_to_id, separator = \"\\t\", label_index = -1):\n",
    "    labels = {} \n",
    "    with open(path,\"r\") as f:\n",
    "        for line in f:\n",
    "            line = line.replace(\"\\n\",\"\")\n",
    "            list_line = line.split(separator)\n",
    "            labels[list_line[0]] = list_line[label_index]\n",
    "    \n",
    "    return map_labels_to_y(labels, num_nodes, node_index_to_id)\n",
    "    \n",
    "\n",
    "def load_labels_from_csv(path, num_nodes, node_index_to_id):\n",
    "    with open(path, newline='') as csvfile:\n",
    "        reader = csv.reader(csvfile, delimiter=',')\n",
    "        labels = {}\n",
    "        for row in reader:\n",
    "            labels[row[0]] = int(row[1])\n",
    "            \n",
    "    return map_labels_to_y(labels, num_nodes, node_index_to_id)\n",
    "\n",
    "def load_paths(path: str) -> pp.PathData:\n",
    "    return pp.PathData.from_csv(path)\n",
    "\n",
    "\n",
    "def build_graph(paths: pp.PathData, k = 1) -> Data:\n",
    "    graph = pp.HigherOrderGraph(paths, order=k, node_id=paths.node_id)\n",
    "    data = graph.to_pyg_data()\n",
    "    data.edge_index = data.edge_index.to(torch.int64)\n",
    "    transform = RemoveDuplicatedEdges(key=[\"edge_weight\"], reduce=\"add\")\n",
    "    data = transform(data)\n",
    "    data.node_id_to_index = graph.node_id_to_index\n",
    "    data.node_index_to_id = graph.node_index_to_id\n",
    "    data.order = k\n",
    "    return data"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculation of Significance Scores\n",
    "\n",
    "The significance scores are calculated for the given higher-order graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "_julia_is_enabled = False\n",
    "hyge = None\n",
    "lcdf = None\n",
    "\n",
    "def lhygecdf(k : torch.tensor, s : torch.tensor, f : torch.tensor, n : torch.tensor) ->  torch.tensor:\n",
    "    global _julia_is_enabled\n",
    "    global hyge\n",
    "    global lcdf\n",
    "\n",
    "    if not _julia_is_enabled:\n",
    "        print(\"enabling julia...\", end='')\n",
    " \n",
    "        from julia.api import Julia\n",
    "        jl = Julia(compiled_modules=False)\n",
    "        from julia.Distributions import Hypergeometric, logcdf\n",
    "        \n",
    "        hyge = Hypergeometric\n",
    "        lcdf = logcdf\n",
    "\n",
    "        _julia_is_enabled = True\n",
    "        print(\" Done\")\n",
    "    \n",
    "    hyge_inst = hyge(int(s.item()), int(f.item()), int(n.item()))\n",
    "    return torch.tensor(lcdf(hyge_inst, k.item()))\n",
    "\n",
    "def calculates_sig_scores(num_nodes: int, edge_index: torch.Tensor, edge_weights: torch.Tensor, tolerance: float = 1e-5, verbose: bool = False, redistribute=True) -> tuple:\n",
    "    \"\"\"\n",
    "    assumption: no duplicated edges\n",
    "    to ensure, use:\n",
    "    transform = RemoveDuplicatedEdges(key=[\"edge_weight\"], reduce=\"add\")\n",
    "    data = transform(data)\n",
    "    \"\"\"\n",
    "\n",
    "    def weighted_in_deg(edge_weights):\n",
    "        return torch.zeros((num_nodes,), dtype=edge_weights.dtype, device=edge_weights.device).scatter_add_(dim=0, index=edge_index[1, :], src=edge_weights)\n",
    "\n",
    "    def weighted_out_deg(edge_weights):\n",
    "        return torch.zeros((num_nodes,), dtype=edge_weights.dtype, device=edge_weights.device).scatter_add_(dim=0, index=edge_index[0, :], src=edge_weights)\n",
    "\n",
    "    f_v_out = weighted_out_deg(edge_weights)\n",
    "    f_v_in  = weighted_in_deg(edge_weights)\n",
    "    \n",
    "    m = f_v_out.sum()\n",
    "    M = torch.matmul(f_v_out.unsqueeze(1), f_v_in.unsqueeze(0)).sum()\n",
    "    print(m, M)\n",
    "    \n",
    "    xi_vw = f_v_out[edge_index[0, :]] * f_v_in[edge_index[1, :]]\n",
    "    \n",
    "    if redistribute:\n",
    "        for i in range(5000):\n",
    "            # expectation for in-degrees    \n",
    "            f_hat_v_in = weighted_in_deg(xi_vw) / m * M / xi_vw.sum()\n",
    "            # correction of in-degrees\n",
    "            xi_vw = xi_vw * (f_v_in / f_hat_v_in)[edge_index[1, :]].nan_to_num(1)\n",
    "            # expectation for out-degrees\n",
    "            f_hat_v_out = weighted_out_deg(xi_vw) / m * M / xi_vw.sum()\n",
    "            # correction for out-degrees\n",
    "            xi_vw = xi_vw * (f_v_out / f_hat_v_out)[edge_index[0, :]].nan_to_num(1)\n",
    "\n",
    "            rmse = ((f_v_out - f_hat_v_out)**2).sum().sqrt().item() + ((f_v_in - f_hat_v_in)**2).sum().sqrt().item()\n",
    "            \n",
    "            if verbose:\n",
    "                print(rmse)\n",
    "            \n",
    "            if  rmse < tolerance:\n",
    "                break\n",
    "        if verbose:\n",
    "            print(f\"Optimized for {i+1} iterations.\")\n",
    "    \n",
    "    pval = torch.zeros_like(xi_vw)\n",
    "    for i in range(len(xi_vw)):\n",
    "        pval[i] = lhygecdf(edge_weights[i], m, xi_vw.sum() - m, xi_vw[i]).exp()\n",
    "\n",
    "    return pval, xi_vw"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Complete Import of Data\n",
    "\n",
    "We import the path data, build the De Bruijn graph and calculate the significance scores."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_data(paths: str, labels: str, max_order: int = 1, verbose: bool = False, redistribute=True) -> Data:\n",
    "    assert max_order >= 1\n",
    "\n",
    "    if verbose:\n",
    "        print('Reading paths...')\n",
    "    path_data = load_paths(paths)\n",
    "\n",
    "    graph_data = []\n",
    "    for k in range(1, max_order+1):\n",
    "        if verbose:\n",
    "            print(f'Building graph data with order {k}...')\n",
    "        data = build_graph(path_data, k = k)\n",
    "\n",
    "        if verbose:\n",
    "            print(f'--- Calculating Significance Scores for order {k}...')\n",
    "        data.pval, data.xi = calculates_sig_scores(data.num_nodes, data.edge_index, data.edge_weight, redistribute=redistribute, verbose=verbose)\n",
    "\n",
    "        graph_data.append(data)\n",
    "\n",
    "        \n",
    "    if verbose:\n",
    "        print('Reading labels ...')\n",
    "    if labels.endswith('.txt'):\n",
    "        y, classes = load_labels_from_file(labels, graph_data[0].num_nodes, graph_data[0].node_index_to_id)\n",
    "    else:\n",
    "        y, classes = load_labels_from_csv(labels, graph_data[0].num_nodes, graph_data[0].node_index_to_id)\n",
    "\n",
    "    edge_index = [data.edge_index for data in graph_data]\n",
    "    edge_weight = [data.edge_weight for data in graph_data]\n",
    "    edge_attr = [data.pval for data in graph_data]\n",
    "    edge_xi = [data.xi for data in graph_data]\n",
    "    \n",
    "    num_nodes = [data.num_nodes for data in graph_data]\n",
    "    \n",
    "    node_idx_to_id = [data.node_index_to_id for data in graph_data]\n",
    "    node_id_to_idx = [data.node_id_to_index for data in graph_data]\n",
    "\n",
    "    order = [data.order for data in graph_data]\n",
    "\n",
    "    def build_x_one_hot(num_nodes):\n",
    "        return one_hot(torch.arange(num_nodes), num_classes = num_nodes).to(torch.float)\n",
    "    \n",
    "    x = build_x_one_hot(num_nodes[0]) \n",
    "    x_h = build_x_one_hot(num_nodes[-1])    \n",
    "\n",
    "    def build_bipartite_mapping_low_to_high(id2idx_low, idx2id_high):\n",
    "        return torch.tensor([[id2idx_low[idx2id_high[idx_high][0]], idx_high] for idx_high in idx2id_high.keys()], dtype=torch.long).t()\n",
    "\n",
    "\n",
    "    def build_bipartite_mapping_high_to_low(id2idx_low, idx2id_high, pos=-1):\n",
    "        return torch.tensor([[idx_high, id2idx_low[idx2id_high[idx_high][pos]]] for idx_high in idx2id_high.keys()], dtype=torch.long).t()\n",
    "\n",
    "    data = Data(\n",
    "        edge_index                  = edge_index[0],\n",
    "        edge_index_higher_order     = edge_index[-1],\n",
    "        bipartite_edge_index        = build_bipartite_mapping_high_to_low(node_id_to_idx[0], node_idx_to_id[-1], pos=0),\n",
    "        bipartite_edge_index_start  = build_bipartite_mapping_low_to_high(node_id_to_idx[0], node_idx_to_id[-1]),\n",
    "        edge_weight                 = edge_weight[0],\n",
    "        edge_weight_higher_order    = edge_weight[-1],\n",
    "        x                           = x,\n",
    "        x_h                         = x_h,\n",
    "        y                           = y.long(),\n",
    "        num_nodes                   = num_nodes[0],\n",
    "        num_ho_nodes                = num_nodes[-1],\n",
    "        # significance scores:\n",
    "        edge_attr                   = edge_attr[0].unsqueeze(1),\n",
    "        edge_attr_higher_order      = edge_attr[-1].unsqueeze(1),\n",
    "        node_id_to_idx              = node_id_to_idx,\n",
    "        node_idx_to_id              = node_idx_to_id,\n",
    "    )\n",
    "    return data"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models\n",
    "\n",
    "### GCN (Kipf & Welling)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Graph convolutional network as proposed in:\n",
    "'Semi-Supervised Classification with Graph Convolutional Networks'\n",
    "Paper: https://arxiv.org/abs/1609.02907\n",
    "Website: https://tkipf.github.io/graph-convolutional-networks/\n",
    "\"\"\"\n",
    "\n",
    "class GCN(Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        num_classes,\n",
    "        num_features,\n",
    "        hidden_dims,\n",
    "        p_dropout=0.0\n",
    "        ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.num_features = num_features\n",
    "        self.num_classes = num_classes\n",
    "        self.hidden_dims = hidden_dims\n",
    "        self.p_dropout = p_dropout\n",
    "\n",
    "        # first-order layers\n",
    "        self.first_order_layers = ModuleList()\n",
    "        self.first_order_layers.append(GCNConv(self.num_features, self.hidden_dims[0]))\n",
    "\n",
    "        for dim in range(1, len(self.hidden_dims)):\n",
    "            # first-order layers\n",
    "            self.first_order_layers.append(GCNConv(self.hidden_dims[dim-1], self.hidden_dims[dim]))\n",
    "\n",
    "        # Linear layer\n",
    "        self.lin = torch.nn.Linear(self.hidden_dims[-1], num_classes)\n",
    "\n",
    "\n",
    "\n",
    "    def forward(self, data):\n",
    "\n",
    "        x = data.x\n",
    "\n",
    "        # First-order convolutions\n",
    "        for layer in self.first_order_layers:\n",
    "            x = F.dropout(x, p=self.p_dropout, training=self.training)\n",
    "            x = F.elu(layer(x, data.edge_index, data.edge_weight))\n",
    "        x = F.dropout(x, p=self.p_dropout, training=self.training)\n",
    "\n",
    "        # Linear layer\n",
    "        x = self.lin(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DBGNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "DBGNN model as proposed in \n",
    "'De Bruijn goes Neural: Causality-Aware Graph Neural Networks for Time Series Data on Dynamic Graphs'\n",
    "Paper: https://arxiv.org/abs/2209.08311\n",
    "Presentation: https://www.youtube.com/watch?v=IezbzMMp9QM\n",
    "\"\"\"\n",
    "\n",
    "class DBGNNBipartiteGraphOperator(MessagePassing):\n",
    "    def __init__(self, in_ch, out_ch):\n",
    "        super(DBGNNBipartiteGraphOperator, self).__init__('add')\n",
    "        self.lin1 = Linear(in_ch, out_ch)\n",
    "        self.lin2 = Linear(in_ch, out_ch)\n",
    "\n",
    "    def forward(self, x, bipartite_index, N, M):\n",
    "        x = (self.lin1(x[0]), self.lin2(x[1]))\n",
    "        return self.propagate(bipartite_index, size=(N, M), x=x)\n",
    "\n",
    "    def message(self, x_i, x_j):\n",
    "        return x_i + x_j\n",
    "\n",
    "class DBGNN(Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        num_classes,\n",
    "        num_features,\n",
    "        hidden_dims,\n",
    "        p_dropout=0.0\n",
    "        ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.num_features = num_features\n",
    "        self.num_classes = num_classes\n",
    "        self.hidden_dims = hidden_dims\n",
    "        self.p_dropout = p_dropout\n",
    "\n",
    "        # higher-order layers\n",
    "        self.higher_order_layers = ModuleList()\n",
    "        self.higher_order_layers.append(GCNConv(self.num_features[1], self.hidden_dims[0]))\n",
    "\n",
    "        # first-order layers\n",
    "        self.first_order_layers = ModuleList()\n",
    "        self.first_order_layers.append(GCNConv(self.num_features[0], self.hidden_dims[0]))\n",
    "\n",
    "        for dim in range(1, len(self.hidden_dims)-1):\n",
    "            # higher-order layers\n",
    "            self.higher_order_layers.append(GCNConv(self.hidden_dims[dim-1], self.hidden_dims[dim]))\n",
    "            # first-order layers\n",
    "            self.first_order_layers.append(GCNConv(self.hidden_dims[dim-1], self.hidden_dims[dim]))\n",
    "\n",
    "        self.bipartite_layer = DBGNNBipartiteGraphOperator(self.hidden_dims[-2], self.hidden_dims[-1])\n",
    "\n",
    "        # Linear layer\n",
    "        self.lin = torch.nn.Linear(self.hidden_dims[-1], num_classes)\n",
    "\n",
    "\n",
    "\n",
    "    def forward(self, data):\n",
    "\n",
    "        x = data.x\n",
    "        x_h = data.x_h\n",
    "\n",
    "        # First-order convolutions\n",
    "        for layer in self.first_order_layers:\n",
    "            x = F.dropout(x, p=self.p_dropout, training=self.training)\n",
    "            x = F.elu(layer(x, data.edge_index, data.edge_weight))\n",
    "        x = F.dropout(x, p=self.p_dropout, training=self.training)\n",
    "\n",
    "        # Second-order convolutions\n",
    "        for layer in self.higher_order_layers:\n",
    "            x_h = F.dropout(x_h, p=self.p_dropout, training=self.training)\n",
    "            x_h = F.elu(layer(x_h, data.edge_index_higher_order, data.edge_weight_higher_order))\n",
    "        x_h = F.dropout(x_h, p=self.p_dropout, training=self.training)\n",
    "\n",
    "        # Bipartite message passing\n",
    "        x = torch.nn.functional.elu(self.bipartite_layer((x_h, x), data.bipartite_edge_index, x_h.size(0), x.size(0)))\n",
    "        x = F.dropout(x, p=self.p_dropout, training=self.training)\n",
    "\n",
    "        # Linear layer\n",
    "        x = self.lin(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SIT-GNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.nn import MessagePassing\n",
    "import torch\n",
    "\n",
    "class BipartiteGraphOperator(MessagePassing):\n",
    "    def __init__(self, in_ch_source, in_ch_target, out_ch):\n",
    "        super(BipartiteGraphOperator, self).__init__('add')\n",
    "        self.lin_source = Linear(in_ch_source, out_ch)\n",
    "        self.lin_target = Linear(in_ch_target, out_ch) if in_ch_target > 0 else None\n",
    "\n",
    "    def forward(self, x_source, x_target, bipartite_index):\n",
    "        x_source = self.lin_source(x_source)\n",
    "        x_target = torch.zeros((bipartite_index[1, :].max()+1, x_source.size(1))).to(x_source.device) if self.lin_target is None else self.lin_target(x_target)\n",
    "        return self.propagate(bipartite_index, size=(x_source.size(0), x_target.size(0)), x=(x_source, x_target))\n",
    "\n",
    "    def message(self, x_i, x_j):\n",
    "        return x_i + x_j\n",
    "    \n",
    "\n",
    "class SIT_GNN(Module): \n",
    "    def __init__(\n",
    "        self,\n",
    "        num_classes,\n",
    "        num_features,\n",
    "        num_edge_features,\n",
    "        hidden_dims,\n",
    "        p_dropout=0.0\n",
    "        ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.num_features = num_features\n",
    "        self.num_edge_features = num_edge_features\n",
    "        self.num_classes = num_classes\n",
    "        self.hidden_dims = hidden_dims\n",
    "        self.p_dropout = p_dropout\n",
    "\n",
    "        assert len(num_features) == 1\n",
    "        self.bipartite_layer_start = BipartiteGraphOperator(self.num_features[0], 0, self.num_features[0])\n",
    "\n",
    "        # higher-order layers\n",
    "        self.higher_order_layers = ModuleList()\n",
    "        self.higher_order_layers.append(GCNConv(self.num_features[0], self.hidden_dims[0]))\n",
    "\n",
    "        # first-order layers\n",
    "        self.first_order_layers = ModuleList()\n",
    "        self.first_order_layers.append(GCNConv(self.num_features[0], self.hidden_dims[0]))\n",
    "\n",
    "\n",
    "        for dim in range(1, len(self.hidden_dims)-1):\n",
    "            # higher-order layers\n",
    "            self.higher_order_layers.append(GCNConv(self.hidden_dims[dim-1], self.hidden_dims[dim]))\n",
    "            # first-order layers\n",
    "            self.first_order_layers.append(GCNConv(self.hidden_dims[dim-1], self.hidden_dims[dim]))\n",
    "\n",
    "        self.bipartite_layer = BipartiteGraphOperator(self.hidden_dims[-2], self.hidden_dims[-2], self.hidden_dims[-1])\n",
    "\n",
    "        # Linear layer\n",
    "        self.lin = torch.nn.Linear(self.hidden_dims[-1], num_classes)\n",
    "\n",
    "\n",
    "\n",
    "    def forward(self, data):\n",
    "\n",
    "        # Bipartite message passing\n",
    "        x = data.x\n",
    "        x_ = F.dropout(x, p=self.p_dropout, training=self.training)\n",
    "        x_h = torch.nn.functional.elu(self.bipartite_layer_start(x_, None, data.bipartite_edge_index_start))\n",
    "\n",
    "        edge_weight = data.edge_attr \n",
    "        edge_weight_higher_order = data.edge_attr_higher_order \n",
    "\n",
    "        # First-order convolutions\n",
    "        layer = self.first_order_layers[0]\n",
    "        x = F.dropout(x, p=self.p_dropout, training=self.training)\n",
    "        x = F.elu(layer(x, data.edge_index, edge_weight=edge_weight))\n",
    "\n",
    "        for layer in self.first_order_layers[1:]:\n",
    "            x = F.dropout(x, p=self.p_dropout, training=self.training)\n",
    "            x = F.elu(layer(x, data.edge_index, edge_weight=edge_weight))\n",
    "        x = F.dropout(x, p=self.p_dropout, training=self.training)\n",
    "\n",
    "        # Second-order convolutions\n",
    "        layer = self.higher_order_layers[0]\n",
    "        x_h = F.dropout(x_h, p=self.p_dropout, training=self.training)\n",
    "        x_h = F.elu(layer(x_h, data.edge_index_higher_order, edge_weight=edge_weight_higher_order))\n",
    "\n",
    "        for layer in self.higher_order_layers[1:]:\n",
    "            x_h = F.dropout(x_h, p=self.p_dropout, training=self.training)\n",
    "            x_h = F.elu(layer(x_h, data.edge_index_higher_order, edge_weight=edge_weight_higher_order))\n",
    "        x_h = F.dropout(x_h, p=self.p_dropout, training=self.training)\n",
    "\n",
    "        # Bipartite message passing\n",
    "        x = torch.nn.functional.elu(self.bipartite_layer(x_h, x, data.bipartite_edge_index))\n",
    "        x = F.dropout(x, p=self.p_dropout, training=self.training)\n",
    "\n",
    "        # Linear layer\n",
    "        x = self.lin(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation Procedure\n",
    "\n",
    "We use a 10 fold cross-validation and select the best epoch with a validation split.\n",
    "Test performance is reported afterwards. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(data, splits, build_model, epochs, seed):\n",
    "\n",
    "    results = []\n",
    "\n",
    "    for fold in range(10):\n",
    "    \n",
    "        # reset seeds used in training\n",
    "        torch.manual_seed(seed)\n",
    "        random.seed(seed)\n",
    "        np.random.seed(seed)\n",
    "\n",
    "        # calculate weights for loss function\n",
    "        def calculate_class_weights(y: torch.Tensor):\n",
    "            class_weights = y.to(torch.float).unique(sorted=True, return_counts=True)[1]\n",
    "            num_elements = class_weights.sum()\n",
    "            return 1.0-class_weights/num_elements\n",
    "        class_weights = calculate_class_weights(data.y).to(device)\n",
    "\n",
    "        # define loss function for training\n",
    "        loss_fn = CrossEntropyLoss(class_weights)\n",
    "\n",
    "        model = build_model()\n",
    "        optimizer = SGD(model.parameters(), lr=0.001, weight_decay=0.0005, momentum=0.9)\n",
    "\n",
    "        # obtain indices from splits\n",
    "        if splits == None:\n",
    "            split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=seed)\n",
    "            for train_valid_index, test_index in split.split(np.zeros(len(data.y.cpu())), data.y.cpu()):\n",
    "                break\n",
    "\n",
    "            split2 = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=seed)\n",
    "            for train_index, valid_index in split2.split(np.zeros(len(train_valid_index)), data.y[train_valid_index].cpu()):\n",
    "                train_index = train_valid_index[train_index]\n",
    "                valid_index = train_valid_index[valid_index]\n",
    "        else:\n",
    "            test_index = splits[fold]['test_index']\n",
    "            valid_index = splits[fold]['valid_index']\n",
    "            train_index = splits[fold]['train_index']\n",
    "\n",
    "        best_valid_loss = 100000\n",
    "        result = 0.0\n",
    "\n",
    "        # train this trial\n",
    "        for epoch in trange(epochs, position=0, desc=f\"learning fold {fold+1}\"):\n",
    "            model.train()\n",
    "\n",
    "            data = data.to(device)\n",
    "            output = model(data)\n",
    "            loss = loss_fn(output[train_index], data.y[train_index])\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            model.eval()\n",
    "            output = model(data)\n",
    "\n",
    "            valid_loss = loss_fn(output[valid_index], data.y[valid_index])\n",
    "            if valid_loss < best_valid_loss:\n",
    "                best_valid_loss = valid_loss\n",
    "                _, pred = output.max(dim=1)\n",
    "                result = balanced_accuracy_score(data.y[test_index].cpu(), pred[test_index].cpu())\n",
    "\n",
    "        results.append(result)\n",
    "\n",
    "    print(f\"Balanced Accuracy: {np.mean(results)} ± {np.std(results)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Execute Experiments\n",
    "\n",
    "### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(8388608.) tensor(7.0369e+13)\n",
      "enabling julia..."
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: import of MainInclude.eval into Main conflicts with an existing identifier; ignored.\n",
      "WARNING: could not import MainInclude.include into Main\n",
      "WARNING: import of MainInclude.eval into Main conflicts with an existing identifier; ignored.\n",
      "WARNING: could not import MainInclude.include into Main\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Done\n",
      "tensor(4194304.) tensor(1.7592e+13)\n"
     ]
    }
   ],
   "source": [
    "data = build_data(path_data, label_data, max_order=2, verbose=False, redistribute=True).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(splits_data, 'r') as file:\n",
    "    splits = json.loads(file.read())['splits']"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Models and Run Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SIT-GNN:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "learning fold 1: 100%|████████████████████████████████████| 5000/5000 [00:57<00:00, 87.66it/s]\n",
      "learning fold 2: 100%|████████████████████████████████████| 5000/5000 [00:52<00:00, 94.46it/s]\n",
      "learning fold 3: 100%|████████████████████████████████████| 5000/5000 [00:52<00:00, 94.44it/s]\n",
      "learning fold 4: 100%|████████████████████████████████████| 5000/5000 [00:50<00:00, 99.09it/s]\n",
      "learning fold 5: 100%|████████████████████████████████████| 5000/5000 [00:57<00:00, 86.41it/s]\n",
      "learning fold 6: 100%|████████████████████████████████████| 5000/5000 [00:57<00:00, 87.53it/s]\n",
      "learning fold 7: 100%|████████████████████████████████████| 5000/5000 [00:54<00:00, 91.24it/s]\n",
      "learning fold 8: 100%|███████████████████████████████████| 5000/5000 [00:46<00:00, 106.84it/s]\n",
      "learning fold 9: 100%|███████████████████████████████████| 5000/5000 [00:44<00:00, 113.06it/s]\n",
      "learning fold 10: 100%|██████████████████████████████████| 5000/5000 [00:45<00:00, 110.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Balanced Accuracy: 1.0 ± 0.0\n",
      "\n",
      "GCN:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "learning fold 1: 100%|███████████████████████████████████| 5000/5000 [00:25<00:00, 197.09it/s]\n",
      "learning fold 2: 100%|███████████████████████████████████| 5000/5000 [00:23<00:00, 214.19it/s]\n",
      "learning fold 3: 100%|███████████████████████████████████| 5000/5000 [00:25<00:00, 198.00it/s]\n",
      "learning fold 4: 100%|███████████████████████████████████| 5000/5000 [00:31<00:00, 160.54it/s]\n",
      "learning fold 5: 100%|███████████████████████████████████| 5000/5000 [00:30<00:00, 163.02it/s]\n",
      "learning fold 6: 100%|███████████████████████████████████| 5000/5000 [00:30<00:00, 166.21it/s]\n",
      "learning fold 7: 100%|███████████████████████████████████| 5000/5000 [00:29<00:00, 167.33it/s]\n",
      "learning fold 8: 100%|███████████████████████████████████| 5000/5000 [00:29<00:00, 171.37it/s]\n",
      "learning fold 9: 100%|███████████████████████████████████| 5000/5000 [00:26<00:00, 186.52it/s]\n",
      "learning fold 10: 100%|██████████████████████████████████| 5000/5000 [00:29<00:00, 167.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Balanced Accuracy: 0.5 ± 0.31622776601683794\n",
      "\n",
      "DBGNN:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "learning fold 1: 100%|███████████████████████████████████| 5000/5000 [00:49<00:00, 100.17it/s]\n",
      "learning fold 2: 100%|███████████████████████████████████| 5000/5000 [00:44<00:00, 113.17it/s]\n",
      "learning fold 3: 100%|███████████████████████████████████| 5000/5000 [00:42<00:00, 116.74it/s]\n",
      "learning fold 4: 100%|███████████████████████████████████| 5000/5000 [00:41<00:00, 120.21it/s]\n",
      "learning fold 5: 100%|███████████████████████████████████| 5000/5000 [00:43<00:00, 113.80it/s]\n",
      "learning fold 6: 100%|███████████████████████████████████| 5000/5000 [00:41<00:00, 120.60it/s]\n",
      "learning fold 7: 100%|███████████████████████████████████| 5000/5000 [00:43<00:00, 115.16it/s]\n",
      "learning fold 8: 100%|███████████████████████████████████| 5000/5000 [00:43<00:00, 113.82it/s]\n",
      "learning fold 9: 100%|███████████████████████████████████| 5000/5000 [00:49<00:00, 102.03it/s]\n",
      "learning fold 10: 100%|██████████████████████████████████| 5000/5000 [00:46<00:00, 107.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Balanced Accuracy: 0.5 ± 0.31622776601683794\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def sit_gnn():\n",
    "    return SIT_GNN(\n",
    "            num_classes         = classes,\n",
    "            num_features        = [data.x.size(-1)],\n",
    "            num_edge_features   = [data.edge_attr.size(-1), data.edge_attr_higher_order.size(-1)],\n",
    "            hidden_dims         = [32, 32, 16],\n",
    "            p_dropout           = 0.4\n",
    "    ).to(device)\n",
    "def dbgnn():\n",
    "    return DBGNN(\n",
    "            num_classes         = classes,\n",
    "            num_features        = [data.x.size(-1), data.x_h.size(-1)],\n",
    "            hidden_dims         = [32, 32, 16],\n",
    "            p_dropout           = 0.4\n",
    "    ).to(device)\n",
    "def gcn():\n",
    "    return GCN(\n",
    "            num_classes         = classes,\n",
    "            num_features        = data.x.size(-1),\n",
    "            hidden_dims         = [32, 32, 16],\n",
    "            p_dropout           = 0.4\n",
    "    ).to(device)\n",
    "\n",
    "\n",
    "print(\"SIT-GNN:\")\n",
    "evaluate(data, splits, sit_gnn, epochs=5000, seed=0)\n",
    "print(\"\\nGCN:\")\n",
    "evaluate(data, splits, gcn, epochs=5000, seed=0)\n",
    "print(\"\\nDBGNN:\")\n",
    "evaluate(data, splits, dbgnn, epochs=5000, seed=0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
