{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "84a8ae1e-5fa1-417a-9653-a5de6447b464",
   "metadata": {},
   "source": [
    "# Overview\n",
    "\n",
    "This is the second part of a 2 part tutorial which runs the PaRoutes benchmark.\n",
    "Running this notebook requires that part 2a is already run.\n",
    "In this half of the tutorial, we analyze the results of search algorithms.\n",
    "\n",
    "\"Analysis\" can be many things. In this notebook we focus on the time when a solution is found\n",
    "and the number of diverse solutions found.\n",
    "We also visualize the routes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a6ee7147-2dd0-4bc8-b54d-12744261fd32",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "import math\n",
    "from pathlib import Path\n",
    "import pickle\n",
    "import pprint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a3c36718-1daa-406b-b87d-6913e60509c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from syntheseus.search.graph.molset import MolSetNode, MolSetGraph\n",
    "from syntheseus.search.graph.and_or import AndNode, OrNode, AndOrGraph"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9e03f8a-9e46-4c26-9cdf-903da9db2b07",
   "metadata": {},
   "source": [
    "## Step 1: load results from notebook 2a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cb398e9e-2ac7-4b5e-88ac-a6af3780881b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load pickle files\n",
    "alg_name_to_graph = dict()\n",
    "for alg_name in [\"retro star\", \"mcts\"]:\n",
    "    with open(f\"./search-results-{alg_name}.pkl\", \"rb\") as f: \n",
    "        alg_name_to_graph[alg_name] = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "843e67fd-2521-485d-af19-643805800a7d",
   "metadata": {},
   "source": [
    "## Step 2: time at which a solution is found \n",
    "\n",
    "The code is written in a way where nodes keep track of their own\n",
    "creation time, so time-based measures can be computed retrospectively for\n",
    "many time measures\n",
    "(e.g. wallclock time, number of calls to reaction model).\n",
    "For any analysis involving time, we need to choose how time is measured,\n",
    "and this is left up to the user by filling in the `analysis_time` field of each node's data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "97c1fb44-3344-4031-9301-654689f6fc2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for graph in alg_name_to_graph.values():\n",
    "    for node in graph.nodes():\n",
    "        \n",
    "        # Wallclock time: difference between this node's creation time and that of the root node\n",
    "        node.data[\"analysis_time\"] = (node.creation_time - graph.root_node.creation_time).total_seconds()\n",
    "        \n",
    "        # Could alternatively use number of calls to reaction model\n",
    "        # node.data[\"analysis_time\"] = node.data[\"num_calls_rxn_model\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cd66b4b2-df67-4e78-8b30-cbae073bf9d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "retro star first solution: 1.205841\n",
      "mcts first solution: 1.267164\n"
     ]
    }
   ],
   "source": [
    "# Now use a function to compute the first solution time\n",
    "from syntheseus.search.analysis.solution_time import get_first_solution_time\n",
    "for alg_name, graph in alg_name_to_graph.items():\n",
    "    print(f\"{alg_name} first solution: {get_first_solution_time(graph)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a16439c-309d-4abe-854f-3010666a7f50",
   "metadata": {},
   "source": [
    "## Step 3: extract routes\n",
    "\n",
    "We extract individual synthesis routes from the graph\n",
    "in order to later calculate their diversity.\n",
    "However, there are many possible routes in a graph,\n",
    "possibly too many to exhaustively enumerate.\n",
    "Therefore we only extract the _minimum cost_ routes,\n",
    "where the cost of each route is the sum of `node.data[\"route_cost\"]` for each\n",
    "node in the route.\n",
    "This cost could be anything: a constant,\n",
    "something based on the policy, etc.\n",
    "It is up to the user to set a route's cost.\n",
    "Here we just assign a constant cost to each node which represents a reaction\n",
    "(i.e. AndNodes and MolSetNodes).\n",
    "This means that the lowest cost routes will be the shortest routes.\n",
    "\n",
    "We also limit the maximum number of routes extracted to speed up computation time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a8858c30-4e71-42bd-b4cc-69e2bf855dae",
   "metadata": {},
   "outputs": [],
   "source": [
    "for graph in alg_name_to_graph.values():\n",
    "    for node in graph.nodes():\n",
    "        \n",
    "        if isinstance(node, (AndNode, MolSetNode)):\n",
    "            node.data[\"route_cost\"] = 1.0\n",
    "        else:\n",
    "            node.data[\"route_cost\"] = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "80c86fd0-cc95-4daa-bed2-83dcb7fec0bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 10000 routes for retro star\n",
      "Found 422 routes for mcts\n",
      "CPU times: user 11.3 s, sys: 54.8 ms, total: 11.4 s\n",
      "Wall time: 11.4 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "from syntheseus.search.analysis import route_extraction\n",
    "alg_name_to_routes = dict()\n",
    "for alg_name, graph in alg_name_to_graph.items():\n",
    "    routes = list(route_extraction.iter_routes_cost_order(graph, 10_000))\n",
    "    print(f\"Found {len(routes)} routes for {alg_name}\", flush=True)\n",
    "    alg_name_to_routes[alg_name] = routes\n",
    "    del routes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f7f427c0-fb49-4e99-ad40-1eb96c2f5f91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We visualize the routes just to get a sense of what they look like\n",
    "from syntheseus.search import visualization\n",
    "\n",
    "visualization.visualize_andor(\n",
    "    alg_name_to_graph[\"retro star\"],\n",
    "    filename=\"retro star route.pdf\",\n",
    "    nodes=alg_name_to_routes[\"retro star\"][0]\n",
    ")\n",
    "\n",
    "visualization.visualize_molset(\n",
    "    alg_name_to_graph[\"mcts\"],\n",
    "    filename=\"mcts route.pdf\",\n",
    "    nodes=alg_name_to_routes[\"mcts\"][0]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c849f47-899d-44ba-8e54-a25dadeececa",
   "metadata": {},
   "source": [
    "## Step 4: calculate diversity\n",
    "\n",
    "Specifically, we estimate the _packing number_ of the route set,\n",
    "i.e. the number of distinct routes which are greater than a distance $r$ away from each other.\n",
    "Here we use a stringent form of diversity: routes which have no common reactions,\n",
    "which means that their Jaccard distance is 1.\n",
    "To do this, we use the `to_synthesis_graph` method of each graph object\n",
    "which converts them into a common format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bfdad9e4-1b92-4214-808e-7652a2e4cbea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "retro star: number of distinct routes = 2\n",
      "mcts: number of distinct routes = 4\n"
     ]
    }
   ],
   "source": [
    "from syntheseus.search.analysis import diversity\n",
    "for alg_name, graph in alg_name_to_graph.items():\n",
    "    route_objects = [graph.to_synthesis_graph(nodes) for nodes in alg_name_to_routes[alg_name]]\n",
    "    packing_set = diversity.estimate_packing_number(\n",
    "        routes=route_objects,\n",
    "        distance_metric=diversity.reaction_jaccard_distance,\n",
    "        radius=0.999  # because comparison is > not >=\n",
    "    )\n",
    "    print(f\"{alg_name}: number of distinct routes = {len(packing_set)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
