{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/sam/bf-nn\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/sam/anaconda3/envs/py10-coreset/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: using dhist requires you to install the `pickleshare` library.\n",
      "  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
     ]
    }
   ],
   "source": [
    "cd /home/sam/bf-nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "from src.dataset import *\n",
    "import torch\n",
    "from torch.nn import Linear, Parameter, ReLU, LeakyReLU\n",
    "import torch.nn as nn\n",
    "from torch_geometric.nn import MessagePassing\n",
    "from torch_geometric.utils import add_self_loops, degree\n",
    "from src.models import *\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0, 1, 1, 2, 2, 3, 0, 1, 2, 3],\n",
      "        [1, 0, 2, 1, 3, 2, 0, 1, 2, 3]])\n",
      "tensor([[2.9134],\n",
      "        [2.9134],\n",
      "        [6.3483],\n",
      "        [6.3483],\n",
      "        [0.0000],\n",
      "        [0.0000],\n",
      "        [0.0000],\n",
      "        [0.0000],\n",
      "        [0.0000],\n",
      "        [0.0000]])\n",
      "tensor([[0, 1, 1, 2, 2, 3, 0, 1, 2, 3],\n",
      "        [1, 0, 2, 1, 3, 2, 0, 1, 2, 3]])\n",
      "tensor([[4.2814],\n",
      "        [4.2814],\n",
      "        [0.0000],\n",
      "        [0.0000],\n",
      "        [2.2058],\n",
      "        [2.2058],\n",
      "        [0.0000],\n",
      "        [0.0000],\n",
      "        [0.0000],\n",
      "        [0.0000]])\n"
     ]
    }
   ],
   "source": [
    "K=2\n",
    "for i in range(K):\n",
    "    edge_weights = np.zeros(K+1)\n",
    "    a, b = np.random.uniform(low=1.0, high=10.0, size=2)\n",
    "    edge_weights[0] = a\n",
    "    edge_weights[i + 1] = b\n",
    "    graph = construct_k_path(K+2, edge_weights, s=0)\n",
    "    data = nx_to_bf_instance(graph, m=3, start=0)\n",
    "    print(data.edge_index)\n",
    "    print(data.edge_attr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_test_dataset(K):\n",
    "    edge_set_weights = np.zeros(K)\n",
    "    # Cross edge graphs\n",
    "    for i in range(K):\n",
    "        graph = construct_cross(K, edge_set_weights, i, [1, 1], s=0)\n",
    "        data = nx_to_bf_instance(graph, m=2, start=0)\n",
    "    \n",
    "    # paths\n",
    "    for i in range(K):\n",
    "        edge_weights = np.zeros(K + 1)\n",
    "        a, b = np.random.uniform(low=1.0, high=10.0, size=2)\n",
    "        edge_weights[0] = a\n",
    "        edge_weights[i + 1] = b\n",
    "        graph= construct_k_path(K+2, edge_weights, 0)\n",
    "        data = nx_to_bf_instance(graph, m=2, start=1)\n",
    "    \n",
    "    # single edge\n",
    "    graph = construct_k_path(2, [1.0], 0)\n",
    "    data = nx_to_bf_instance(graph, m=1, start = 0)\n",
    "\n",
    "    # two edges \n",
    "    graph = construct_k_cycle(3, [1.0, 0.0], 0.0)\n",
    "    data = nx_to_bf_instance(graph, m=1, start = 0)\n",
    "\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = construct_extended_small_graph_dataset(2, size=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16\n"
     ]
    }
   ],
   "source": [
    "print(len(dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_cross(num_edges, edges, cross_idx, cross_edge_weights, s=0):\n",
    "    G = nx.Graph()\n",
    "    for i in range(num_edges):\n",
    "        G.add_edge(i, i+1, weight = edges[i])\n",
    "        G.add_edge(num_edges+i + 1 , num_edges+i + 2, weight=edges[i])\n",
    "    \n",
    "    G.add_edge(cross_idx, num_edges + 1 +cross_idx+1, weight=cross_edge_weights[0])\n",
    "    G.add_edge(cross_idx + 1, num_edges + 1 + cross_idx, weight=cross_edge_weights[0])\n",
    "    nx.set_node_attributes(G, values=1000, name='attr')\n",
    "    G.nodes[s]['attr'] = 0.0\n",
    "    return G\n",
    "\n",
    "def construct_full_cross_dataset(path_length, cross_w_1, cross_w_2):\n",
    "    dataset = []\n",
    "    for j in range(path_length ):\n",
    "        edge_weights = np.abs( np.random.normal(loc=0.0, scale=1.0, size=path_length + 1))\n",
    "        H = construct_cross(path_length , edge_weights, j, cross_w_1)\n",
    "        H_bar = construct_cross(path_length , edge_weights, j, cross_w_2)\n",
    "        num_nodes = len(H.nodes)\n",
    "\n",
    "        data_H  = nx_to_bf_instance(H, num_nodes, start=0)\n",
    "        dataset.append(data_H)\n",
    "        #print(data_H.edge_index)\n",
    "\n",
    "        data_H_bar = nx_to_bf_instance(H_bar, num_nodes, start=0)\n",
    "        dataset.append(data_H_bar)\n",
    "        #print(data_H_bar.edge_index)\n",
    "    return dataset\n",
    "\n",
    "\n",
    "\n",
    "def construct_ktrain_dataset(K, sz=5):\n",
    "    dataset = []\n",
    "    \n",
    "    calc_total = 2 * (K * (K + 1)/2) + 2 * K \n",
    "    print(calc_total)\n",
    "    runs = int(sz / calc_total) + 1\n",
    "    for _ in range(runs):\n",
    "        for i in trange(1, K + 1):\n",
    "            cross_data = construct_full_cross_dataset(i, [0, 0], [1, 1])\n",
    "            dataset.extend(cross_data)\n",
    "            print(i, len(cross_data))\n",
    "            edge_weights = np.abs( np.random.normal(loc=0.0, scale=1.0, size=i + 1))\n",
    "            dual_path = construct_two_paths(i + 1, edge_weights)\n",
    "            data = nx_to_bf_instance(dual_path, i + 1)\n",
    "            \n",
    "            dataset.append(data)\n",
    "            ew = np.zeros(i)\n",
    "            \n",
    "            ew[0] = np.random.uniform(low=1.0, high=30.0)\n",
    "            path = construct_k_path(i + 1, ew)\n",
    "            data = nx_to_bf_instance(path, i)\n",
    "            dataset.append(data)\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4/4 [00:00<00:00, 240.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 2\n",
      "2 4\n",
      "3 6\n",
      "4 8\n",
      "tensor([[0.0000e+00],\n",
      "        [4.2007e-01],\n",
      "        [1.4730e+00],\n",
      "        [1.0000e+03],\n",
      "        [1.0000e+03]])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "data = construct_ktrain_dataset(4, sz=1)\n",
    "print(data[2].y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_cross(num_edges, edges, cross_idx, cross_edge_weights, s=0):\n",
    "    G = nx.Graph()\n",
    "    for i in range(num_edges):\n",
    "        G.add_edge(i, i+1, weight = edges[i])\n",
    "        G.add_edge(num_edges+i + 1 , num_edges+i + 2, weight=edges[i])\n",
    "    \n",
    "    G.add_edge(cross_idx, num_edges + 1 +cross_idx+1, weight=cross_edge_weights[0])\n",
    "    G.add_edge(cross_idx + 1, num_edges + 1 + cross_idx, weight=cross_edge_weights[0])\n",
    "    nx.set_node_attributes(G, values=1000, name='attr')\n",
    "    G.nodes[s]['attr'] = 0.0\n",
    "    return G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0, 1)\n",
      "(1, 2)\n",
      "(1, 5)\n",
      "(3, 4)\n",
      "(4, 5)\n",
      "(4, 2)\n"
     ]
    }
   ],
   "source": [
    "G = construct_cross(2, [1, 1, 1], 1, [1, 1], s=0)\n",
    "for edge in G.edges:\n",
    "    print(edge)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 1, 0, 2, 1, 2, 1, 1, 0, 1, 2],\n",
       "        [1, 0, 2, 0, 2, 1, 1, 1, 0, 1, 2]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "construct_full_cross_dataset(1, [0, 0], [1, 1])[0].edge_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResidualLayer(MessagePassing):\n",
    "    def __init__(self, aggregation_config, update_config, act = 'ReLU' , **kwargs ):\n",
    "        super(ResidualLayer, self).__init__()\n",
    "\n",
    "        self.aggregation_layer = initialize_mlp(**aggregation_config)\n",
    "         \n",
    "        self.update_layer = initialize_mlp(**update_config)\n",
    "        self.act = globals()[act]()\n",
    "\n",
    "    def forward(self, x, edge_index, edge_attr):\n",
    "        from_aggregation = self.propagate(edge_index, x=x, edge_attr=edge_attr)\n",
    "        update_ = torch.cat((from_aggregation, x), dim=1)\n",
    "        return self.act(self.update_layer(update_))\n",
    "    \n",
    "    def message(self, x_j, edge_attr):\n",
    "        return self.act(self.aggregation_layer(torch.cat((x_j, edge_attr), dim=-1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "aggregation_parameters = {'input': 2, 'hidden': 10, 'output': 5, 'layers': 4}\n",
    "update_parameters = {'input': 10, 'hidden': 10, 'output': 6, 'layers': 2}\n",
    "\n",
    "layer = ResidualLayer(aggregation_parameters, update_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 5])\n",
      "torch.Size([2, 6])\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "mat1 and mat2 shapes cannot be multiplied (2x5 and 10x10)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[11], line 4\u001b[0m\n\u001b[1;32m      1\u001b[0m G \u001b[39m=\u001b[39m construct_k_path(\u001b[39m2\u001b[39m, [\u001b[39m2\u001b[39m])\n\u001b[1;32m      2\u001b[0m data \u001b[39m=\u001b[39m nx_to_bf_instance(G, \u001b[39m2\u001b[39m, start \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m layer(data\u001b[39m.\u001b[39;49mx, data\u001b[39m.\u001b[39;49medge_index, data\u001b[39m.\u001b[39;49medge_attr)\n",
      "File \u001b[0;32m~/anaconda3/envs/py10-coreset/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)  \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
      "File \u001b[0;32m~/anaconda3/envs/py10-coreset/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n",
      "Cell \u001b[0;32mIn[9], line 15\u001b[0m, in \u001b[0;36mResidualLayer.forward\u001b[0;34m(self, x, edge_index, edge_attr)\u001b[0m\n\u001b[1;32m     13\u001b[0m update_ \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat((from_aggregation, x), dim\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[1;32m     14\u001b[0m \u001b[39mprint\u001b[39m(update_\u001b[39m.\u001b[39msize())\n\u001b[0;32m---> 15\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mact(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mupdate_layer(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpropagate(edge_index, x\u001b[39m=\u001b[39;49mx, edge_attr\u001b[39m=\u001b[39;49medge_attr)))\n",
      "File \u001b[0;32m~/anaconda3/envs/py10-coreset/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)  \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
      "File \u001b[0;32m~/anaconda3/envs/py10-coreset/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/py10-coreset/lib/python3.10/site-packages/torch/nn/modules/container.py:215\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    213\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m):\n\u001b[1;32m    214\u001b[0m     \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m:\n\u001b[0;32m--> 215\u001b[0m         \u001b[39minput\u001b[39m \u001b[39m=\u001b[39m module(\u001b[39minput\u001b[39;49m)\n\u001b[1;32m    216\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39minput\u001b[39m\n",
      "File \u001b[0;32m~/anaconda3/envs/py10-coreset/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)  \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
      "File \u001b[0;32m~/anaconda3/envs/py10-coreset/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/py10-coreset/lib/python3.10/site-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Tensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m     \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39;49mlinear(\u001b[39minput\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias)\n",
      "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (2x5 and 10x10)"
     ]
    }
   ],
   "source": [
    "G = construct_k_path(2, [2])\n",
    "data = nx_to_bf_instance(G, 2, start = 1)\n",
    "\n",
    "layer(data.x, data.edge_index, data.edge_attr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BFModel(width=2, depth=2, bias=True, l0_regularizer=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BellmanFordStep(x=[2, 1], edge_index=[2, 4], edge_attr=[4, 1], y=[2, 1])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[23.6599],\n",
       "        [23.6599]], grad_fn=<ReluBackward0>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G = construct_k_path(2, [2])\n",
    "data1 = nx_to_bf_instance(G, 2, start = 1)\n",
    "print(data1)\n",
    "\n",
    "model(data1.x, data1.edge_index, data1.edge_attr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(2.3906, grad_fn=<AddBackward0>)\n",
      "tensor(3.7470, grad_fn=<AddBackward0>)\n"
     ]
    }
   ],
   "source": [
    "for layer in model.module_list:\n",
    "    print(layer.get_l0_regularization_term())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "class L0Linear(nn.Module):\n",
    "    def __init__(self, in_features, out_features, bias=True):\n",
    "        super(L0Linear, self).__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "\n",
    "        self.weight = Parameter(torch.Tensor(out_features, in_features))\n",
    "        self.include_bias = bias\n",
    "        if bias:\n",
    "            self.bias = Parameter(torch.Tensor(out_features))\n",
    "        self.mu_weight = nn.Parameter(torch.normal(mean=torch.zeros((out_features, in_features)), \n",
    "                                                   std=torch.ones((out_features, in_features))))\n",
    "        self.mu_bias = nn.Parameter(torch.normal(mean=torch.zeros(out_features), \n",
    "                                                 std=torch.ones(out_features)))\n",
    "    \n",
    "    def forward(self, x):\n",
    "        if self.training:\n",
    "            epsilon_weight = torch.normal(mean=torch.zeros((self.out_features, self.in_features)), \n",
    "                                          std = torch.ones((self.out_features, self.in_features)))\n",
    "            epsilon_bias = torch.normal(mean=torch.zeros((self.out_features)), \n",
    "                                        std = torch.ones((self.out_features)))\n",
    "        else: \n",
    "            epsilon_weight = torch.zeros((self.width, 2))\n",
    "        z_weight = nn.functional.relu(torch.minimum(self.mu_weight + epsilon_weight, \n",
    "                                                    torch.ones((self.out_features, self.in_features))))\n",
    "        z_bias = nn.functional.relu(torch.minimum(self.mu_bias + epsilon_bias, \n",
    "                                                    torch.ones(self.out_features)))\n",
    "        out = torch.matmul(x, (self.weight * z_weight).T) + (self.bias * z_bias)\n",
    "        return out\n",
    "\n",
    "class SingleLayerArbitraryWidthBFLayer(MessagePassing):\n",
    "    def __init__(self, \n",
    "                 width=2, \n",
    "                 bias=False, \n",
    "                 act='ReLU', \n",
    "                 l0_regularizer=False, \n",
    "                 **kwargs):\n",
    "        super().__init__(aggr='min')\n",
    "        if l0_regularizer:\n",
    "            self.W_1 = L0Linear(in_features=2, out_features=2, bias=bias)\n",
    "            self.W_2 = L0Linear(in_features=width, out_features=1, bias=bias)\n",
    "        else:\n",
    "            self.W_1 = Linear(in_features = 2, out_features = width, bias=bias)\n",
    "            self.W_2 = Linear(in_features=width, out_features=1, bias=bias)\n",
    "\n",
    "        # self.act = ReLU()\n",
    "        self.act = globals()[act]()\n",
    "\n",
    "        if bias:\n",
    "            self.bias = True\n",
    "        else:\n",
    "            self.bias = False\n",
    "    \n",
    "    def random_init(self):\n",
    "        torch.nn.init.normal_(self.W_1.weight, 0.0, 1.0)\n",
    "        \n",
    "        torch.nn.init.normal_(self.W_2.weight, 0.0, 1.0)\n",
    "        \n",
    "        if self.bias:\n",
    "            torch.nn.init.normal_(self.W_1.bias, 0.0, 1.0)\n",
    "            torch.nn.init.normal_(self.W_2.bias, 0.0, 1.0)\n",
    "\n",
    "    \n",
    "    def random_positive_init(self):\n",
    "        torch.nn.init.uniform_(self.W_1.weight, a = 0.0, b=1.0)        \n",
    "        torch.nn.init.uniform_(self.W_2.weight, a = 0.0, b=1.0)\n",
    "        \n",
    "        if self.bias:\n",
    "            torch.nn.init.uniform_(self.W_1.bias, a = 0.0, b=1.0)\n",
    "            torch.nn.init.uniform_(self.W_2.bias, a = 0.0, b=1.0)\n",
    "    \n",
    "\n",
    "    def forward(self, x, edge_index, edge_attr):\n",
    "        return self.act(self.W_2(self.propagate(edge_index, x=x, edge_attr=edge_attr)))\n",
    "    \n",
    "    def message(self, x_j, edge_attr):\n",
    "        return self.act(self.W_1(torch.cat((x_j, edge_attr), dim=-1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BellmanFordStep(x=[2, 1], edge_index=[2, 4], edge_attr=[4, 1], y=[2, 1])\n"
     ]
    }
   ],
   "source": [
    "G = construct_k_path(2, [2])\n",
    "data1 = nx_to_bf_instance(G, 2, start = 1)\n",
    "print(data1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "l0layer = SingleLayerArbitraryWidthBFLayer(width = 2, \n",
    "                                         bias=True, \n",
    "                                         act='ReLU', \n",
    "                                         l0_regularizer=True)\n",
    "\n",
    "l0layer.fixed_init()\n",
    "\n",
    "layer = SingleLayerArbitraryWidthBFLayer(width = 2, \n",
    "                                         bias=True, \n",
    "                                         act='ReLU', \n",
    "                                         l0_regularizer=False)\n",
    "layer.fixed_init()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[3.],\n",
       "        [7.]], grad_fn=<ReluBackward0>)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer(data1.x, data1.edge_index, data1.edge_attr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.],\n",
       "        [0.]], grad_fn=<ReluBackward0>)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l0layer(data1.x, data1.edge_index, data1.edge_attr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Parameter containing:\n",
      "tensor([[1., 1.],\n",
      "        [1., 1.]], requires_grad=True), Parameter containing:\n",
      "tensor([1., 1.], requires_grad=True), Parameter containing:\n",
      "tensor([[-0.2874, -1.2583],\n",
      "        [-0.1910,  0.7927]], requires_grad=True), Parameter containing:\n",
      "tensor([-0.0817, -1.4638], requires_grad=True), Parameter containing:\n",
      "tensor([[1., 1.]], requires_grad=True), Parameter containing:\n",
      "tensor([1.], requires_grad=True), Parameter containing:\n",
      "tensor([[-0.1801, -0.4578]], requires_grad=True), Parameter containing:\n",
      "tensor([0.9373], requires_grad=True)]\n"
     ]
    }
   ],
   "source": [
    "print(list(l0layer.parameters()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "vec = torch.tensor([[0, 1], [1, 1], [3, 1]], dtype=torch.float32)\n",
    "l0linear = L0Linear(in_features=2, out_features=5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[2.7497e-34, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.5649e-41],\n",
       "        [2.7497e-34, 1.7433e-09, 4.8864e-10, 1.3706e-09, 4.5649e-41],\n",
       "        [2.7497e-34, 5.2298e-09, 1.4659e-09, 4.1118e-09, 4.5649e-41]],\n",
       "       grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l0linear(vec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[39], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m val \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m\n\u001b[0;32m----> 2\u001b[0m \u001b[39massert\u001b[39;00m val \u001b[39m==\u001b[39m \u001b[39m1\u001b[39m\n",
      "\u001b[0;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "val = 0\n",
    "assert val == 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def three_node_path(edge_val, beta = 100):\n",
    "    x = torch.tensor([[0], [edge_val], [beta]], dtype=torch.float)\n",
    "    edge_index = torch.tensor([[0, 1, 1, 2, 0, 1, 2], \n",
    "                               [1, 0, 2, 1, 0, 1, 2]])\n",
    "    edge_attr = torch.unsqueeze(torch.tensor([edge_val, edge_val, 0, 0, 0, 0, 0], dtype=torch.float), dim=1)\n",
    "    y = torch.tensor([[0], [edge_val], [edge_val]], dtype=torch.float)\n",
    "    return x, edge_index, edge_attr, y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_m_path_dataset(m, sz=5):\n",
    "    dataset = []\n",
    "    for k in range(1, m):\n",
    "        if k > 1:\n",
    "            combinations = list(itertools.combinations(range(5*k), 2))\n",
    "       # sz = len(combinations) if k > 1 else 2\n",
    "        for i in range(sz):\n",
    "            edge_weight = np.zeros(k)\n",
    "            if k == 1:\n",
    "                edge_weight[k - 1] = i\n",
    "            else:\n",
    "                edge_weight[0] = combinations[i][0]\n",
    "                edge_weight[k - 1] = combinations[i][1]\n",
    "            G = construct_k_path(k + 1, edge_weight, s=0)\n",
    "            data = nx_to_bf_instance(G, k, start = 0)\n",
    "            dataset.append(data)\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  0.],\n",
      "        [200.]]) tensor([[0.],\n",
      "        [2.]]) tensor([[2.],\n",
      "        [2.],\n",
      "        [0.],\n",
      "        [0.]])\n"
     ]
    }
   ],
   "source": [
    "data = construct_m_path_dataset(5)\n",
    "print(data[2].x, data[2].y, data[2].edge_attr)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'G' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mprint\u001b[39m(G[\u001b[39m0\u001b[39m][\u001b[39m1\u001b[39m][\u001b[39m'\u001b[39m\u001b[39mweight\u001b[39m\u001b[39m'\u001b[39m])\n",
      "\u001b[0;31mNameError\u001b[0m: name 'G' is not defined"
     ]
    }
   ],
   "source": [
    "print(G[0][1]['weight'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn import Linear, Parameter, ReLU\n",
    "from torch_geometric.nn import MessagePassing\n",
    "from torch_geometric.utils import add_self_loops, degree\n",
    "\n",
    "class BFLayer(MessagePassing):\n",
    "    def __init__(self):\n",
    "        super().__init__(aggr='min')\n",
    "        self.W_1 = Linear(2, 1)\n",
    "        self.relu = ReLU()\n",
    "        self.w_2 = Parameter(torch.rand(1))\n",
    "        self.b_2 = Parameter(torch.rand(1))\n",
    "\n",
    "    def forward(self, x, edge_index, edge_attr):        \n",
    "        return self.relu(self.w_2 * self.propagate(edge_index, x=x, edge_attr=edge_attr) + self.b_2)\n",
    "\n",
    "    def message(self, x_j, edge_attr):\n",
    "\n",
    "        return self.relu(self.W_1(torch.cat((x_j, edge_attr), dim=-1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn_bf = BFLayer()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.tensor([[20], [0], [20]], dtype=torch.float)\n",
    "edge_indx = torch.tensor([[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]])\n",
    "edge_attr = torch.unsqueeze(torch.tensor([5, 5, 10, 10, 0, 0, 0], dtype=torch.float), dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 5.],\n",
      "        [ 5.],\n",
      "        [10.],\n",
      "        [10.],\n",
      "        [ 0.],\n",
      "        [ 0.],\n",
      "        [ 0.]])\n"
     ]
    }
   ],
   "source": [
    "print(edge_attr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 5.],\n",
       "        [ 0.],\n",
       "        [10.]], grad_fn=<ReluBackward0>)"
      ]
     },
     "execution_count": 133,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn_bf(x, edge_indx, edge_attr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py10-coreset",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "e071539cccef4c0bb4ed46693789f6471484fd3e421d82529124a3bb2524ec50"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
