{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6753a01e-a475-4fe2-936f-b1ff168bd813",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import logging\n",
    "import time\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "from copy import deepcopy\n",
    "\n",
    "from eg_mcts.utils.prepare_methods import prepare_starting_molecules, prepare_mlp, prepare_egmcts_planner\n",
    "from eg_mcts.utils.smiles_process import smiles_to_fp, reaction_smarts_to_fp\n",
    "from eg_mcts.model.eg_network import EG_MLP\n",
    "from eg_mcts.utils.logger import setup_logger\n",
    "import numpy as np\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9a3b2d2-30b4-4776-b949-d8889511d31e",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "expansion_topk=50\n",
    "use_value_fn=True\n",
    "iterations=500\n",
    "logdir=\"data/temp/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c5a49e1-e01f-4089-8bda-d1834146e28f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO\n",
    "# there is some weird bug for really small expansion topk (e.g. 5) when a target node gets qv=-inf assigned\n",
    "#     doesnt happen on practically relevant settings for expansion topk, but still.."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4b0156a-3b61-4ae0-ae5e-625fff9d6306",
   "metadata": {},
   "outputs": [],
   "source": [
    "dirpath = \"externals/eg-mcts/eg_mcts/\"\n",
    "gpu=-1\n",
    "starting_molecules=dirpath+'/dataset/origin_dict.csv'\n",
    "mlp_templates=dirpath+'/one_step_model/template_rules_1.dat'\n",
    "mlp_model_dump=dirpath+'/one_step_model/retro_star_value_ours.ckpt'\n",
    "save_folder=dirpath+'/saved_EG_fn'\n",
    "value_model=\"best_egn_for_emol.pt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6bbedcf-5569-4f83-9575-c7a5550bac93",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda:%d' % gpu if gpu >= 0 else 'cpu')\n",
    "starting_mols = prepare_starting_molecules(starting_molecules)\n",
    "one_step = prepare_mlp(mlp_templates, mlp_model_dump)\n",
    "one_step.rules2idx = {v:k for k,v in one_step.idx2rules.items()}\n",
    "\n",
    "expand_fn = lambda x: one_step.run(x, topk=expansion_topk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c281634e-e57a-426b-95a2-ca8653fdced7",
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_value_fn:\n",
    "    print('using value fn')\n",
    "    model = EG_MLP(\n",
    "        n_layers=1,\n",
    "        fp_dim=4096,\n",
    "        latent_dim=256,\n",
    "        dropout_rate=0.1,\n",
    "        device=device\n",
    "    ).to(device)\n",
    "    model_f = '%s/%s' % (save_folder, value_model)\n",
    "    logging.info('Loading Experience Guidance Network from %s' % model_f)\n",
    "    model.load_state_dict(torch.load(model_f, map_location=device))\n",
    "    model.eval()\n",
    "\n",
    "    def value_fn(mol, template):\n",
    "        mol_fp = smiles_to_fp(mol, fp_dim=2048).reshape(1, -1)\n",
    "        template_fp = reaction_smarts_to_fp(template, fp_dim=2048).reshape(1, -1)\n",
    "        fp = np.hstack((mol_fp, template_fp))\n",
    "        fp = torch.FloatTensor(fp).to(device)\n",
    "        v = model(fp).item()\n",
    "        return v\n",
    "else:\n",
    "    value_fn = lambda x,y: 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54456f00-9b13-403b-8d1a-2fea8d40b1ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "from utils import plot_graph\n",
    "\n",
    "import networkx as nx\n",
    "\n",
    "from networkx.drawing.nx_agraph import to_agraph\n",
    "\n",
    "def truncate_string(s, max_length):\n",
    "    return s if len(s) <= max_length else s[:max_length - 2] + '..'\n",
    "\n",
    "from typing import Any\n",
    "\n",
    "def nice_str(\n",
    "    obj: Any,\n",
    "    *,\n",
    "    indent: int = 0,\n",
    "    indent_step: int = 4,\n",
    "    float_fmt: str = \".4f\",\n",
    "    sort_keys: bool = False\n",
    ") -> str:\n",
    "    \"\"\"\n",
    "    Return a nicely formatted string for any nested structure of\n",
    "    dicts, lists, tuples, sets, and basic scalars.\n",
    "    \n",
    "    float_fmt: a format spec like '.2f', ':.3g', etc.\n",
    "    sort_keys: if True, dict keys will be sorted.\n",
    "    \"\"\"\n",
    "    pad = \" \" * indent\n",
    "\n",
    "    # Scalars\n",
    "    if isinstance(obj, float):\n",
    "        return format(obj, float_fmt)\n",
    "    if isinstance(obj, (str, bool, int)) or obj is None:\n",
    "        return repr(obj)\n",
    "\n",
    "    # Mapping\n",
    "    if isinstance(obj, dict):\n",
    "        items = obj.items()\n",
    "        if sort_keys:\n",
    "            items = sorted(items, key=lambda kv: kv[0])\n",
    "        if not items:\n",
    "            return \"{}\"\n",
    "        lines = []\n",
    "        for k, v in items:\n",
    "            key_str = repr(k)\n",
    "            val_str = nice_str(v, indent=indent + indent_step,\n",
    "                               indent_step=indent_step,\n",
    "                               float_fmt=float_fmt,\n",
    "                               sort_keys=sort_keys)\n",
    "            lines.append(f\"{' ' * (indent + indent_step)}{key_str}: {val_str}\")\n",
    "        body = \"\\n\".join(lines)\n",
    "        return body + \"\\n\"\n",
    "\n",
    "    # Sequence types\n",
    "    if isinstance(obj, (list, tuple, set)):\n",
    "        if not obj:\n",
    "            return \"[]\" if isinstance(obj, list) else \\\n",
    "                   \"()\" if isinstance(obj, tuple) else \\\n",
    "                   \"set()\"\n",
    "\n",
    "        lines = []\n",
    "        for item in obj:\n",
    "            item_str = nice_str(item, indent=indent + indent_step,\n",
    "                                indent_step=indent_step,\n",
    "                                float_fmt=float_fmt,\n",
    "                                sort_keys=sort_keys)\n",
    "            lines.append(f\"{' ' * (indent + indent_step)}{item_str}\")\n",
    "        body = \",\\n\".join(lines)\n",
    "        return body + \"\\n\"\n",
    "\n",
    "    # Fallback\n",
    "    return repr(obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bb7f1c6-ddb7-47b0-8a51-99d1b8534f26",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_andor_graph(G, truncate_edge_strings=False, truncate_node_strings=False, edge_data=None, node_data=None):\n",
    "    P = G.copy()\n",
    "    max_edge_length = max_node_length = 10000\n",
    "    if truncate_edge_strings > 1:\n",
    "        max_edge_length = truncate_edge_strings\n",
    "    if truncate_node_strings > 1:\n",
    "        max_node_length = truncate_node_strings\n",
    "\n",
    "    or_style = dict(shape=\"ellipse\")\n",
    "    and_style = dict(shape=\"rectangle\")\n",
    "    false_style = dict(style=\"filled\", fillcolor=\"red\")\n",
    "    true_style = dict(style=\"filled\", fillcolor=\"lightgreen\")\n",
    "\n",
    "    for s,t,d in P.edges(data=True):\n",
    "        #print(s,t,k,d)\n",
    "        if \"label\" not in d and edge_data is None: \n",
    "            continue\n",
    "        if edge_data is not None:\n",
    "            d[\"label\"] = nice_str(d)\n",
    "        d[\"label\"] = \"\\n\".join([truncate_string(line, max_edge_length) for line in d[\"label\"].split(\"\\n\")])\n",
    "        d[\"fontsize\"] = 8\n",
    "\n",
    "    \n",
    "    for n, d in P.nodes(data=True):\n",
    "        if \"label\" not in d:\n",
    "            d[\"label\"] = str(n)\n",
    "        if node_data is not None:\n",
    "            d[\"label\"] = nice_str(d)\n",
    "        d[\"label\"] = \"\\n\".join([truncate_string(line, max_node_length) for line in d[\"label\"].split(\"\\n\")])\n",
    "\n",
    "        if d[\"truth_value\"]:\n",
    "            d.update(true_style)\n",
    "        else:\n",
    "            d.update(false_style)\n",
    "\n",
    "        if d[\"node_type\"] == \"and\":\n",
    "            d.update(and_style)\n",
    "        elif d[\"node_type\"] == \"or\":\n",
    "            d.update(or_style)\n",
    "        d[\"fontsize\"] = 8\n",
    "\n",
    "    plot_graph(P)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ecad0d0-bcdd-441c-ac0c-aa222bbab386",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import deque\n",
    "\n",
    "ROOT_NODE = \"__ROOT__\"\n",
    "class AndOrGraph(nx.DiGraph):\n",
    "    \n",
    "    def __init__(self, is_exhaustable=True, max_expands=1, max_visits_per_target=None, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        max_visits_per_target = np.inf if max_visits_per_target is None else max_visits_per_target\n",
    "        self.max_visits_per_target = max_visits_per_target\n",
    "        self.is_exhaustable = is_exhaustable\n",
    "        self.max_expands = max_expands\n",
    "        self.targets = None\n",
    "        self.root_ands = None\n",
    "\n",
    "    def get_unsolved_targets(self):\n",
    "        # always based on the current state of the truth values\n",
    "        return [n for n in self.targets if not self.is_solved(n)]\n",
    "\n",
    "    def is_solved(self, node):\n",
    "        return self.nodes[node][\"truth_value\"]\n",
    "\n",
    "    def is_and_node(self, node):\n",
    "        return self.nodes[node][\"node_type\"] == \"and\"\n",
    "\n",
    "    def is_or_node(self, node):\n",
    "        return self.nodes[node][\"node_type\"] == \"or\"\n",
    "    \n",
    "    def get_and_nodes(self):\n",
    "        return [n for n in self if self.is_and_node(n)]\n",
    "\n",
    "    def get_or_nodes(self):\n",
    "        return [n for n in self if self.is_or_node(n)]\n",
    "    \n",
    "    def add_node(self, node, node_type=None, **data):\n",
    "        assert node_type is not None\n",
    "        super().add_node(node, node_type=node_type, truth_value=False, **data)\n",
    "\n",
    "    def add_or_node(self, node, **data):\n",
    "        # adds an OR node\n",
    "        self.add_node(node, node_type=\"or\", **data)\n",
    "        d = self.nodes[node]\n",
    "        d[\"id\"] = d.get(\"id\", len(self))\n",
    "        d[\"is_open\"] = d.get(\"is_open\", True)\n",
    "        d[\"qv\"] = d.get(\"qv\", np.nan)\n",
    "        d[\"num_expanded\"] = d.get(\"num_expanded\", 0)\n",
    "        d[\"is_exhausted\"] = d.get(\"is_exhausted\",False)\n",
    "\n",
    "    def node_was_expanded(self, node):\n",
    "        self.nodes[node][\"num_expanded\"] += 1\n",
    "        self.nodes[node][\"is_exhausted\"] = self.is_exhaustable and self.nodes[node][\"num_expanded\"] >= self.max_expands\n",
    "    \n",
    "    def add_and_node(self, src, node, children, qv=np.nan, num_visited=1, pro=0.0, **group_kwargs):\n",
    "        # adds an AND node with OR children to the given src node\n",
    "        assert src in self\n",
    "        if node is None:\n",
    "            node = len([n for n,d in self.nodes(data=True) if d[\"node_type\"] == \"and\"])\n",
    "        self.add_node(node, node_type=\"and\", qv=qv, pro=pro, num_visited=num_visited, **group_kwargs)\n",
    "        self.add_edge(src, node)\n",
    "        for child in children:\n",
    "            self.add_or_node(child)\n",
    "            self.add_edge(node, child)\n",
    "\n",
    "    def update_truth_values(self, literals=None):\n",
    "        '''\n",
    "            computes the least fixpoint for the boolean circuit expressed by this graph. \n",
    "            `literals` are all nodes you know to be true beforehand.\n",
    "            Overwrites ALL truth states in self!\n",
    "        '''\n",
    "        truths = {n: False for n in self}\n",
    "        \n",
    "        # Initialize\n",
    "        q = deque()\n",
    "        if literals:\n",
    "            for l in literals:\n",
    "                truths[l] = True\n",
    "                q.append(l)\n",
    "        \n",
    "        while q:\n",
    "            n = q.popleft()\n",
    "            for parent in self.predecessors(n):\n",
    "                if truths[parent]:\n",
    "                    continue\n",
    "                if self.nodes[parent][\"node_type\"] == \"or\":\n",
    "                    # OR: parent becomes true if any child is true\n",
    "                    if truths[n]:\n",
    "                        truths[parent] = True\n",
    "                        q.append(parent)\n",
    "                else:  # AND\n",
    "                    # AND: parent becomes true if all children are true\n",
    "                    if all(truths[c] for c in self.successors(parent)):\n",
    "                        truths[parent] = True\n",
    "                        q.append(parent)\n",
    "        \n",
    "        for n,v in truths.items():\n",
    "            d = self.nodes[n]\n",
    "            d[\"truth_value\"] = v\n",
    "\n",
    "        # we also check visit limits for our root nodes:\n",
    "        for n in self.root_ands:\n",
    "            if self.nodes[n][\"num_visited\"] >= self.max_visits_per_target:\n",
    "                self.nodes[next(self.predecessors(n))].update(dict(truth_value=True))\n",
    "   \n",
    "    def update_openness(self):\n",
    "        '''\n",
    "        A node is open unless:\n",
    "            - its true\n",
    "            - if is_exhaustable: \n",
    "                node has no path to an open node\n",
    "        --> for the last one, we again do a least fixpoint iteration:\n",
    "        set all nodes to True, that are guaranteed to be open, i.e. nodes that are:\n",
    "            - not true but not exhausted\n",
    "            - \n",
    "        Then, recursively, check for each node whether all of its children are \n",
    "        '''\n",
    "        is_open = {n: False for n,d in self.nodes(data=True) if d[\"node_type\"] == \"or\"}\n",
    "        def check_open(n):\n",
    "            check = (not (d:=self.nodes[n])[\"truth_value\"]) \n",
    "            check = check and (not d[\"is_exhausted\"] or any(any(is_open[or_child] for or_child in self.successors(and_child)) for and_child in self.successors(n)))\n",
    "            return check\n",
    "\n",
    "        changed = True\n",
    "        while changed:\n",
    "            changed = False\n",
    "            for n,v in is_open.items():\n",
    "                if v: continue\n",
    "                new_v = check_open(n)\n",
    "                is_open[n] = new_v\n",
    "                changed |= new_v\n",
    "\n",
    "        for n,v in is_open.items():\n",
    "            self.nodes[n][\"is_open\"] = v\n",
    "\n",
    "    def update_q_values_for_or_nodes(self):\n",
    "        if not self.is_exhaustable:\n",
    "            raise NotImplementedError()\n",
    "\n",
    "        to_calc = {}\n",
    "        for n,d in self.nodes(data=True):\n",
    "            if d[\"node_type\"] != \"or\": continue\n",
    "            ie = d[\"is_exhausted\"]\n",
    "            io = d[\"is_open\"]\n",
    "            it = d[\"truth_value\"] #\n",
    "            il = self.out_degree(n) > 0\n",
    "\n",
    "            if it:\n",
    "                d[\"qv\"] = 10.0\n",
    "            elif ie:\n",
    "                # we are open, need to calc\n",
    "                if io:\n",
    "                    to_calc[n] = None\n",
    "                else:\n",
    "                # we are false, exhausted and closed\n",
    "                    d[\"qv\"] = -np.inf\n",
    "            elif not ie and io:\n",
    "                # we are expandable, open and an intermediate node, need to calc\n",
    "                if il:\n",
    "                    to_calc[n] = None\n",
    "                else:\n",
    "                    # expandable, open and a leaf\n",
    "                    d[\"qv\"] = np.inf\n",
    "            else:\n",
    "                raise ValueError(\"unhandled node state for node\", n, f\"\\n\\texhausted: {ie}\\n\\tis_open: {io}\\n\\tis_true: {it}\\n\\tis_intermediate: {il}\")\n",
    "\n",
    "        # for each depending node, we gather the qvs of their children and take the max\n",
    "        # we do not directly overwrite the qv in the graph to avoid feedback loops\n",
    "        for n in to_calc:\n",
    "            to_calc[n] = max(self.nodes[c][\"qv\"] for c in self.successors(n))\n",
    "\n",
    "        for n,qv in to_calc.items():\n",
    "            self.nodes[n][\"qv\"] = qv\n",
    "\n",
    "    def add_target_nodes(self, nodes, initial_qv=0):\n",
    "        if len(self) != 0:\n",
    "            raise ValueError(\"Graph not empty\")\n",
    "        self.add_or_node(ROOT_NODE)\n",
    "        children = []\n",
    "        root_ands = []\n",
    "        for i,n in enumerate(nodes):\n",
    "            np = ROOT_NODE + str(i)\n",
    "            self.add_or_node(np, num_expanded=1)\n",
    "            self.add_or_node(n)\n",
    "            children.append(np)\n",
    "            ra = f\"ROOT_AND_{i}\"\n",
    "            self.add_and_node(np, ra, [n], qv=initial_qv)\n",
    "            root_ands.append(ra)\n",
    "        self.root_ands = root_ands\n",
    "        self.add_and_node(ROOT_NODE, \"ROOT_AND\", children, qv=0)\n",
    "\n",
    "        self.targets = nodes\n",
    "        return ROOT_NODE\n",
    "\n",
    "    def update_truth_and_proof_tree_sizes(self, literals=None):\n",
    "        \"\"\"\n",
    "        Computes truth values AND minimal proof sizes for each node.\n",
    "        Stores results in self.nodes[n][\"truth_value\"] and\n",
    "        self.nodes[n][\"proof_size\"].\n",
    "\n",
    "        Usually only needed once search is terminated and you want to find minimum proof\n",
    "        trees for your successful nodes.\n",
    "        Unless that's the case, stick with the cheaper `update_truth_values` method.\n",
    "        \"\"\"\n",
    "        from collections import deque\n",
    "    \n",
    "        truths = {n: False for n in self}\n",
    "        proof_size = {n: float(\"inf\") for n in self}\n",
    "        backpointers = {n: [] for n in self}  # store chosen children\n",
    "    \n",
    "        q = deque()\n",
    "        if literals:\n",
    "            for l in literals:\n",
    "                truths[l] = True\n",
    "                proof_size[l] = 1\n",
    "                backpointers[l] = []  # literal = leaf\n",
    "                q.append(l)\n",
    "    \n",
    "        while q:\n",
    "            n = q.popleft()\n",
    "            for parent in self.predecessors(n):\n",
    "                node_type = self.nodes[parent][\"node_type\"]\n",
    "                children = list(self.successors(parent))\n",
    "    \n",
    "                if node_type == \"or\":\n",
    "                    if truths[n]:\n",
    "                        new_size = 1 + proof_size[n]\n",
    "                        if not truths[parent] or new_size < proof_size[parent]:\n",
    "                            truths[parent] = True\n",
    "                            proof_size[parent] = new_size\n",
    "                            backpointers[parent] = [n]\n",
    "                            q.append(parent)\n",
    "    \n",
    "                else:  # AND\n",
    "                    if all(truths[c] for c in children):\n",
    "                        new_size = 1 + sum(proof_size[c] for c in children)\n",
    "                        if not truths[parent] or new_size < proof_size[parent]:\n",
    "                            truths[parent] = True\n",
    "                            proof_size[parent] = new_size\n",
    "                            backpointers[parent] = children\n",
    "                            q.append(parent)\n",
    "    \n",
    "        # write results into graph\n",
    "        for n in self:\n",
    "            self.nodes[n][\"truth_value\"] = truths[n]\n",
    "            self.nodes[n][\"proof_size\"] = proof_size[n] if truths[n] else None\n",
    "            self.nodes[n][\"backpointers\"] = backpointers[n] if truths[n] else []\n",
    "\n",
    "    def extract_min_proof_dag(self, node):\n",
    "        \"\"\"\n",
    "        Returns a NetworkX DiGraph representing the minimal proof DAG for `node`.\n",
    "        \"\"\"\n",
    "    \n",
    "        if not self.nodes[node][\"truth_value\"]:\n",
    "            return nx.DiGraph()  # empty proof\n",
    "    \n",
    "        proof = nx.DiGraph()\n",
    "        visited = set()\n",
    "    \n",
    "        def dfs(n):\n",
    "            if n in visited:\n",
    "                return\n",
    "            visited.add(n)\n",
    "            proof.add_node(n, **self.nodes[n])  # copy attrs\n",
    "            for c in self.nodes[n][\"backpointers\"]:\n",
    "                proof.add_node(c, **self.nodes[c])\n",
    "                proof.add_edge(n, c)\n",
    "                dfs(c)\n",
    "\n",
    "        dfs(node)\n",
    "        return proof"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "714523fd-b55c-4648-9c00-c8fb034dc961",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import time\n",
    "from contextlib import contextmanager\n",
    "\n",
    "@contextmanager\n",
    "def timing(label: str = \"Block\"):\n",
    "    start = time.perf_counter()\n",
    "    yield\n",
    "    end = time.perf_counter()\n",
    "    #print(f\"{label} took {end - start:.6f} seconds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f8decd5-8f23-4c9e-8c35-4b48ae40a36c",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "# retrosynthesis with HP-MCTS\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffef0700-7375-4e4a-b344-ecc40fbc07f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"externals/eg-mcts/eg_mcts/dataset/retro190.txt\") as f:\n",
    "    lines = [x.strip() for x in f.readlines()]\n",
    "retro_data = lines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ced766-903d-4492-a675-e4f2b98c5841",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "target_mols = retro_data\n",
    "A = AndOrGraph(is_exhaustable=True, max_expands=1, max_visits_per_target=iterations)\n",
    "root = A.add_target_nodes(target_mols)\n",
    "print(\"Added\", len(target_mols), \"targets mols\")\n",
    "A.update_truth_values()\n",
    "A.update_openness()\n",
    "A.update_q_values_for_or_nodes()\n",
    "#visualize_andor_graph(A, truncate_node_strings=25, truncate_edge_strings=25, edge_data=1, node_data=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47ab7023-7d40-4515-802e-280443671ade",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "max_total_iterations = iterations * len(A.targets)\n",
    "\n",
    "record_history = False\n",
    "\n",
    "graphs = []\n",
    "value_model_calls = 0\n",
    "expand_model_calls = 0\n",
    "true_nodes = set()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a735df1c-636f-4670-a708-08662a1d5af1",
   "metadata": {
    "editable": true,
    "scrolled": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "with tqdm(total=max_total_iterations) as pbar:\n",
    "    for i in range(max_total_iterations):\n",
    "        if record_history:\n",
    "            graphs.append(deepcopy(A))\n",
    "        if A.is_solved(root):\n",
    "            print(\"break\", i)\n",
    "            break\n",
    "\n",
    "        # puct traversal\n",
    "        # take some root\n",
    "        next_node = root\n",
    "        path = [next_node]\n",
    "        visited_nodes = set()\n",
    "        forbidden_edges = set()\n",
    "        num_visited_parent = i + 1\n",
    "        no_more_nodes = False\n",
    "        C = 0.5\n",
    "        with timing(\"traversal\"):\n",
    "            while not no_more_nodes:\n",
    "                nd = A.nodes[next_node]\n",
    "                # check whether we are in a leaf node\n",
    "                if A.out_degree(next_node) == 0:\n",
    "                    assert not nd[\"is_exhausted\"]\n",
    "                    break\n",
    "    \n",
    "                pucts = []\n",
    "                for and_child in A.successors(next_node):\n",
    "                    # check whether this edge previously lead to a deadend:\n",
    "                    d = A.nodes[and_child]\n",
    "\n",
    "                    if (next_node, and_child) in forbidden_edges:\n",
    "                        continue\n",
    "    \n",
    "                    # check whether following this node would lead to a deadend or circle\n",
    "                    if not any(A.nodes[c][\"is_open\"] and c not in visited_nodes for c in A.successors(and_child)):\n",
    "                        continue\n",
    "                    q,p,n = d[\"qv\"], d[\"pro\"], d[\"num_visited\"]\n",
    "                    puct = q / n + C * p * num_visited_parent ** 0.5 / (n + 1)\n",
    "                    pucts.append((puct, and_child))\n",
    "                pucts = sorted(pucts, key=lambda x: x[0], reverse=True)\n",
    "        \n",
    "                for puct, and_child in pucts:\n",
    "                    next_or_node = None\n",
    "                    for or_child in A.successors(and_child):\n",
    "                        if or_child in visited_nodes or not A.nodes[or_child][\"is_open\"]:\n",
    "                            continue\n",
    "                        next_or_node = or_child\n",
    "                        break\n",
    "                    if next_or_node:\n",
    "                        break\n",
    "                else:\n",
    "                    # we did end up in a dead end here. Idea: back up and choose different path instead\n",
    "                    # path[-1] is our current deadend or node\n",
    "                    # path[-2] is the and node we took in path[-3] that lead to the deadend\n",
    "                    # --> we add the edge (path[-3], path[-2]) to the forbidden edges and\n",
    "                    #     back off to node path[-3].\n",
    "                    if len(path) == 1:\n",
    "                        print(\"backed up to the root --> ran all out of nodes to explore, exiting.\")\n",
    "                        no_more_nodes = True\n",
    "                    else:\n",
    "                        e = (path[-3],path[-2])\n",
    "                        forbidden_edges.add(e)\n",
    "                        path = path[:-2]\n",
    "                        next_node = path[-1]\n",
    "                        print(\"node\", nd[\"id\"], \"was a deadend! Added edge \", A.nodes[e[0]][\"id\"],\"-\",e[1],\"to the forbidden edges set. Resuming from\", A.nodes[next_node][\"id\"])\n",
    "                        continue\n",
    "                    \n",
    "                print(\"in node\", nd[\"id\"], \"selected\", A.nodes[next_or_node][\"id\"],\"from and_node\", and_child,f\"with puct {puct:.3f}\")\n",
    "                path.extend([and_child, or_child])\n",
    "                visited_nodes.add(or_child)\n",
    "                next_node = next_or_node\n",
    "                num_visited_parent = A.nodes[and_child][\"num_visited\"]\n",
    "        if no_more_nodes:\n",
    "            break\n",
    "        print(\"Traversed to node\", A.nodes[next_node][\"id\"])\n",
    "\n",
    "        # expand :)\n",
    "        with timing(\"expand model\"):\n",
    "            result = expand_fn(next_node)\n",
    "        A.node_was_expanded(next_node)\n",
    "        expand_model_calls += 1\n",
    "\n",
    "        with timing(\"insert + value model\"):\n",
    "            if result is None:\n",
    "                print(\"none result by expand\")\n",
    "            else:\n",
    "                reactants = result['reactants']\n",
    "                pros = result['scores']\n",
    "                if 'templates' in result.keys():\n",
    "                    templates = result['templates']\n",
    "                else:\n",
    "                    templates = result['template']\n",
    "                \n",
    "                reactant_lists = []\n",
    "                \n",
    "                for j in range(len(pros)):\n",
    "                    reactant_list = list(set(reactants[j].split('.')))\n",
    "                    reactant_lists.append(reactant_list)\n",
    "                #print(result)\n",
    "                for r, t, s in zip(reactant_lists, templates, pros):\n",
    "                    # get initial q value\n",
    "                    init_value = value_fn(next_node, t)\n",
    "                    #print(init_value)\n",
    "                    value_model_calls += 1\n",
    "                    A.add_and_node(next_node, f\"and_{len(A)}\", r, template=t, pro=s, qv=init_value, num_visited=1)\n",
    "            \n",
    "                    # update found building blocks\n",
    "                    true_nodes |= true_nodes.union(c for c in r if c in starting_mols)\n",
    "\n",
    "        with timing(\"update truth\"):\n",
    "            A.update_truth_values(literals=true_nodes)\n",
    "        \n",
    "        with timing(\"update openness\"):\n",
    "            A.update_openness()\n",
    "\n",
    "        with timing(\"update on path\"):\n",
    "            # update along path\n",
    "            for n in path[::-1]:\n",
    "                d = A.nodes[n]\n",
    "                if d[\"node_type\"] == \"and\":\n",
    "                    # update visit counter, qv\n",
    "                    if all(A.nodes[c][\"truth_value\"] for c in A.successors(n)):\n",
    "                        nqv = 10\n",
    "                    elif any(A.nodes[c][\"qv\"] == -np.inf for c in A.successors(n)):\n",
    "                        nqv = -10\n",
    "                    else:\n",
    "                        #child_qvs = sum([A.nodes[c][\"qv\"] for c in A.successors(n) if not A.nodes[c][\"is_open\"]])\n",
    "                        total_vm = 0\n",
    "                        not_expanded = 0\n",
    "                        for c in A.successors(n):\n",
    "                            nd = A.nodes[c]\n",
    "                            if nd[\"num_expanded\"] > 0 or A.nodes[c][\"truth_value\"]:\n",
    "                                total_vm += nd[\"qv\"]\n",
    "                            else:\n",
    "                                not_expanded += 1\n",
    "                        num_children = len(list(A.successors(n)))\n",
    "                        mean_qv = total_vm / (num_children - not_expanded)\n",
    "                        print(n, \"sees total_vm\", total_vm, \"ne\", not_expanded, \"#children\", num_children)\n",
    "                        \n",
    "                        nqv = (d[\"qv\"] * d[\"num_visited\"] + mean_qv) / (d[\"num_visited\"] + 1)\n",
    "                    print(\"changed qv of node\", n,\"from\", d[\"qv\"],\"to\", nqv)\n",
    "                    d[\"qv\"] = nqv\n",
    "                    d[\"num_visited\"] += 1\n",
    "                else:\n",
    "                    # update qv\n",
    "                    ie = d[\"is_exhausted\"]\n",
    "                    io = d[\"is_open\"]\n",
    "                    it = d[\"truth_value\"]\n",
    "                    il = A.out_degree(n) > 0\n",
    "            \n",
    "                    if it:\n",
    "                        nqv = 10.0\n",
    "                    elif ie:\n",
    "                        # we are open, need to calc\n",
    "                        if io:\n",
    "                            nqv = max(A.nodes[c][\"qv\"] for c in A.successors(n))\n",
    "                        else:\n",
    "                        # we are false, exhausted and closed\n",
    "                            nqv = -np.inf\n",
    "                    elif not ie and io:\n",
    "                        # we are expandable, open and an intermediate node, need to calc\n",
    "                        if il:\n",
    "                            nqv = max(A.nodes[c][\"qv\"] for c in A.successors(n))\n",
    "                        else:\n",
    "                            # expandable, open and a leaf\n",
    "                            nqv = np.inf\n",
    "                    else:\n",
    "                        raise ValueError(f\"unhandled node state for node '{n}'. \\n\\texhausted: {ie}\\n\\tis_open: {io}\\n\\tis_true: {it}\\n\\tis_intermediate: {il}\")\n",
    "                    print(\"changed qv of node\", d[\"id\"],\"from\", d[\"qv\"],\"to\", nqv)\n",
    "                    d[\"qv\"] = nqv\n",
    "\n",
    "        with timing(\"update qvals for or\"):\n",
    "            A.update_q_values_for_or_nodes()\n",
    "        # update progress\n",
    "        pbar.update(1)\n",
    "        pbar.set_postfix({\"Targets solved\": f\"{len(A.targets) - len(A.get_unsolved_targets())}/{len(A.targets)}\"})\n",
    "\n",
    "print(\"is solved:\", len(A.get_unsolved_targets())==0)\n",
    "print(\"Expand model calls\", expand_model_calls)\n",
    "print(\"value model calls\", value_model_calls)\n",
    "print(\"# and nodes\", len(A.get_and_nodes()))\n",
    "print(\"# or nodes\", len(A.get_or_nodes()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3e22457-6d3d-4d0f-a3f7-d2a85a1d3cc7",
   "metadata": {},
   "source": [
    "## Compile results and save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "219e4aeb-dea7-4166-acb7-537fb4c39b42",
   "metadata": {},
   "outputs": [],
   "source": [
    "logdir = f\"data/chem/hp/{iterations}_{expansion_topk}_{use_value_fn}/\"\n",
    "os.makedirs(logdir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce0b80af-ac9e-41cd-897e-af8d9aab6b46",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dill\n",
    "with open(logdir + \"A.dat\",\"wb\") as f:\n",
    "    dill.dump(A, f, fix_imports=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "707cab69-c62b-44e0-9e9a-db34057a7fc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from eg_mcts.algorithm.syn_route import SynRoute\n",
    "\n",
    "def extract_SynRoute(A, mol):\n",
    "    if mol not in A or not A.nodes[mol][\"truth_value\"]:\n",
    "        return None\n",
    "\n",
    "    proof = AndOrGraph.extract_min_proof_dag(A, mol)\n",
    "\n",
    "    route = SynRoute(mol)\n",
    "\n",
    "    q = deque([mol])\n",
    "    while not len(q) == 0:\n",
    "        m = q.popleft()\n",
    "        if m in starting_mols:\n",
    "            route.set_value(m, proof.nodes[m][\"qv\"])\n",
    "            continue\n",
    "\n",
    "        # an or child can only have one child in a minimum proof tree (unless its a leaf)\n",
    "        r = next(c for c in proof.successors(m) if proof.nodes[c][\"truth_value\"])\n",
    "        reactants = []\n",
    "        for reactant in proof.successors(r):\n",
    "            q.append(reactant)\n",
    "            reactants.append(reactant)\n",
    "\n",
    "        d = proof.nodes[r]\n",
    "        route.add_reaction(mol=m, value=proof.nodes[m][\"qv\"], template=d[\"template\"], \n",
    "                           prob=d[\"pro\"], Q0=d[\"qv\"], reactants=reactants)\n",
    "    return route"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10691958-19b4-4c4e-b6cf-25874ea09d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# find minimum dags for each target that make it true, a.k.a. minimum proof dags\n",
    "AndOrGraph.update_truth_and_proof_tree_sizes(A, literals=true_nodes)\n",
    "routes = [extract_SynRoute(A, t) for t in A.targets]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90a886db-ad21-4511-a129-558461d3d958",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "result = dict(\n",
    "    iter=expand_model_calls,\n",
    "    route_len=[r.length if r else None for r in routes],\n",
    "    expand_model_call=expand_model_calls,\n",
    "    value_model_call=value_model_calls,\n",
    "    reaction_nodes_lens=len(A.get_and_nodes()),\n",
    "    mol_nodes_lens=len(A.get_or_nodes()),\n",
    "    args=dict(\n",
    "        folder=logdir,\n",
    "        mol=A.targets,\n",
    "        use_value_fn=True,\n",
    "        iterations=iterations,\n",
    "        expansion_topk=expansion_topk\n",
    "    ),\n",
    "    routes=[r.serialize() if r else None for r in routes],\n",
    "    succ=[A.is_solved(t) for t in A.targets], \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ec051bd-40f4-4122-a81a-d32f71bd9a5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open(logdir + \"result.json\", \"w\") as f:\n",
    "    json.dump(result, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcbbfaeb-753a-4a04-8cdf-237705bd6f1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.exit()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc6fed1a-f761-48e6-ad74-b4a1c3afc032",
   "metadata": {},
   "source": [
    "# END\n",
    "***\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hp_env2",
   "language": "python",
   "name": "hp_env2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
