{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "from rdkit import Chem\n",
    "import numpy as np\n",
    "from rdkit.Chem import MACCSkeys\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from scipy.stats import gaussian_kde\n",
    "import yaml\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_path='../data_crossdocked/final_filter_train.yaml'\n",
    "\n",
    "with open(train_path, 'r') as f:\n",
    "    train_smi = list(yaml.full_load(f).values())\n",
    "\n",
    "train_path2='/home/nic/Code/HGNN-GPT/GPT-last-new-2/datasets/chembl_smi_len_2.smi'\n",
    "with open(train_path2, 'r') as f:\n",
    "    train_smi2 = f.readlines()\n",
    "\n",
    "\n",
    "df_train = pd.DataFrame()\n",
    "df_train['smiles'] =train_smi + [smi.strip() for smi in train_smi2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen3_path='../save/pre/crossdocked/char/hgnn/2025_01_05_20/sample_300_30_True_1_1_1741635931/metrics_each.csv'\n",
    "df_g3 = pd.read_csv(gen3_path)\n",
    "df_g3 = df_g3[['smiles']]\n",
    "\n",
    "gen4_path='../save/pre/crossdocked/char/all/2025_01_12_16/sample_300_30_True_1_1_1743065493/metrics_each.csv'\n",
    "df_g4 = pd.read_csv(gen4_path)\n",
    "df_g4 = df_g4[['smiles']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def smi_to_maccs(smiles):\n",
    "    molecule = Chem.MolFromSmiles(smiles)\n",
    "    if molecule is None: \n",
    "        return None\n",
    "    maccs_fp = MACCSkeys.GenMACCSKeys(molecule)\n",
    "    maccs_fp_bit  = list(maccs_fp.ToBitString())\n",
    "    maccs_fp_int = [int(bit) for bit in maccs_fp_bit]\n",
    "    maccs_array = np.array(maccs_fp_int)\n",
    "    return maccs_fp_int"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train['MACCS'] = df_train['smiles'].apply(smi_to_maccs)\n",
    "df_g3['MACCS'] = df_g3['smiles'].apply(smi_to_maccs)\n",
    "df_g4['MACCS'] = df_g4['smiles'].apply(smi_to_maccs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pca = PCA(n_components=2)\n",
    "X1_pca = pca.fit_transform(df_train['MACCS'].tolist())\n",
    "\n",
    "X3_pca = pca.transform(df_g3['MACCS'].tolist())\n",
    "X4_pca = pca.transform(df_g4['MACCS'].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig_path='../figure/'\n",
    "plt.rcParams['font.sans-serif'] = ['Noto Sans CJK'] \n",
    "# plt.rcParams['font.sans-serif'] = ['SimHei']\n",
    "plt.rcParams['axes.unicode_minus'] = False\n",
    "plt.figure(figsize=(8, 7),dpi=300)\n",
    "plt.scatter(X1_pca[:, 0], X1_pca[:, 1], label='Drug Set 1', edgecolor='grey', facecolor='none',marker='o', alpha=0.2, s=25)  \n",
    "plt.scatter(X3_pca[:, 0], X3_pca[:, 1], label='Drug Set 2', color='blue', marker='^', alpha=0.2, s=25)   \n",
    "# plt.scatter(X2_pca[:, 0], X2_pca[:, 1], label='Drug Set 3', color='red', marker='s', alpha=0.5, s=25) \n",
    "\n",
    "plt.xlabel('The first principal component', color='black')\n",
    "plt.ylabel('The second principal component', color='black')\n",
    "# plt.legend() \n",
    "# plt.grid(True)\n",
    "plt.tick_params(colors='black')\n",
    "plt.savefig(f'{fig_path}/train-3.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 7),dpi=300)\n",
    "plt.scatter(X3_pca[:, 0], X3_pca[:, 1], label='Drug Set 1', edgecolor='blue', facecolor='none', marker='^', alpha=0.4, s=25)  \n",
    "plt.scatter(X4_pca[:, 0], X4_pca[:, 1], label='Drug Set 3', color='pink', marker='s', alpha=0.3, s=25) \n",
    "# plt.scatter(X2_pca[:, 0], X2_pca[:, 1], label='Drug Set 3', color='red', marker='s', alpha=0.5, s=25) \n",
    "\n",
    "plt.xlabel('The first principal component', color='black')\n",
    "plt.ylabel('The second principal component', color='black')\n",
    "# plt.legend() \n",
    "# plt.grid(True)\n",
    "plt.tick_params(colors='black')\n",
    "plt.savefig(f'{fig_path}/3-4.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig_path = '../figure/pca/'\n",
    "\n",
    "# root_path = 'Ch31_E51'\n",
    "# file_path = f'../checkpoints/{root_path}/sampled_molecules.out' \n",
    "\n",
    "df_train = pd.read_csv('../dataset/chembl31_custcleaned.smi', header=None)\n",
    "df_train.columns = ['smi']\n",
    "\n",
    "df_g3 = pd.read_csv('../checkpoints/Ch31_conAll_ais1024_e10/sampled_molecules.out', header=None) \n",
    "df_g3.columns = ['smi']\n",
    "\n",
    "df_g4 = pd.read_csv('../checkpoints/Con_AIS1024E10_ComLoss_3/sampled_molecules.out', header=None) \n",
    "df_g4.columns = ['smi']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_a = pd.read_csv('../dataset/oldseed5.smi', header=None)\n",
    "df_a.columns = ['smi']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def smi_to_maccs(smiles):\n",
    "    molecule = Chem.MolFromSmiles(smiles)\n",
    "    if molecule is None: \n",
    "        return None\n",
    "    maccs_fp = MACCSkeys.GenMACCSKeys(molecule)\n",
    "    maccs_fp_bit  = list(maccs_fp.ToBitString())\n",
    "    maccs_fp_int = [int(bit) for bit in maccs_fp_bit]\n",
    "    maccs_array = np.array(maccs_fp_int)\n",
    "    return maccs_fp_int"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train['MACCS'] = df_train['smi'].apply(smi_to_maccs)\n",
    "df_g3['MACCS'] = df_g3['smi'].apply(smi_to_maccs)\n",
    "df_g4['MACCS'] = df_g4['smi'].apply(smi_to_maccs)\n",
    "\n",
    "\n",
    "# scaler = StandardScaler()\n",
    "# X1_scaled = scaler.fit_transform(df_train['MACCS'].tolist())\n",
    "# X3_scaled = scaler.transform(df_g3['MACCS'].tolist())\n",
    "# X4_scaled = scaler.transform(df_g4['MACCS'].tolist())\n",
    "\n",
    "pca = PCA(n_components=2)\n",
    "X1_pca = pca.fit_transform(df_train['MACCS'].tolist())\n",
    "\n",
    "X3_pca = pca.transform(df_g3['MACCS'].tolist())\n",
    "X4_pca = pca.transform(df_g4['MACCS'].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_a['MACCS'] = df_a['smi'].apply(smi_to_maccs)\n",
    "X2_pca = pca.transform(df_a['MACCS'].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['font.sans-serif'] = ['SimHei']\n",
    "plt.rcParams['axes.unicode_minus'] = False  \n",
    "plt.figure(figsize=(8, 7),dpi=300)\n",
    "plt.scatter(X1_pca[:, 0], X1_pca[:, 1], label='Drug Set 1', edgecolor='grey', facecolor='none',marker='o', alpha=0.2, s=25)  \n",
    "plt.scatter(X3_pca[:, 0], X3_pca[:, 1], label='Drug Set 2', color='blue', marker='^', alpha=0.2, s=25)   \n",
    "# plt.scatter(X2_pca[:, 0], X2_pca[:, 1], label='Drug Set 3', color='red', marker='s', alpha=0.5, s=25)\n",
    "\n",
    "plt.xlabel('The first principal component', color='black')\n",
    "plt.ylabel('The second principal component', color='black')\n",
    "# plt.legend() \n",
    "# plt.grid(True)\n",
    "plt.tick_params(colors='black')\n",
    "plt.savefig(f'{fig_path}/train-3.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 7),dpi=300)\n",
    "plt.scatter(X3_pca[:, 0], X3_pca[:, 1], label='Drug Set 1', edgecolor='blue', facecolor='none', marker='^', alpha=0.4, s=25)  \n",
    "plt.scatter(X4_pca[:, 0], X4_pca[:, 1], label='Drug Set 3', color='pink', marker='s', alpha=0.3, s=25) \n",
    "# plt.scatter(X2_pca[:, 0], X2_pca[:, 1], label='Drug Set 3', color='red', marker='s', alpha=0.5, s=25) \n",
    "plt.xlabel('The first principal component', color='black')\n",
    "plt.ylabel('The second principal component', color='black')\n",
    "# plt.legend() \n",
    "# plt.grid(True)\n",
    "plt.tick_params(colors='black')\n",
    "plt.savefig(f'{fig_path}/3-4.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
