{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip install umap-learn\n",
    "# pip install umap\n",
    "import os, sys\n",
    "import time\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
    "import torch\n",
    "import umap\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.manifold import TSNE\n",
    "from rdkit.Chem import AllChem as Chem\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import random\n",
    "import yaml\n",
    "random.seed(16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_fp(list_of_smi):\n",
    "    fingerprints = []\n",
    "    mols = [Chem.MolFromSmiles(x) for x in list_of_smi]\n",
    "    idx_to_remove = []\n",
    "    for idx,mol in enumerate(mols):\n",
    "        try:\n",
    "            fprint = Chem.GetMorganFingerprintAsBitVect(mol, 2, useFeatures=False)\n",
    "            fingerprints.append(fprint)\n",
    "        except:\n",
    "            idx_to_remove.append(idx)\n",
    "    \n",
    "    smi_to_keep = [smi for i,smi in enumerate(list_of_smi) if i not in idx_to_remove]\n",
    "    return fingerprints, smi_to_keep\n",
    "\n",
    "def get_embedding(data):\n",
    "    data_scaled = StandardScaler().fit_transform(data)\n",
    "    \n",
    "    embedding = umap.UMAP(n_neighbors=10,\n",
    "                          min_dist=0.5,\n",
    "                          metric='correlation',\n",
    "                          random_state=16).fit_transform(data_scaled)\n",
    "    \n",
    "    return embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_umap(embedding, lim_origin):\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(14, 10))\n",
    "\n",
    "    plt.xlim([np.min(embedding[:,0])-0.5, np.max(embedding[:,0])+1.5])\n",
    "    plt.ylim([np.min(embedding[:,1])-0.5, np.max(embedding[:,1])+0.5])\n",
    "\n",
    "    labelsize = 20\n",
    "    plt.xlabel('UMAP1', fontsize=labelsize, fontproperties=\"SimHei\") \n",
    "    plt.ylabel('UMAP2', fontsize=labelsize, fontproperties=\"SimHei\") \n",
    "\n",
    "    # Hide the right and top spines\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    \n",
    "    plt.scatter(embedding[:lim_origin, 0], embedding[:lim_origin, 1], \n",
    "                lw=0, c='#8b8d94', label='Real Molecule', alpha=0.85, s=180,\n",
    "                marker=\"o\") # c1c0c4 9d9ea1 8b8d94 6c6f78\n",
    "    \n",
    "    plt.scatter(embedding[lim_origin:, 0], embedding[lim_origin:, 1], \n",
    "                lw=0, c='#0d0be9', label='Generate Molecule', alpha=1.0, s=180, \n",
    "                marker=\"^\") # , edgecolors='k', linewidth=1\n",
    "    \n",
    "    leg = plt.legend(prop={'size': labelsize}, loc='upper right', markerscale=1.50, scatteryoffsets=[0.5, 0.5, 0.5], frameon=False, labelspacing=1.00, handlelength=0.5)\n",
    "    leg.get_frame().set_alpha(0.9)\n",
    "    leg.get_frame().set_edgecolor('white')\n",
    "    \n",
    "    plt.setp(ax, xticks=[], yticks=[])\n",
    "    plt.savefig(\"MolecularChemicalSpatialDistribution.png\", dpi=300)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "random.seed(1234)\n",
    "train_path='../data_crossdocked/final_filter_train.yaml'\n",
    "\n",
    "with open(train_path, 'r') as f:\n",
    "    train_smi = list(yaml.full_load(f).values())\n",
    "\n",
    "gen3_path='../save/pre/crossdocked/char/hgnn/2025_01_05_20/sample_300_30_True_1_1_1741635931/metrics_each.csv'\n",
    "df_g3 = pd.read_csv(gen3_path)\n",
    "smiles_generate = df_g3['smiles'].tolist()\n",
    "\n",
    "\n",
    "\n",
    "smiles_origin = random.sample(train_smi, 1000)\n",
    "smiles_generate = random.sample(smiles_generate, 1000)\n",
    "\n",
    "\n",
    "origin_fp, origin_smiles = get_fp(smiles_origin)\n",
    "\n",
    "origin_fp = np.array(origin_fp)\n",
    "generate_fp, generate_smiles = get_fp(smiles_generate)\n",
    "generate_fp = np.array(generate_fp)\n",
    "# print(origin_fp.shape)\n",
    "# print(generate_fp.shape)\n",
    "\n",
    "\n",
    "all_data = np.concatenate([origin_fp, generate_fp], axis=0)\n",
    "# print(all_data.shape)\n",
    "embedding = get_embedding(all_data)\n",
    "# print(embedding.shape)\n",
    "# print(embedding)\n",
    "lim_origin = origin_fp.shape[0]\n",
    "draw_umap(embedding, lim_origin)\n",
    "end = time.time()\n",
    "print(f'UMAP PROJECTION DONE in {end - start:.04} seconds')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "random.seed(1234)\n",
    "origin_data = pd.read_csv(\"protein_ligand.csv\") # , sep=\"\\t\"\n",
    "smiles_origin = origin_data[\"SMILES\"].tolist()\n",
    "\n",
    "smiles_generate = []\n",
    "generate_data = pd.read_csv(\"MMF2Drug_all.csv\")\n",
    "gen_data = generate_data[\"SMILES\"]\n",
    "for smiles in gen_data:\n",
    "    mol = Chem.MolFromSmiles(smiles)\n",
    "    if mol is not None and len(smiles) > 20:\n",
    "        smiles_generate.append(smiles)\n",
    "        \n",
    "smiles_origin = random.sample(smiles_origin, 1000)\n",
    "smiles_generate = random.sample(smiles_generate, 1000)\n",
    "origin_fp, origin_smiles = get_fp(smiles_origin)\n",
    "\n",
    "origin_fp = np.array(origin_fp)\n",
    "generate_fp, generate_smiles = get_fp(smiles_generate)\n",
    "generate_fp = np.array(generate_fp)\n",
    "# print(origin_fp.shape)\n",
    "# print(generate_fp.shape)\n",
    "all_data = np.concatenate([origin_fp, generate_fp], axis=0)\n",
    "# print(all_data.shape)\n",
    "embedding = get_embedding(all_data)\n",
    "# print(embedding.shape)\n",
    "# print(embedding)\n",
    "\n",
    "lim_origin = origin_fp.shape[0]\n",
    "draw_umap(embedding, lim_origin)\n",
    "end = time.time()\n",
    "print(f'UMAP PROJECTION DONE in {end - start:.04} seconds')"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
