{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import models\n",
    "import utils\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "import sys\n",
    "sys.path.append('./../dataSP/')\n",
    "import data_utils\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import time\n",
    "\n",
    "sys.path.append('./../python-astar/')\n",
    "import astar\n",
    "\n",
    "import jgrapht\n",
    "import jgrapht.generators as gen\n",
    "import jgrapht.algorithms.shortestpaths as sp\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_n = 0\n",
    "\n",
    "path_data = './cabspotting_preprocessing/'\n",
    "\n",
    "df_features = pd.read_csv(f'{path_data}features_per_trip_useful.csv')\n",
    "df_trips = pd.read_csv(f'{path_data}full_useful_trips.csv')\n",
    "df_edges = pd.read_csv(f'{path_data}graph_0010_080.csv')\n",
    "df_nodes = pd.read_csv(f'{path_data}nodes_0010_080.csv')\n",
    "df_nodes['node_sorted'] = df_nodes['node_id_new']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_trips = df_trips.iloc[:300000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_drivers = df_trips['driver'].drop_duplicates()\n",
    "selected_drivers = unique_drivers.sample(frac=0.7, random_state=seed_n)\n",
    "df_trips_train = df_trips[df_trips['driver'].isin(selected_drivers)]\n",
    "\n",
    "df_trips_train = df_trips_train[df_trips_train.groupby('trip_id_new').node_id.transform('nunique')>1]\n",
    "df_trips_train = df_trips_train.sort_values(by=['driver','trip_id_new','date_time'])\n",
    "df_features['day_of_Week'] = df_features['day_of_Week'].astype(int).map({\n",
    "    0: 0, 1: 0, 2: 0, 3: 0, 4: 1,\n",
    "    5: 2, \n",
    "    6: 3 })\n",
    "\n",
    "df_features = pd.get_dummies(df_features, columns=['day_of_Week'])\n",
    "df_features['time_start'] = (df_features['time_start'] - df_features['time_start'].min()) / (df_features['time_start'].max() - df_features['time_start'].min())\n",
    "\n",
    "indices_trips = df_trips_train[['trip_id','driver','trip_id_new']].drop_duplicates()\n",
    "df_features_train = indices_trips.merge(df_features, on=['trip_id','driver'], how='left')\n",
    "df_features_train.iloc[:,-4:] = df_features_train.iloc[:,-4:].astype(int)\n",
    "feats = ['day_of_Week_0','day_of_Week_1','day_of_Week_2','day_of_Week_3',\n",
    "         'is_Holiday','time_start']\n",
    "n_features = len(feats)\n",
    "n_trips_train = len(df_features_train)\n",
    "\n",
    "\n",
    "prior_M, edges_prior, M_indices = data_utils.get_prior_and_M_indices(\n",
    "    df_nodes, df_edges)\n",
    "\n",
    "assert (df_trips_train.trip_id_new.unique() == df_features_train.trip_id_new.unique()).all()\n",
    "\n",
    "trip_ids = df_trips_train.trip_id_new.unique()\n",
    "\n",
    "V = M_indices.max()+1\n",
    "\n",
    "X_np = np.array(df_features_train[feats])\n",
    "node_idx_sequence_trips = df_trips_train.groupby('trip_id_new')['node_id'].apply(list)\n",
    "\n",
    "edges_seq_original = node_idx_sequence_trips.apply(\n",
    "    lambda x: np.column_stack([x[:-1], x[1:]]))\n",
    "start_nodes_original = node_idx_sequence_trips.apply(\n",
    "    lambda x: x[0])\n",
    "end_nodes_original = node_idx_sequence_trips.apply(\n",
    "    lambda x: x[-1])\n",
    "\n",
    "edges_idx_on_original = np.zeros((len(edges_seq_original), \n",
    "                                  len(M_indices)), dtype=int)\n",
    "edges_seq_original_np = np.array(edges_seq_original)\n",
    "\n",
    "N_train = len(edges_seq_original)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing Data\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████| 9970/9970 [00:08<00:00, 1214.37it/s]\n"
     ]
    }
   ],
   "source": [
    "print('Processing Data')\n",
    "for i in tqdm(range(len(edges_seq_original))):\n",
    "    matching_indices = []\n",
    "    for row in edges_seq_original_np[i]:\n",
    "        idx = np.where(np.isin(M_indices[:,0], row[0])\\\n",
    "                       *np.isin(M_indices[:,1], row[1]))[0].item()\n",
    "        edges_idx_on_original[i, idx] = 1\n",
    "\n",
    "edges_seq_original = list(edges_seq_original)\n",
    "node_idx_sequence_trips = list(node_idx_sequence_trips)\n",
    "\n",
    "end_to_end_nodes_original = np.vstack((\n",
    "    np.array(start_nodes_original), \n",
    "    np.array(end_nodes_original))).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "heurist = prior_M[M_indices[:,0], M_indices[:,1]]\n",
    "edge_costs = np.random.randint(1,100, M_indices.numpy().shape[0])/10. + heurist.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "heur = prior_M.numpy()\n",
    "heur = np.array(heur, dtype=np.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "edges = M_indices.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "nodes = {}\n",
    "\n",
    "for i in range(edges.shape[0]):\n",
    "    start_node = edges[i, 0]\n",
    "    end_node = edges[i, 1]\n",
    "    cost = edge_costs[i]\n",
    "    \n",
    "    if start_node not in nodes:\n",
    "        nodes[start_node] = []\n",
    "    \n",
    "    nodes[start_node].append((end_node, cost))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([9195, 2])\n",
      "tensor(1190)\n"
     ]
    }
   ],
   "source": [
    "print(M_indices.shape)\n",
    "print(V)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from astar import AStar\n",
    "\n",
    "class BasicAStar(AStar):\n",
    "    def __init__(self, nodes):\n",
    "        self.nodes = nodes\n",
    "\n",
    "    def neighbors(self, n):\n",
    "        for n1, d in self.nodes[n]:\n",
    "            yield n1\n",
    "\n",
    "    def distance_between(self, n1, n2):\n",
    "        for n, d in self.nodes[n1]:\n",
    "            if n == n2:\n",
    "                return d\n",
    "            \n",
    "    def heuristic_cost_estimate(self, current, goal):\n",
    "        return heur[current, goal]\n",
    "    \n",
    "    def is_goal_reached(self, current, goal):\n",
    "        return current == goal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1190, 1190)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "heur.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "numpy.longlong"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.longlong"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def heuristic(u, v):\n",
    "    return 50.\n",
    "def heuristic2(u, v):\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx.DiGraph()\n",
    "for i, ed in enumerate(edges):\n",
    "    G.add_edge(ed[0], ed[1], weight=edge_costs[i])\n",
    "    \n",
    "Gj = jgrapht.create_graph(directed=True, weighted=True)\n",
    "vxss = list(map(int, np.arange(0, V)))\n",
    "Gj.add_vertices_from(vxss)\n",
    "for i, ed in enumerate(edges):\n",
    "    Gj.add_edge(list(map(int, ed))[0], list(map(int, ed))[1], weight=edge_costs[i], edge=i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4439055919647217\n",
      "0.19105195999145508\n",
      "0.3467121124267578\n"
     ]
    }
   ],
   "source": [
    "N_sim = 100\n",
    "se = np.random.randint(0, V, (N_sim, 2))\n",
    "\n",
    "start_time = time.time()\n",
    "for i in range(N_sim):\n",
    "    solution_astar = nx.astar_path(G, 0, 1000, heuristic, weight=\"weight\")\n",
    "    \n",
    "end_time_astar = time.time()\n",
    "\n",
    "#for i in range(N_sim):\n",
    "#    solution_astar2 = BasicAStar(nodes).astar(se[i,0], se[i,1])\n",
    "    \n",
    "for i in range(N_sim):    \n",
    "    solution_astar2 = sp.a_star(Gj, 0, 1000, heuristic, use_bidirectional=False)\n",
    "    \n",
    "end_time_astar2 = time.time()\n",
    "\n",
    "for i in range(N_sim):\n",
    "    solution_dij = nx.dijkstra_path(G, 0, 1000, weight=\"weight\")\n",
    "end_time_dij = time.time()\n",
    "\n",
    "print(end_time_astar-start_time)\n",
    "print(end_time_astar2-end_time_astar)\n",
    "print(end_time_dij-end_time_astar2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "solution_dij == solution_astar2.vertices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "solution_dij == solution_astar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0,\n",
       " 18,\n",
       " 44,\n",
       " 60,\n",
       " 87,\n",
       " 115,\n",
       " 162,\n",
       " 179,\n",
       " 202,\n",
       " 225,\n",
       " 276,\n",
       " 332,\n",
       " 394,\n",
       " 440,\n",
       " 474,\n",
       " 544,\n",
       " 614,\n",
       " 659,\n",
       " 704,\n",
       " 782,\n",
       " 828,\n",
       " 915,\n",
       " 947,\n",
       " 1000]"
      ]
     },
     "execution_count": 179,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "solution_dij"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0,\n",
       " 18,\n",
       " 44,\n",
       " 60,\n",
       " 87,\n",
       " 115,\n",
       " 162,\n",
       " 179,\n",
       " 202,\n",
       " 225,\n",
       " 276,\n",
       " 332,\n",
       " 394,\n",
       " 440,\n",
       " 474,\n",
       " 510,\n",
       " 554,\n",
       " 601,\n",
       " 663,\n",
       " 711,\n",
       " 773,\n",
       " 868,\n",
       " 898,\n",
       " 947,\n",
       " 1000]"
      ]
     },
     "execution_count": 180,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "solution_astar2.vertices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0,\n",
       " 18,\n",
       " 44,\n",
       " 60,\n",
       " 87,\n",
       " 115,\n",
       " 162,\n",
       " 179,\n",
       " 202,\n",
       " 225,\n",
       " 276,\n",
       " 332,\n",
       " 394,\n",
       " 440,\n",
       " 474,\n",
       " 510,\n",
       " 554,\n",
       " 601,\n",
       " 663,\n",
       " 711,\n",
       " 773,\n",
       " 868,\n",
       " 898,\n",
       " 947,\n",
       " 1000]"
      ]
     },
     "execution_count": 176,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "solution_astar2.vertices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[338,\n",
       " 369,\n",
       " 376,\n",
       " 440,\n",
       " 429,\n",
       " 488,\n",
       " 493,\n",
       " 494,\n",
       " 555,\n",
       " 604,\n",
       " 622,\n",
       " 639,\n",
       " 668,\n",
       " 671,\n",
       " 667,\n",
       " 709]"
      ]
     },
     "execution_count": 148,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "solution_astar2.vertices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E\n",
      "======================================================================\n",
      "ERROR: /home/alan/ (unittest.loader._FailedTest)\n",
      "----------------------------------------------------------------------\n",
      "AttributeError: module '__main__' has no attribute '/home/alan/'\n",
      "\n",
      "----------------------------------------------------------------------\n",
      "Ran 1 test in 0.001s\n",
      "\n",
      "FAILED (errors=1)\n"
     ]
    },
    {
     "ename": "SystemExit",
     "evalue": "True",
     "output_type": "error",
     "traceback": [
      "An exception has occurred, use %tb to see the full traceback.\n",
      "\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/alan/Desktop/envs/pao_env/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3377: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.\n",
      "  warn(\"To exit: use 'exit', 'quit', or Ctrl-D.\", stacklevel=1)\n"
     ]
    }
   ],
   "source": [
    "import unittest\n",
    "\n",
    "path = BasicAStar(node).astar('A', 'D')\n",
    "class BasicTests(unittest.TestCase):\n",
    "\n",
    "    def test_bestpath(self):\n",
    "        \"\"\"ensure that we take the shortest path, and not the path with less elements.\n",
    "           the path with less elements is A -> B with a distance of 100\n",
    "           the shortest path is A -> C -> D -> B with a distance of 60\n",
    "        \"\"\"\n",
    "        nodes = {'A': [('B', 100), ('C', 20)],\n",
    "                 'C': [('D', 20)], \n",
    "                 'D': [('B', 20)]}\n",
    "\n",
    "        path = BasicAStar(nodes).astar('A', 'B')\n",
    "        self.assertIsNotNone(path)\n",
    "        if path:\n",
    "            path = list(path)\n",
    "            self.assertEqual(4, len(path))\n",
    "            for i, n in enumerate('ACDB'):\n",
    "                self.assertEqual(n, path[i])\n",
    "\n",
    "    def test_issue_15(self):\n",
    "        \"\"\"This test case reproduces https://github.com/jrialland/python-astar/issues/15.\n",
    "        B has no neighbors, therefore the computation should return None and not raise an exception.\n",
    "        \"\"\"\n",
    "        node = {\n",
    "            'A': [('B', 200000)],\n",
    "            'C': [('D', 200000)],\n",
    "            'D': [('E', 200000)],\n",
    "            'E': [('F', 200000)],\n",
    "            'B': [],\n",
    "            'F': []\n",
    "        }\n",
    "        path = BasicAStar(node).astar('A', 'D')\n",
    "        self.assertIsNone(path)\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    unittest.main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = jgrapht.create_graph(directed=False, weighted=True)\n",
    "gen.barabasi_albert(g, 3, 3, 10, seed=17)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = random.Random(17)\n",
    "for e in g.edges:\n",
    "    g.set_edge_weight(e, 100 * rng.random())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_JGraphTGraph-EdgeSet(<Swig Object of type 'void *' at 0x7f9664ec4330>)"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g.add_edge()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {0={0,1}, 1={0,2}, 2={1,2}, 3={3,0}, 4={3,1}, 5={3,2}, 6={4,1}, 7={4,0}, 8={4,3}, 9={5,0}, 10={5,1}, 11={5,3}, 12={6,5}, 13={6,0}, 14={6,1}, 15={7,1}, 16={7,3}, 17={7,0}, 18={8,0}, 19={8,7}, 20={8,4}, 21={9,1}, 22={9,7}, 23={9,3}})\n"
     ]
    }
   ],
   "source": [
    "print(g)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "tree = sp.dijkstra(g, source_vertex=6)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = tree.get_path(8)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "path start: 6\n",
      "path end: 8\n",
      "path edges: [13, 18]\n",
      "path vertices: [6, 0, 8]\n",
      "path weight: 37.96278817981964\n"
     ]
    }
   ],
   "source": [
    "print('path start: {}'.format(path.start_vertex))\n",
    "print('path end: {}'.format(path.end_vertex))\n",
    "print('path edges: {}'.format(path.edges))\n",
    "print('path vertices: {}'.format(path.vertices))\n",
    "print('path weight: {}'.format(path.weight))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pao_env",
   "language": "python",
   "name": "pao_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
