{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%cd ../.."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os.path\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from tdc.single_pred import Tox, ADME, HTS, QM\n",
    "from tdc.utils import retrieve_label_name_list\n",
    "from tqdm import tqdm\n",
    "from IPython.display import clear_output\n",
    "\n",
    "from molDistill.baselines.utils.tdc_dataset import correspondancy_dict, get_dataset\n",
    "\n",
    "DATASET_metadata = {\n",
    "    \"LD50_Zhu\": (\"LD50\", \"Tox\", \"reg\"),\n",
    "    \"Caco2_Wang\": (\"Caco2\", \"Absorption\", \"reg\"),\n",
    "    \"Lipophilicity_AstraZeneca\": (\"Lipophilicity\", \"Absorption\", \"reg\"),\n",
    "    \"Solubility_AqSolDB\": (\"Solubility\", \"Absorption\", \"reg\"),\n",
    "    \"HydrationFreeEnergy_FreeSolv\": (\"FreeSolv\", \"Absorption\", \"reg\"),\n",
    "    \"PPBR_AZ\": (\"PPBR\", \"Distribution\", \"reg\"),\n",
    "    \"VDss_Lombardo\": (\"VDss\", \"Distribution\", \"reg\"),\n",
    "    \"Half_Life_Obach\" : (\"Half Life\", \"Excretion\", \"reg\"),\n",
    "    \"Clearance_Hepatocyte_AZ\" : (\"Clearance (H)\", \"Excretion\", \"reg\"),\n",
    "    \"Clearance_Microsome_AZ\" : (\"Clearance (M)\", \"Excretion\", \"reg\"),\n",
    "    \"hERG\": (\"hERG\", \"Tox\", \"cls\"),\n",
    "    \"hERG_Karim\": (\"hERG (k)\", \"Tox\", \"cls\"),\n",
    "    \"AMES\": (\"AMES\", \"Tox\", \"cls\"),\n",
    "    \"DILI\": (\"DILI\", \"Tox\", \"cls\"),\n",
    "    \"Carcinogens_Lagunin\": (\"Carcinogens\", \"Tox\", \"cls\"),\n",
    "    \"Skin__Reaction\": (\"Skin R\", \"Tox\", \"cls\"),\n",
    "    \"Tox21\": (\"Tox21\", \"Tox\", \"cls\"),\n",
    "    \"ClinTox\": (\"ClinTox\", \"Tox\", \"cls\"),\n",
    "    \"ToxCast\": (\"ToxCast\", \"Tox\", \"cls\"),\n",
    "    \"PAMPA_NCATS\": (\"PAMPA\", \"Absorption\", \"cls\"),\n",
    "    \"HIA_Hou\": (\"HIA\", \"Absorption\", \"cls\"),\n",
    "    \"Pgp_Broccatelli\": (\"Pgp\", \"Absorption\", \"cls\"),\n",
    "    \"Bioavailability_Ma\": (\"Bioavailability\", \"Absorption\", \"cls\"),\n",
    "    \"BBB_Martins\": (\"BBB\", \"Distribution\", \"cls\"),\n",
    "    \"CYP2C19_Veith\": (\"CYP2C19\", \"Metabolism\", \"cls\"),\n",
    "    \"CYP2D6_Veith\": (\"CYP2D6\", \"Metabolism\", \"cls\"),\n",
    "    \"CYP3A4_Veith\": (\"CYP3A4\", \"Metabolism\", \"cls\"),\n",
    "    \"CYP1A2_Veith\": (\"CYP1A2\", \"Metabolism\", \"cls\"),\n",
    "    \"CYP2C9_Veith\": (\"CYP2C9\", \"Metabolism\", \"cls\"),\n",
    "    \"CYP2C9_Substrate_CarbonMangels\" : (\"CYP2C9 (s)\", \"Metabolism\", \"cls\"),\n",
    "    \"CYP2D6_Substrate_CarbonMangels\" : (\"CYP2D6 (s)\", \"Metabolism\", \"cls\"),\n",
    "    \"CYP3A4_Substrate_CarbonMangels\" : (\"CYP3A4 (s)\", \"Metabolism\", \"cls\"),\n",
    "    \"HIV\": (\"HIV\", \"HTS\", \"cls\")\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "datasets = correspondancy_dict.keys()\n",
    "\n",
    "df_metadata = pd.DataFrame(columns=[\n",
    "    \"dataset\", \"task_type\", \"category\", \"n_samples\", \"balanced\", \"short_name\", \"n_tasks\"\n",
    "])\n",
    "\n",
    "for d in DATASET_metadata.keys():\n",
    "    print(d)\n",
    "    if correspondancy_dict[d] in [Tox, ADME, HTS]:\n",
    "        try:\n",
    "            labels = retrieve_label_name_list(d)\n",
    "        except Exception as e:\n",
    "            labels = [None]\n",
    "\n",
    "        n_samples = 0\n",
    "        bal = []\n",
    "        task_type = DATASET_metadata[d][2]\n",
    "        for l in tqdm(labels):\n",
    "            df_task = correspondancy_dict[d](name=d, label_name=l).get_data()\n",
    "            n_samples += df_task.shape[0]\n",
    "            if task_type == \"cls\":\n",
    "                bal.append(abs(0.5-df_task.Y.mean()))\n",
    "            else:\n",
    "                bal.append(df_task.Y.std())\n",
    "\n",
    "        n_tasks = len(labels)\n",
    "        bal = sum(bal)/n_tasks\n",
    "        n_samples = n_samples/n_tasks\n",
    "\n",
    "\n",
    "        row = [d, task_type, DATASET_metadata[d][1], n_samples, bal, DATASET_metadata[d][0], n_tasks]\n",
    "        df_metadata.loc[len(df_metadata)] = row\n",
    "        clear_output()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "df_metadata.to_csv(\"molDistill/df_metadata.csv\", index=False)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "df_metadata"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
