{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Uni-Mol Molecular Property Prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Licenses**\n",
    "\n",
    "Copyright (c) DP Technology.\n",
    "\n",
    "This source code is licensed under the MIT license found in the\n",
    "LICENSE file in the root directory of this source tree.\n",
    "\n",
    "**Citations**\n",
    "\n",
    "Please cite the following papers if you use this notebook:\n",
    "\n",
    "- Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. \"[Uni-Mol: A Universal 3D Molecular Representation Learning Framework.](https://chemrxiv.org/engage/chemrxiv/article-details/6318b529bada388485bc8361)\"\n",
    "ChemRxiv (2022)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Preparation (SMILES, label to .lmdb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import lmdb\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from rdkit import Chem\n",
    "from tqdm import tqdm\n",
    "from rdkit.Chem import AllChem\n",
    "from rdkit import RDLogger\n",
    "RDLogger.DisableLog('rdApp.*')  \n",
    "import warnings\n",
    "warnings.filterwarnings(action='ignore')\n",
    "from multiprocessing import Pool\n",
    "\n",
    "\n",
    "def smi2_2Dcoords(smi):\n",
    "    mol = Chem.MolFromSmiles(smi)\n",
    "    mol = AllChem.AddHs(mol)\n",
    "    AllChem.Compute2DCoords(mol)\n",
    "    coordinates = mol.GetConformer().GetPositions().astype(np.float32)\n",
    "    len(mol.GetAtoms()) == len(coordinates), \"2D coordinates shape is not align with {}\".format(smi)\n",
    "    return coordinates\n",
    "\n",
    "\n",
    "def smi2_3Dcoords(smi,cnt):\n",
    "    mol = Chem.MolFromSmiles(smi)\n",
    "    mol = AllChem.AddHs(mol)\n",
    "    coordinate_list=[]\n",
    "    for seed in range(cnt):\n",
    "        try:\n",
    "            res = AllChem.EmbedMolecule(mol, randomSeed=seed)  # will random generate conformer with seed equal to -1. else fixed random seed.\n",
    "            if res == 0:\n",
    "                try:\n",
    "                    AllChem.MMFFOptimizeMolecule(mol)       # some conformer can not use MMFF optimize\n",
    "                    coordinates = mol.GetConformer().GetPositions()\n",
    "                except:\n",
    "                    print(\"Failed to generate 3D, replace with 2D\")\n",
    "                    coordinates = smi2_2Dcoords(smi)            \n",
    "                    \n",
    "            elif res == -1:\n",
    "                mol_tmp = Chem.MolFromSmiles(smi)\n",
    "                AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed)\n",
    "                mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)\n",
    "                try:\n",
    "                    AllChem.MMFFOptimizeMolecule(mol_tmp)       # some conformer can not use MMFF optimize\n",
    "                    coordinates = mol_tmp.GetConformer().GetPositions()\n",
    "                except:\n",
    "                    print(\"Failed to generate 3D, replace with 2D\")\n",
    "                    coordinates = smi2_2Dcoords(smi) \n",
    "        except:\n",
    "            print(\"Failed to generate 3D, replace with 2D\")\n",
    "            coordinates = smi2_2Dcoords(smi) \n",
    "\n",
    "        assert len(mol.GetAtoms()) == len(coordinates), \"3D coordinates shape is not align with {}\".format(smi)\n",
    "        coordinate_list.append(coordinates.astype(np.float32))\n",
    "    return coordinate_list\n",
    "\n",
    "\n",
    "def inner_smi2coords(content):\n",
    "    smi = content[0]\n",
    "    target = content[1:]\n",
    "    cnt = 10 # conformer num,all==11, 10 3d + 1 2d\n",
    "\n",
    "    mol = Chem.MolFromSmiles(smi)\n",
    "    if len(mol.GetAtoms()) > 400:\n",
    "        coordinate_list =  [smi2_2Dcoords(smi)] * (cnt+1)\n",
    "        print(\"atom num >400,use 2D coords\",smi)\n",
    "    else:\n",
    "        coordinate_list = smi2_3Dcoords(smi,cnt)\n",
    "        coordinate_list.append(smi2_2Dcoords(smi).astype(np.float32))\n",
    "    mol = AllChem.AddHs(mol)\n",
    "    atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]  # after add H \n",
    "    return pickle.dumps({'atoms': atoms, \n",
    "    'coordinates': coordinate_list, \n",
    "    'mol':mol,'smi': smi, 'target': target}, protocol=-1)\n",
    "\n",
    "\n",
    "def smi2coords(content):\n",
    "    try:\n",
    "        return inner_smi2coords(content)\n",
    "    except:\n",
    "        print(\"failed smiles: {}\".format(content[0]))\n",
    "        return None\n",
    "\n",
    "\n",
    "def write_lmdb(inpath='./', outpath='./', nthreads=16):\n",
    "\n",
    "    df = pd.read_csv(os.path.join(inpath))\n",
    "    sz = len(df)\n",
    "    train, valid, test = df[:int(sz*0.8)], df[int(sz*0.8):int(sz*0.9)], df[int(sz*0.9):]\n",
    "    for name, content_list in [('train.lmdb', zip(*[train[c].values.tolist() for c in train])),\n",
    "                                ('valid.lmdb', zip(*[valid[c].values.tolist() for c in valid])),\n",
    "                                ('test.lmdb', zip(*[test[c].values.tolist() for c in test]))]:\n",
    "        os.makedirs(outpath, exist_ok=True)\n",
    "        output_name = os.path.join(outpath, name)\n",
    "        try:\n",
    "            os.remove(output_name)\n",
    "        except:\n",
    "            pass\n",
    "        env_new = lmdb.open(\n",
    "            output_name,\n",
    "            subdir=False,\n",
    "            readonly=False,\n",
    "            lock=False,\n",
    "            readahead=False,\n",
    "            meminit=False,\n",
    "            max_readers=1,\n",
    "            map_size=int(100e9),\n",
    "        )\n",
    "        txn_write = env_new.begin(write=True)\n",
    "        with Pool(nthreads) as pool:\n",
    "            i = 0\n",
    "            for inner_output in tqdm(pool.imap(smi2coords, content_list)):\n",
    "                if inner_output is not None:\n",
    "                    txn_write.put(f'{i}'.encode(\"ascii\"), inner_output)\n",
    "                    i += 1\n",
    "            print('{} process {} lines'.format(name, i))\n",
    "            txn_write.commit()\n",
    "            env_new.close()\n",
    "\n",
    "write_lmdb(inpath='mol_property_demo.csv', outpath='./demo', nthreads=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finetuning (based on pretraining)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path='./'  # replace to your data path\n",
    "save_dir='./save_demo'  # replace to your save path\n",
    "MASTER_PORT=10086\n",
    "n_gpu=1\n",
    "dict_name='dict.txt'\n",
    "weight_path='./weights/mol_pre_no_h_220816.pt'  # replace to your ckpt path\n",
    "task_name='demo'  # data folder name\n",
    "task_num=2\n",
    "loss_func='finetune_cross_entropy'\n",
    "lr=1e-4\n",
    "batch_size=32\n",
    "epoch=5\n",
    "dropout=0.1\n",
    "warmup=0.06\n",
    "local_batch_size=32\n",
    "only_polar=0 # -1 all h; 0 no h\n",
    "conf_size=11\n",
    "seed=0\n",
    "metric=\"valid_agg_auc\"\n",
    "update_freq=batch_size / local_batch_size\n",
    "\n",
    "!cp ../example_data/molecule/$dict_name $data_path\n",
    "!export NCCL_ASYNC_ERROR_HANDLING=1\n",
    "!export OMP_NUM_THREADS=1\n",
    "!python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --task-name $task_name --user-dir ../unimol --train-subset train --valid-subset valid \\\n",
    "       --conf-size $conf_size \\\n",
    "       --num-workers 8 --ddp-backend=c10d \\\n",
    "       --dict-name $dict_name \\\n",
    "       --task mol_finetune --loss $loss_func --arch unimol_base  \\\n",
    "       --classification-head-name $task_name --num-classes $task_num \\\n",
    "       --optimizer adam --adam-betas '(0.9, 0.99)' --adam-eps 1e-6 --clip-norm 1.0 \\\n",
    "       --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $local_batch_size --pooler-dropout $dropout\\\n",
    "       --update-freq $update_freq --seed $seed \\\n",
    "       --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n",
    "       --log-interval 100 --log-format simple \\\n",
    "       --validate-interval 1 --keep-last-epochs 10 \\\n",
    "       --finetune-from-model $weight_path \\\n",
    "       --best-checkpoint-metric $metric --patience 20 \\\n",
    "       --save-dir $save_dir --only-polar $only_polar \\\n",
    "       --maximize-best-checkpoint-metric\n",
    "# --maximize-best-checkpoint-metric, for classification task"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path='./'  # replace to your data path\n",
    "results_path='./infer_demo'  # replace to your results path\n",
    "weight_path='./save_demo/checkpoint_best.pt'  # replace to your ckpt path\n",
    "batch_size=32\n",
    "task_name='demo' # data folder name \n",
    "task_num=2\n",
    "loss_func='finetune_cross_entropy'\n",
    "dict_name='dict.txt'\n",
    "conf_size=11\n",
    "only_polar=0\n",
    "\n",
    "!cp ../example_data/molecule/$dict_name $data_path\n",
    "!CUDA_VISIBLE_DEVICES=\"0\" python ../unimol/infer.py --user-dir ../unimol $data_path --task-name $task_name --valid-subset test \\\n",
    "       --results-path $results_path \\\n",
    "       --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n",
    "       --task mol_finetune --loss $loss_func --arch unimol_base \\\n",
    "       --classification-head-name $task_name --num-classes $task_num \\\n",
    "       --dict-name $dict_name --conf-size $conf_size \\\n",
    "       --only-polar $only_polar  \\\n",
    "       --path $weight_path  \\\n",
    "       --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n",
    "       --log-interval 50 --log-format simple "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read inference results (.pkl to .csv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_csv_results(predict_path, csv_path):\n",
    "    predict = pd.read_pickle(predict_path)\n",
    "    smi_list, predict_list = [], []\n",
    "    for batch in predict:\n",
    "        sz = batch[\"bsz\"]\n",
    "        for i in range(sz):\n",
    "            smi_list.append(batch[\"smi_name\"][i])\n",
    "            predict_list.append(batch[\"prob\"][i][1].cpu().tolist())\n",
    "    predict_df = pd.DataFrame({\"SMILES\": smi_list, \"predict_prob\": predict_list})\n",
    "    predict_df = predict_df.groupby(\"SMILES\")[\"predict_prob\"].mean().reset_index()\n",
    "    predict_df.to_csv(csv_path,index=False)\n",
    "    return predict_df\n",
    "\n",
    "predict_path='./infer_demo/save_demo_test.out.pkl'  # replace to your results path\n",
    "csv_path='./infer_demo/demo_results.csv'\n",
    "predict_df = get_csv_results(predict_path, csv_path)\n",
    "predict_df.info(), predict_df.head()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
