{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import subprocess\n",
    "import numpy as np\n",
    "from Bio.PDB.PDBParser import PDBParser\n",
    "import warnings\n",
    "import yaml\n",
    "import glob\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem.rdMolAlign import CalcRMS\n",
    "from easydict import EasyDict\n",
    "import json\n",
    "import re\n",
    "import csv\n",
    "import pandas as pd\n",
    "import shutil\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pocket_path = './data_crossdocked/test.yaml'           # './data_crossdocked/test.yaml'\n",
    "ori_vina_path = '/home/nic/Code/HGNN-GPT/GPT-last-new-2/crossdocked/dock_result2/pocket_vina.csv'\n",
    "json_file4='./dock_file_save/crossdocked/2025_01_12_16_1743065493/dock_result/dock_dict.json'\n",
    "json_file3='./dock_file_save/crossdocked/2025_01_05_20_1741635931/dock_result/dock_dict.json'\n",
    "\n",
    "smiles_yaml4='./save/pre/crossdocked/char/all/2025_01_12_16/sample_300_30_True_1_1_1743065493/1ai4_A_rec_1ai5_mnp_lig_tt_docked_0_pocket10_sampled_temp1.yaml'\n",
    "smiles_yaml3='./save/pre/crossdocked/char/hgnn/2025_01_05_20/sample_300_30_True_1_1_1741635931/1ai4_A_rec_1ai5_mnp_lig_tt_docked_0_pocket10_sampled_temp1.yaml'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(pocket_path, 'r') as f:\n",
    "    pocket_dict = yaml.full_load(f)\n",
    "pocket_names=list(pocket_dict.keys())\n",
    "\n",
    "\n",
    "ori_vina = {}\n",
    "with open(ori_vina_path, 'r') as file:\n",
    "    csv_reader = csv.DictReader(file)\n",
    "    for row in csv_reader:\n",
    "        ligand_name = row['pocket_name']\n",
    "        affinity = float(row['affinity'])\n",
    "        ori_vina[ligand_name] = affinity\n",
    "\n",
    "with open(json_file3, 'r') as f:\n",
    "    dock_data3 = json.load(f)\n",
    "\n",
    "with open(json_file4, 'r') as f:\n",
    "    dock_data4 = json.load(f)\n",
    "\n",
    "with open(smiles_yaml3, 'r') as f:\n",
    "    mol_dict3 = yaml.full_load(f)\n",
    "mol_num3=list(mol_dict3.values())\n",
    "\n",
    "with open(smiles_yaml4, 'r') as f:\n",
    "    mol_dict4 = yaml.full_load(f)\n",
    "mol_num4=list(mol_dict4.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(mol_num3)\n",
    "print(len(mol_num3))\n",
    "print(mol_num4)\n",
    "print(len(mol_num4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_all_affinity(one_pocket_name,dock_data):\n",
    "    pass\n",
    "    affinity_values = {}\n",
    "    for key, values in dock_data.items():\n",
    "        for record in values:\n",
    "            if record.get('mode_id') == 0:\n",
    "                affinity_values[key] = record.get('affinity', None)\n",
    "                break\n",
    "    \n",
    "    pocket_dock_values={}\n",
    "    for key,value in affinity_values.items():\n",
    "        pocket_name = \"_\".join(key.split(\"_\")[:-1])\n",
    "        if pocket_name not in pocket_dock_values:\n",
    "            pocket_dock_values[pocket_name]=[]\n",
    "        pocket_dock_values[pocket_name].append((key,value))\n",
    "\n",
    "    one_pocket_affinity = pocket_dock_values[one_pocket_name]\n",
    "\n",
    "    return [value for _, value in one_pocket_affinity]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "one_pocket_name='1ai4_A_rec_1ai5_mnp_lig_tt_docked_0_pocket10'\n",
    "one_pocket_name_path='PAC_ECOLX_27_846_0/1ai4_A_rec_1ai5_mnp_lig_tt_docked_0_pocket10.pdb'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_affinity_3=get_all_affinity(one_pocket_name,dock_data3)\n",
    "print(unique_affinity_3)\n",
    "print(len(unique_affinity_3))\n",
    "\n",
    "unique_affinity_4=get_all_affinity(one_pocket_name,dock_data4)\n",
    "print(unique_affinity_4)\n",
    "print(len(unique_affinity_4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "complete_data3 = []\n",
    "for data, count in zip(unique_affinity_3, mol_num3):\n",
    "    complete_data3.extend([data] * count)\n",
    "\n",
    "print(len(complete_data3))\n",
    "print(complete_data3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "complete_data4 = []\n",
    "for data, count in zip(unique_affinity_4, mol_num4):\n",
    "    complete_data4.extend([data] * count)\n",
    "\n",
    "print(len(complete_data4))\n",
    "print(complete_data4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(complete_data3, columns=['PHy2Mol'])\n",
    "palette = sns.color_palette(['#a1dab4'])\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(5, 5))\n",
    "ax = sns.boxplot( data=df,width=0.2,showmeans=True,palette=palette,meanprops={\n",
    "                                                  \"markerfacecolor\":\"white\",\n",
    "                                                  \"markeredgecolor\":\"black\",\n",
    "                                                  \"markersize\":\"10\"})\n",
    "ax.set_ylabel('vina  score',fontsize=12)\n",
    "ax.set_xticklabels(['PHy2Mol'])\n",
    "ax.set_xlim(-0.3, 0.3)\n",
    "means = df['PHy2Mol'].mean()\n",
    "ax.text(0.12, means+0.1, f'mean={means:.2f}', ha='left', va='center', fontsize=10, color='black')\n",
    "\n",
    "median_value = df['PHy2Mol'].median()\n",
    "ax.text(0.12, median_value-0.1, f'median={median_value:.2f}', ha='left', va='center', fontsize=10, color='black')\n",
    "\n",
    "plt.tight_layout() \n",
    "plt.savefig('./figure/3_vina_score.png',dpi=500)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([pd.Series(complete_data3), pd.Series(complete_data4)], axis=1)\n",
    "df.columns = ['PHy2Mol', 'MB2Mol']\n",
    "fig, ax = plt.subplots(figsize=(8, 7))\n",
    "# palette = {\"PHy2Mol\": \"#3ABF99\", \"MB2Mol\": \"#ED8D5A\"}\n",
    "palette = {\"PHy2Mol\": \"#a1dab4\", \"MB2Mol\": \"#5dbfe9\"}\n",
    "# palette = sns.color_palette(['#a1dab4', '#5dbfe9'])\n",
    "ax = sns.boxplot( data=df,width=0.3,showmeans=True,palette=palette,meanprops={\n",
    "                                                  \"markerfacecolor\":\"white\",\n",
    "                                                  \"markeredgecolor\":\"black\",\n",
    "                                                  \"markersize\":\"8\"})\n",
    "ax.set_ylabel('vina  score',fontsize=12)\n",
    "ax.set_xticklabels(['PHy2Mol', 'MB2Mol'])\n",
    "# sns.boxplot( x=df,width=0.3,palette=my_pal)\n",
    "\n",
    "\n",
    "means = df.mean()\n",
    "for i, (col, mean) in enumerate(means.items()):\n",
    "    ax.text(i + 0.18, mean + 0.1, f'mean={mean:.2f}', ha='left', va='center', fontsize=10, color='black')\n",
    "\n",
    "medians = df.median()  \n",
    "for i, (col, median) in enumerate(medians.items()):\n",
    "    ax.text(i + 0.18, median - 0.1, f'median={median:.2f}', ha='left', va='center', fontsize=10, color='black')\n",
    "plt.savefig('./figure/3-4_vina_score.png',dpi=500)\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "HGNN-GPT",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
