{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np, matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {},
   "outputs": [],
   "source": [
    "#graph_id = 'sf'\n",
    "#graph_id = 'gt2'\n",
    "#graph_id = '3d'\n",
    "#graph_id = 'to'\n",
    "#graph_id = 'to'\n",
    "\n",
    "graph_id = '3dd'\n",
    "# graph_id = 'taxi'\n",
    "# graph_id = '3d'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fix Edgelist so that all connected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125000 375000\n"
     ]
    }
   ],
   "source": [
    "import networkx as nx\n",
    "G = nx.read_edgelist('./data/{}.edgelist'.format(graph_id), nodetype=int, \n",
    "                     data=(('weight',float),), create_using=nx.DiGraph())\n",
    "# G = G.to_undirected()\n",
    "G.number_of_nodes(), G. number_of_edges()\n",
    "\n",
    "def connected_component_subgraphs(G):\n",
    "  for c in nx.connected_components(G):\n",
    "    yield G.subgraph(c)\n",
    "\n",
    "# Gs = list(sorted(connected_component_subgraphs(G), key = lambda g: g.number_of_nodes(), reverse=True))\n",
    "# G = Gs[0]\n",
    "print(G.number_of_nodes(), G. number_of_edges())\n",
    "\n",
    "# nx.write_weighted_edgelist(G, './data/{}connected.edgelist'.format(graph_id))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Make dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np, matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import networkx as nx\n",
    "import multiprocessing as mp, tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx.read_edgelist('./data/{}connected.edgelist'.format(graph_id), nodetype=int, data=(('weight',float),), create_using=nx.DiGraph())\n",
    "G = G.to_undirected()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "nodes = np.array(G.nodes())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(125000, 375000)"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G.number_of_nodes(), G.number_of_edges()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<networkx.classes.graph.Graph at 0x7fbb33f0ccc0>"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [],
   "source": [
    "def keywithmaxval(d):\n",
    "  \"\"\" a) create a list of the dict's keys and values; \n",
    "     b) return the key with the max value\"\"\"  \n",
    "  v=list(d.values())\n",
    "  k=list(d.keys())\n",
    "  return k[v.index(max(v))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import deque"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 45/45 [00:30<00:00,  1.47it/s]\n"
     ]
    }
   ],
   "source": [
    "N_LANDMARKS = 15\n",
    "OPTIMIZATION_TRIES = 45\n",
    "landmarks = deque(np.random.choice(nodes, (1,), replace=False), N_LANDMARKS)\n",
    "for i in tqdm.tqdm(range(OPTIMIZATION_TRIES)):\n",
    "  dists = nx.multi_source_dijkstra_path_length(G, landmarks)\n",
    "  landmarks.append(keywithmaxval(dists))\n",
    "landmark_dists = [nx.single_source_dijkstra_path_length(G, l) for l in landmarks]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(landmarks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "def heuristic(u, v):\n",
    "  lbs = [l[u] - l[v] for l in landmark_dists]\n",
    "  lbs += [l[v] - l[u] for l in landmark_dists]\n",
    "  return max(lbs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(150000, 2)\n",
      "[[ 84038  41395]\n",
      " [119928 101619]\n",
      " [119609  86545]]\n"
     ]
    }
   ],
   "source": [
    "# We will make a total of up to 155K data points, then pick 150K (since we might prune some)\n",
    "N = 155000\n",
    "data_pairs = np.random.choice(nodes, size=(int(N), 2))\n",
    "data_pairs = data_pairs[data_pairs[:,0] != data_pairs[:,1]] # prevent self connections\n",
    "data_pairs = np.array(list(set([tuple(d) for d in data_pairs]))) # prevent duplicates \n",
    "data_pairs = data_pairs[:150000]\n",
    "\n",
    "print(data_pairs.shape)\n",
    "print(data_pairs[:3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mp_func(pair):\n",
    "  return nx.astar_path_length(G, source=pair[0], target=pair[1], heuristic=heuristic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 250\n",
    "data_pair_arrays = np.split(data_pairs, EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 250/250 [40:23<00:00,  9.69s/it]\n"
     ]
    }
   ],
   "source": [
    "dists = []\n",
    "for epoch in tqdm.tqdm(range(EPOCHS)):\n",
    "  with mp.Pool(10) as pool:\n",
    "    dists += pool.map(mp_func, data_pair_arrays[epoch])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((150000,), (150000,), (150000,))"
      ]
     },
     "execution_count": 107,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = data_pairs[:,0]\n",
    "Y = data_pairs[:,1]\n",
    "D = np.array(dists)\n",
    "X.shape, Y.shape, D.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "75.0"
      ]
     },
     "execution_count": 108,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./data/{}_150k.pickle'.format(graph_id), 'wb') as f:\n",
    "  pickle.dump((X, Y, D), f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./data/{}_150k.pickle'.format(graph_id), 'rb') as f:\n",
    "  X, Y, D = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((150000,), (150000,), (150000,))"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape, Y.shape, D.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "37.468626666666665"
      ]
     },
     "execution_count": 112,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "75.0"
      ]
     },
     "execution_count": 113,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [],
   "source": [
    "D = D * (50. / np.mean(D))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "49.99999999999999"
      ]
     },
     "execution_count": 115,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100.08373227450379"
      ]
     },
     "execution_count": 117,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open('./data/{}_150k.pickle'.format(graph_id), 'wb') as f:\n",
    "  pickle.dump((X, Y, D), f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3d1\n"
     ]
    }
   ],
   "source": [
    "print(graph_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
