{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "spiritual-search",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "from torch_geometric.data import DataLoader, Batch, Data\n",
    "from embedding import GNN\n",
    "from downstream import MLP\n",
    "from tagexplainer import TAGExplainer, MLPExplainer\n",
    "from loader import MoleculeDataset\n",
    "from splitters import scaffold_split\n",
    "from dig.xgraph.evaluation import XCollector\n",
    "\n",
    "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else torch.device(\"cpu\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "statutory-monitor",
   "metadata": {},
   "outputs": [],
   "source": [
    "pre_train_dataset = MoleculeDataset(\"dataset/zinc_standard_agent\", dataset='zinc_standard_agent')\n",
    "train_loader = DataLoader(pre_train_dataset, 256, shuffle=True)\n",
    "\n",
    "embed_model = GNN(num_layer = 5, emb_dim = 600, JK = 'last', drop_ratio = 0, gnn_type = 'gin')\n",
    "embed_model.load_state_dict(torch.load('ckpts_model/chem_pretrained_contextpred.pth', map_location='cpu'))\n",
    "enc_explainer = TAGExplainer(embed_model, embed_dim=600, device=device, explain_graph=True, \n",
    "                              grad_scale=0.2, coff_size=0.05, coff_ent=0.002, loss_type='JSE')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "closed-windows",
   "metadata": {},
   "source": [
    "#### To train the explainer, uncomment the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "union-survivor",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7813/7813 [23:56<00:00,  5.44it/s, loss=0.0217, log=75.2979, 1.0000, -21.9580, 0.0079]  \n"
     ]
    }
   ],
   "source": [
    "# enc_explainer.train_explainer_graph(train_loader, lr=0.0001, epochs=1)\n",
    "# torch.save(enc_explainer.explainer.state_dict(), 'ckpts_explainer/explain_mol_twostage.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "massive-straight",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state_dict = torch.load('ckpts_explainer/explain_mol_twostage.pt')\n",
    "enc_explainer.explainer.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "raised-patrol",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_task(idx):\n",
    "    def transform(data):\n",
    "        return Data(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, y=data.y[idx:idx+1].long())\n",
    "    return transform\n",
    "\n",
    "def get_dataset(name, task=0):\n",
    "    task_transform = get_task(task)\n",
    "    dataset = MoleculeDataset(\"dataset/%s\"%name, dataset=name, transform=task_transform)\n",
    "    smiles_list = pd.read_csv('dataset/%s/processed/smiles.csv'%name, header=None)[0].tolist()\n",
    "    train_dataset, valid_dataset, test_dataset = scaffold_split(\n",
    "        dataset, smiles_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1)\n",
    "    return train_dataset\n",
    "\n",
    "def get_results(task_name, top_k, pos=True):\n",
    "    train_dataset = get_dataset(task_name)\n",
    "    mlp_model = MLP(num_layer = 2, emb_dim =600, hidden_dim = 600)\n",
    "    mlp_model.load_state_dict(torch.load('ckpts_model/downstream_%s_contextpred.pth'%task_name, map_location='cpu'))\n",
    "    mlp_explainer = MLPExplainer(mlp_model, device)\n",
    "\n",
    "    x_collector = XCollector()\n",
    "    dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers = 1)\n",
    "    for i, data in enumerate(dataloader):\n",
    "        if pos==(data.y < 0):\n",
    "            continue\n",
    "        if data.edge_index.shape[1]<=0:\n",
    "            continue\n",
    "\n",
    "        print(f'explain graph {i}...', end='\\r')\n",
    "        walks, masks, related_preds = \\\n",
    "            enc_explainer(data.to(device), mlp_explainer, top_k=top_k, mask_mode='split')\n",
    "\n",
    "        x_collector.collect_data(masks, related_preds)\n",
    "\n",
    "    fid, fid_std = x_collector.fidelity\n",
    "    spa, spa_std = x_collector.sparsity\n",
    "\n",
    "    print()\n",
    "    print(f'Fidelity: {fid:.4f} +/- {fid_std:.4f}\\n'\n",
    "          f'Sparsity: {spa:.4f} +/- {spa_std:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "humanitarian-smoke",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "explain graph 965...\n",
      "Fidelity: 0.3782 +/- 0.2934\n",
      "Sparsity: 0.9026 +/- 0.0278\n"
     ]
    }
   ],
   "source": [
    "get_results('bace', top_k=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "normal-story",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "explain graph 32795...\n",
      "Fidelity: 0.5952 +/- 0.3200\n",
      "Sparsity: 0.8806 +/- 0.0582\n"
     ]
    }
   ],
   "source": [
    "get_results('hiv', top_k=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "dental-confidentiality",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "explain graph 1140...\n",
      "Fidelity: 0.4067 +/- 0.3226\n",
      "Sparsity: 0.8545 +/- 0.0751\n"
     ]
    }
   ],
   "source": [
    "get_results('sider', top_k=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "reported-filling",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "explain graph 1193...\n",
      "Fidelity: 0.1878 +/- 0.1543\n",
      "Sparsity: 0.7205 +/- 0.1162\n"
     ]
    }
   ],
   "source": [
    "get_results('bbbp', top_k=10, pos=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "confidential-carbon",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
