{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8e7ee7b-f8e5-49d1-b2ad-244baa279e6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dirpath = \"externals/eg-mcts/eg_mcts\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4155741d-d1e5-46f6-95f6-fea3ecc4aedc",
   "metadata": {},
   "source": [
    "# Prepare data\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33cd1c74-0eec-4bcd-8c73-9de1e12fcfc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "if not os.path.exists(dirpath + \"/dataset/origin_dict.dat\"):\n",
    "    starting_mols = set(list(pd.read_csv(dirpath+\"/dataset/origin_dict.csv\")['mol']))\n",
    "    with open(dirpath+\"/dataset/origin_dict.dat\", \"wb\") as f:\n",
    "        pickle.dump(starting_mols, f)\n",
    "else:\n",
    "    with open(dirpath+\"/dataset/origin_dict.dat\",\"rb\") as f:\n",
    "        starting_mols =  pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1418ab6d-1b1a-4e40-af49-6b82225c8a89",
   "metadata": {},
   "source": [
    "# Code\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2af7e08d-19b9-4524-baed-b3c4ec33b53c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import logging\n",
    "import time\n",
    "import json\n",
    "\n",
    "from eg_mcts.utils.prepare_methods import prepare_mlp, prepare_egmcts_planner, prepare_starting_molecules\n",
    "from eg_mcts.utils.smiles_process import smiles_to_fp, reaction_smarts_to_fp\n",
    "from eg_mcts.model.eg_network import EG_MLP\n",
    "from eg_mcts.utils.logger import setup_logger\n",
    "from mlp_retrosyn.mlp_inference import MLPModel\n",
    "\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "class RSPlanner:\n",
    "    def __init__(self,\n",
    "                 gpu=-1,\n",
    "                 expansion_topk=50,\n",
    "                 iterations=500,\n",
    "                 use_value_fn=False,\n",
    "                 starting_molecules=dirpath+'/dataset/origin_dict.csv',\n",
    "                 mlp_templates=dirpath+'/one_step_model/template_rules_1.dat',\n",
    "                 mlp_model_dump=dirpath+'/one_step_model/retro_star_value_ours.ckpt',\n",
    "                 save_folder=dirpath+'/saved_EG_fn',\n",
    "                 value_model='bets_EGN.pt',\n",
    "                 viz=False,\n",
    "                 viz_dir='viz'):\n",
    "\n",
    "        device = torch.device('cuda:%d' % gpu if gpu >= 0 else 'cpu')\n",
    "        if isinstance(starting_molecules, set):\n",
    "            starting_mols = starting_molecules\n",
    "        else:\n",
    "            starting_mols = prepare_starting_molecules(starting_molecules)\n",
    "        self.use_value_fn = use_value_fn\n",
    "        one_step = MLPModel(mlp_model_dump, mlp_templates, device=gpu)\n",
    "        #one_step = prepare_mlp(mlp_templates, mlp_model_dump)\n",
    "\n",
    "        if use_value_fn:\n",
    "            print('use_fn')\n",
    "            model = EG_MLP(\n",
    "                n_layers=1,\n",
    "                fp_dim=4096,\n",
    "                latent_dim=256,\n",
    "                dropout_rate=0.1,\n",
    "                device=device\n",
    "            ).to(device)\n",
    "            model_f = '%s/%s' % (save_folder, value_model)\n",
    "            logging.info('Loading Experience Guidance Network from %s' % model_f)\n",
    "            model.load_state_dict(torch.load(model_f, map_location=device))\n",
    "            model.eval()\n",
    "\n",
    "            def value_fn(mol, template):\n",
    "                mol_fp = smiles_to_fp(mol, fp_dim=2048).reshape(1, -1)\n",
    "                template_fp = reaction_smarts_to_fp(template, fp_dim=2048).reshape(1, -1)\n",
    "                fp = np.hstack((mol_fp, template_fp))\n",
    "                fp = torch.FloatTensor(fp).to(device)\n",
    "                v = model(fp).item()\n",
    "                return v\n",
    "        else:\n",
    "            value_fn = lambda x,y: 0.5\n",
    "\n",
    "        self.plan_handle = prepare_egmcts_planner(\n",
    "            one_step=one_step,\n",
    "            value_fn=value_fn,\n",
    "            starting_mols=starting_mols,\n",
    "            expansion_topk=expansion_topk,\n",
    "            iterations=iterations,\n",
    "            viz=viz,\n",
    "            viz_dir=viz_dir\n",
    "        )\n",
    "\n",
    "    def plan(self, target_mol, target_molid = 0):\n",
    "        t0 = time.time()\n",
    "        succ, route, msg, expressions = self.plan_handle(target_mol, target_molid)\n",
    "\n",
    "        result = {\n",
    "            'succ': succ,\n",
    "            'time': time.time() - t0,\n",
    "            'iter': msg[0],\n",
    "            'routes': route.serialize() if succ else None,\n",
    "            'route_len': route.length if succ else None,\n",
    "            'expand_model_call': msg[1],\n",
    "            'value_model_call': 0,\n",
    "            'reaction_nodes_lens': msg[3],\n",
    "            'mol_nodes_lens': msg[4]\n",
    "        }\n",
    "\n",
    "        if self.use_value_fn:\n",
    "            result['value_model_call'] = msg[2]\n",
    "        return result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85270849-6735-4015-a73b-c5cb06771047",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import os\n",
    "import time\n",
    "\n",
    "from functools import partial\n",
    "from multiprocess import Pool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a57589c4-674f-4047-bafa-f3b00c9d9ba9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_batch(batch, folder, use_value_fn, iterations, expansion_topk, force=False):\n",
    "    # check whether our batch actually needs any processing\n",
    "    all_there = True\n",
    "    for i, mol in batch:\n",
    "        mol_folder = f\"{folder}/{i}/\"\n",
    "        if os.path.exists(mol_folder + \"result.json\") and not force:\n",
    "            continue\n",
    "        else:\n",
    "            all_there = False\n",
    "            break\n",
    "\n",
    "    if all_there:\n",
    "        return\n",
    "\n",
    "    planner = RSPlanner(\n",
    "        gpu=-1,\n",
    "        use_value_fn=use_value_fn,\n",
    "        iterations=iterations,\n",
    "        expansion_topk=expansion_topk,\n",
    "        value_model=\"best_egn_for_emol.pt\",\n",
    "        starting_molecules=starting_mols,\n",
    "        viz=False,\n",
    "    )\n",
    "\n",
    "    for i, mol in batch:\n",
    "        mol_folder = f\"{folder}/{i}/\"\n",
    "        if os.path.exists(mol_folder + \"result.json\") and not force:\n",
    "            continue\n",
    "        os.makedirs(mol_folder, exist_ok=True)\n",
    "        r = planner.plan(mol, target_molid=i)\n",
    "        r[\"args\"] = dict(folder=folder, mol=mol, use_value_fn=use_value_fn, iterations=iterations, expansion_topk=expansion_topk)\n",
    "        with open(mol_folder + \"result.json\",\"w\") as f:\n",
    "            json.dump(r, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26690a13-0e56-46a3-8649-058b2b44aabc",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"externals/eg-mcts/eg_mcts/dataset/retro190.txt\") as f:\n",
    "    lines = [x.strip() for x in f.readlines()]\n",
    "retro_data = lines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e6a18ac-a027-405e-b701-a22307aedc7a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "num_workers = 32 #min(len(retro_data), os.cpu_count())\n",
    "batches = [x.tolist() for x in np.array_split([(i,y) for i,y in enumerate(retro_data)], num_workers)]\n",
    "\n",
    "for iterations in [10,50,100,200,300,400,500]:\n",
    "    for expansion_topk in [50]:\n",
    "        use_value_fn = True\n",
    "        \n",
    "        folder = f\"data/chem/eg-mcts/{iterations}_{expansion_topk}_{use_value_fn}/\"\n",
    "        \n",
    "        \n",
    "        plan_handle = partial(run_batch, iterations=iterations, expansion_topk=expansion_topk, \n",
    "                              use_value_fn=use_value_fn, folder=folder, force=False)\n",
    "        \n",
    "        with Pool(num_workers) as p:\n",
    "            p.map(plan_handle, batches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e79bd80-25f4-4145-86d7-98f338b8925f",
   "metadata": {},
   "outputs": [],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a999c6b-239b-4dbe-abe8-525155964fe5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hp_env2",
   "language": "python",
   "name": "hp_env2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
