{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# ProtoPNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import hydra\n",
    "import torch\n",
    "import shutil\n",
    "import warnings\n",
    "from torch.optim import Adam\n",
    "from omegaconf import OmegaConf\n",
    "from utils import check_dir\n",
    "from tes_gnnNets import *\n",
    "from dataset import get_dataset, get_dataloader\n",
    "import torch.nn.functional as F\n",
    "from torch.optim.lr_scheduler import MultiStepLR\n",
    "import logging\n",
    "import networkx as nx\n",
    "import torch_geometric\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:100% !important; }</style>\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hydra import compose, initialize\n",
    "from omegaconf import OmegaConf\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(config_path=\"tes_config\", job_name=\"test_app\")\n",
    "cfg = compose(config_name=\"config\", overrides=[\"datasets=NCI1\", \"seed=0\",\"models.param.ba_2motifs.num_basis_per_class=10\"]) #here put the same settings used for the train\n",
    "#print(OmegaConf.to_yaml(cfg))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config.models.gnn_saving_dir = 'tes_checkpoints'\n",
    "config.models.param = config.models.param[config.datasets.dataset_name]\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda', index=config.device_id)\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "\n",
    "dataset = get_dataset(dataset_root=config.datasets.dataset_root,\n",
    "                      dataset_name=config.datasets.dataset_name)\n",
    "dataset.data.x = dataset.data.x.float()\n",
    "dataset.data.y = dataset.data.y.squeeze().long()\n",
    "if config.models.param.graph_classification:\n",
    "    dataloader_params = {'batch_size': 1,\n",
    "                         'random_split_flag': config.datasets.random_split_flag,\n",
    "                         'data_split_ratio': config.datasets.data_split_ratio,\n",
    "                         'seed': config.seed}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls tes_checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = get_dataset(dataset_root=config.datasets.dataset_root,\n",
    "                          dataset_name=config.datasets.dataset_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = get_dataloader(dataset, **dataloader_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = get_gnnNets(dataset.num_node_features, dataset.num_classes, config.models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f\"outputs_tes/datasets={config.datasets.dataset_name},device_id=0,models.param.{config.datasets.dataset_name}.num_basis_per_class=10,models=gcn,seed={config.seed}/tes_checkpoints/{config.seed}/{config.datasets.dataset_name}_from_None/gcn_{config.models.param.num_basis_per_class}_3l_best.pth\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "saved = torch.load(f\"outputs_tes_old/datasets={config.datasets.dataset_name},device_id=0,models.param.{config.datasets.dataset_name}.num_basis_per_class=10,models=gcn,seed={config.seed}/tes_checkpoints/{config.seed}/{config.datasets.dataset_name}_from_None/gcn_{config.models.param.num_basis_per_class}_3l_best.pth\") #here the path to the saved model\n",
    "state_dict = saved['net']\n",
    "basis_data = saved['basis_data']\n",
    "model.load_state_dict(state_dict)\n",
    "model.basis_data = basis_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.basis_concepts #these are the vectors of the prototypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.classifier_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "for b in dataloader['test']:\n",
    "    print(b.y, model(b.x.float(), b.edge_index)[0].argmax(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = dataloader['train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.classifier_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "mpl.rcParams['figure.dpi'] = 300\n",
    "fig, ax = plt.subplots(figsize=(6, 6))\n",
    "basis_id=0 #select prototype you want to visualize\n",
    "data = model.basis_data[basis_id].to('cpu')\n",
    "b = Batch.from_data_list([data])\n",
    "out = model(data=b)\n",
    "logit, embs, sims = out\n",
    "node_id = sims[:,basis_id].argmax().item()\n",
    "nodes = torch_geometric.utils.k_hop_subgraph(node_id, 3, data.edge_index)[0].tolist()\n",
    "import copy\n",
    "for p in model.parameters():\n",
    "    p.requires_grad = True\n",
    "data2 = copy.deepcopy(data)\n",
    "b = Batch.from_data_list([data2])\n",
    "out = model(data=b)\n",
    "logit, embs, sims = out\n",
    "x = sims[:, basis_id].detach().numpy()\n",
    "x = (x-x.min())/(x.max()-x.min())\n",
    "cmap = [x[i] if i in nodes else 0 for i in range(data.num_nodes)]\n",
    "g = torch_geometric.utils.to_networkx(data, to_undirected=True)\n",
    "pos = nx.kamada_kawai_layout(g)\n",
    "nx.draw(g, node_color=cmap, pos=pos, cmap='Blues', edgecolors='black', ax=ax)\n",
    "nx.draw(g.subgraph([node_id]), pos=pos, node_color='#12306b', cmap='Blues', edgecolors='red', ax=ax)\n",
    "\n",
    "# Set common labels\n",
    "fig.tight_layout(pad=5)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_config_ex(dataset=\"ba_2motifs\", seed=3, num_basis_per_class=10, config_name=\"tes\", model_name=\"gcn\"):\n",
    "    hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "    initialize(config_path=config_name + \"_\" + \"config\" , job_name=\"test_app\")\n",
    "    cfg = compose(config_name=\"config\", overrides=[f\"datasets={dataset}\", f\"seed={seed}\", f\"models={model_name}\"])\n",
    "#     print(OmegaConf.to_yaml(cfg))\n",
    "    config = cfg\n",
    "    config.models.gnn_saving_dir = config_name + \"_\" + 'checkpoints'\n",
    "    config.models.param = config.models.param[config.datasets.dataset_name]\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        device = torch.device('cuda', index=config.device_id)\n",
    "    else:\n",
    "        device = torch.device('cpu')\n",
    "\n",
    "    dataset = get_dataset(dataset_root=config.datasets.dataset_root,\n",
    "                          dataset_name=config.datasets.dataset_name)\n",
    "    dataset.data.x = dataset.data.x.float()\n",
    "    dataset.data.y = dataset.data.y.squeeze().long()\n",
    "    if config.models.param.graph_classification:\n",
    "        dataloader_params = {'batch_size': 1,\n",
    "                             'random_split_flag': config.datasets.random_split_flag,\n",
    "                             'data_split_ratio': config.datasets.data_split_ratio,\n",
    "                             'seed': config.seed}\n",
    "    dataset = get_dataset(dataset_root=config.datasets.dataset_root,\n",
    "                          dataset_name=config.datasets.dataset_name)\n",
    "    dataloader = get_dataloader(dataset, **dataloader_params)\n",
    "    model = get_gnnNets(dataset.num_node_features, dataset.num_classes, config.models)\n",
    "    saved = torch.load(f\"outputs_tes/datasets={config.datasets.dataset_name},device_id=2,models=gcn,seed={config.seed}/tes_checkpoints/{config.seed}/{config.datasets.dataset_name}_from_None/gcn_{config.models.param.num_basis_per_class}_3l_latest.pth\", map_location='cuda:0') #here the path to the saved model\n",
    "    # saved = torch.load(f\"outputs_tes_old/datasets={config.datasets.dataset_name},device_id=0,models.param.{config.datasets.dataset_name}.num_basis_per_class=10,models=gcn,seed={config.seed}/tes_checkpoints/{config.seed}/{config.datasets.dataset_name}_from_None/gcn_{config.models.param.num_basis_per_class}_3l_latest.pth\", map_location='cuda:0') #here the path to the saved model\n",
    "    state_dict = saved['net']\n",
    "    basis_data = saved['basis_data']\n",
    "    model.load_state_dict(state_dict)\n",
    "    model.basis_data = basis_data\n",
    "    return config, dataset, dataloader, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config, dataset, dataloader, model = set_config_ex(dataset=dataset_name, seed=i, model_name=\"gcn\", config_name=\"tes\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "l = []\n",
    "seeds = [8971, 85688, 9467, 32830, 28689, 94845, 69840, 50883, 74177, 79585, 1055, 75631, 6825, 93188, 95426, 54514, 31467, 70597, 71149, 81994]\n",
    "dataset_name = 'MUTAG'\n",
    "for i in seeds:\n",
    "    try:\n",
    "        config, dataset, dataloader, model = set_config_ex(dataset=dataset_name, seed=i, model_name=\"gcn\", config_name=\"tes\")\n",
    "    except:\n",
    "        continue\n",
    "    preds, accs = [], []\n",
    "    for batch in dataloader['test']:\n",
    "        batch = batch\n",
    "        batch_preds, _, _ = model(batch.x.float(), batch.edge_index)\n",
    "        batch_preds = batch_preds.argmax(-1)\n",
    "        preds.append(batch_preds)\n",
    "        accs.append(batch_preds == batch.y)\n",
    "    preds = torch.cat(preds, dim=-1)\n",
    "    test_acc = torch.cat(accs, dim=-1).float().mean().item()\n",
    "    l.append(test_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.std(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def explain(model, data, sparsity=0.8, c=None, protop=True):\n",
    "    out, embs, cosines=model(data.x.float(), data.edge_index)    \n",
    "    y_hat=c\n",
    "    if y_hat is None:\n",
    "        y_hat=out.argmax(-1).item()\n",
    "    mask=np.zeros(len(data.x))\n",
    "    if True:\n",
    "        max_nodes = cosines[:,model.num_basis_per_class*y_hat:model.num_basis_per_class*(y_hat+1)].argmax(axis=0).tolist()\n",
    "        for i in range(model.num_basis_per_class*y_hat,model.num_basis_per_class*(y_hat+1)):\n",
    "            try:\n",
    "                to_mask = torch_geometric.utils.k_hop_subgraph(max_nodes[i%model.num_basis_per_class], 3, data.edge_index)[0].tolist()\n",
    "                cosines_i = cosines[max_nodes[i%model.num_basis_per_class],i].item()\n",
    "                mask[to_mask] += model.classifier_weights[i, y_hat].detach().numpy()*cosines_i\n",
    "            except:pass\n",
    "    else:\n",
    "        for i in range(len(data.x)):\n",
    "            mask[i]+= cosines[i,model.num_basis_per_class*y_hat:model.num_basis_per_class*(y_hat+1)].sum().item()\n",
    "#     if protop: mask=torch.log(mask+1)/(mask+1e-8)\n",
    "#     mask = mask*model.classifier_weights[i]\n",
    "    mask=(mask-mask.min())/(mask.max()-mask.min()+1e-8)\n",
    "    p=np.percentile(mask, 100*sparsity)\n",
    "    mask[mask<p]=0\n",
    "    edge_mask=np.zeros(data.edge_index.shape[1])\n",
    "    for i, (a,b) in enumerate(data.edge_index.T):\n",
    "        edge_mask[i]=(mask[a]+mask[b])/2\n",
    "    edge_mask[edge_mask>0]=1\n",
    "    return mask#, edge_mask\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[0].y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "mpl.rcParams['figure.dpi'] = 300\n",
    "fig, ax = plt.subplots(figsize=(6, 6))\n",
    "data = dataset[500]\n",
    "b = Batch.from_data_list([data])\n",
    "out = model(data=b)\n",
    "logit, embs, sims = out\n",
    "\n",
    "print((logit.softmax(-1)[0,0]).item()*100)\n",
    "e = explain(model, data,c=0) - explain(model, data,c=1)\n",
    "# print(e)\n",
    "# cmap = [x[i] if i in nodes else 0 for i in range(data.num_nodes)]\n",
    "g = torch_geometric.utils.to_networkx(data, to_undirected=True)\n",
    "pos = nx.kamada_kawai_layout(g)\n",
    "\n",
    "labels = {node: str(node) for node in g.nodes()}\n",
    "nx.draw(g, node_color=e, pos=pos, cmap='coolwarm', vmin=-1, vmax=1,edgecolors='black', ax=ax, labels=labels)\n",
    "\n",
    "# # Set common labels\n",
    "fig.tight_layout(pad=5)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "e.round()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.data import Data\n",
    "from torch_geometric.utils import subgraph\n",
    "data_orig = data\n",
    "cond = torch.tensor(e.round() <= 0)\n",
    "print(cond)\n",
    "masked_edge_index, _ = subgraph(cond, data_orig.edge_index, relabel_nodes=True, num_nodes=data_orig.x.shape[0])\n",
    "data_mask = Data(x=data.x[cond], y=data.y, edge_index=masked_edge_index, edge_weight=None, batch=torch.zeros_like(cond)) \n",
    "import matplotlib as mpl\n",
    "mpl.rcParams['figure.dpi'] = 300\n",
    "fig, ax = plt.subplots(figsize=(6, 6))\n",
    "b = Batch.from_data_list([data_mask])\n",
    "out = model(data=b)\n",
    "logit, embs, sims = out\n",
    "print((logit.softmax(-1)[0,0]).item()*100)\n",
    "e = explain(model, data_mask,c=0)-explain(model, data_mask,c=1)\n",
    "print(e)\n",
    "# cmap = [x[i] if i in nodes else 0 for i in range(data.num_nodes)]\n",
    "g = torch_geometric.utils.to_networkx(data_mask, to_undirected=True)\n",
    "\n",
    "pos = nx.kamada_kawai_layout(g)\n",
    "labels = {node: str(node) for node in g.nodes()}\n",
    "nx.draw(g, node_color=e, pos=pos, cmap='coolwarm', vmin=-1, vmax=1,edgecolors='black', ax=ax, labels=labels)\n",
    "\n",
    "# # Set common labels\n",
    "fig.tight_layout(pad=5)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def duplicate_and_attach(graph, mask):\n",
    "    import copy\n",
    "    graph = copy.deepcopy(graph)\n",
    "    # Step 1: Identify the indices of the true values in the mask\n",
    "    true_indices = torch.nonzero(mask).view(-1)\n",
    "\n",
    "    if len(true_indices) == 0:\n",
    "        # No true values in the mask, nothing to duplicate and attach\n",
    "        return graph\n",
    "\n",
    "    # Step 2: Randomly pick an index from the true values of the mask\n",
    "    selected_index = torch.randint(len(true_indices), (1,)).item()\n",
    "    selected_node_index = true_indices[selected_index]\n",
    "\n",
    "    # Step 3: Duplicate the node corresponding to the selected index\n",
    "    duplicated_node = graph.x[selected_node_index].clone()\n",
    "    graph.x = torch.cat([graph.x, duplicated_node.view(1, -1)], dim=0)\n",
    "\n",
    "    # Step 4: Attach the duplicated node to another random node in the graph\n",
    "    attached_node_index = torch.randint(len(true_indices), (1,)).item()\n",
    "    graph.edge_index = torch.cat([graph.edge_index, torch.tensor([[selected_node_index, len(graph.x) - 1]], dtype=torch.long).t()], dim=1)\n",
    "\n",
    "    # Update the 'batch' parameter\n",
    "    graph.batch = torch.cat([graph.batch, torch.zeros(1, dtype=torch.long)])\n",
    "\n",
    "    return graph\n",
    "import copy\n",
    "def expl_res(model, dataloader):\n",
    "    expl_true = []\n",
    "    expl_pred = []\n",
    "    expl_stab_true = []\n",
    "    expl_stab_pred = []\n",
    "    from torch_geometric.data import Data\n",
    "    from torch_geometric.utils import subgraph\n",
    "    test_loader = dataloader['test']\n",
    "    fidelity = []\n",
    "    for i, data in enumerate(test_loader):\n",
    "        # if data.y.item() == 1: \n",
    "            # continue\n",
    "        try:\n",
    "            data.batch=torch.zeros(data.x.shape[0])\n",
    "            out = model(data.x.float(), data.edge_index)[0]\n",
    "            c = out.argmax(-1).item()\n",
    "            # expl = torch.tensor(explain(model,data,c=0, sparsity=0.9))>0\n",
    "            expl = torch.tensor(explain(model,data,c=c, sparsity=0.9))>0\n",
    "            # print(expl)\n",
    "            # expl = (1-expl).int()\n",
    "            batch = torch.zeros(data.x.shape[0]).to(data.x.device).long()\n",
    "            data_orig = Data(x=data.x, y=data.y,   edge_index=data.edge_index, edge_weight=None, batch=batch)\n",
    "            masked_edge_index, _ = subgraph(expl==0, data_orig.edge_index, relabel_nodes=True, num_nodes=data_orig.x.shape[0])\n",
    "            data_mask = Data(x=data.x[expl==0], y=data.y, edge_index=masked_edge_index, edge_weight=None, batch=batch[expl==0]) \n",
    "            fidelity.append(model(data_orig.x.float(), data_orig.edge_index)[0].softmax(-1)[0,0].item() - model(data_mask.x.float(), data_mask.edge_index)[0].softmax(-1)[0,0].item())\n",
    "            # print(self.model(data_orig)[0].softmax(-1)[0,0].item(),  self.model(data_mask)[0].softmax(-1)[0,0].item())\n",
    "\n",
    "            data_stab = duplicate_and_attach(copy.deepcopy(data_orig), expl==0)\n",
    "            \n",
    "            # stab_mask = (expl == 1) | ((expl == 0) & (torch.rand(data.x.shape[0])>0.2).to(expl.device))\n",
    "            # print((expl == 1))\n",
    "            # print((expl == 0 & (torch.rand(data.x.shape[0])>0.8).to(expl.device)))\n",
    "            # print(stab_mask)\n",
    "            # print(stab_mask)\n",
    "            # print(stab_mask)\n",
    "            # print(data[i].node_label)\n",
    "            # stab_edge_index, _ = subgraph(stab_mask, data_orig.edge_index, relabel_nodes=True, num_nodes=data_orig.x.shape[0])\n",
    "            # data_stab = Data(x=data.x[stab_mask], y=data.y, edge_index=stab_edge_index, edge_weight=None, batch=batch[stab_mask])\n",
    "            expl_stab = torch.tensor(explain(model,data_stab,c=c, sparsity=0.9))>0\n",
    "            # print(expl_stab.detach().cpu().numpy().tolist())\n",
    "            # print(expl[stab_mask].detach().cpu().numpy().tolist())\n",
    "\n",
    "            # print(expl_stab.detach().cpu().numpy().tolist())\n",
    "            # print(expl[stab_mask].detach().cpu().numpy().tolist())\n",
    "            expl_stab_true.extend(expl_stab[:-1].detach().cpu().numpy().tolist())\n",
    "            expl_stab_pred.extend(expl.detach().cpu().numpy().tolist())\n",
    "\n",
    "#             for a, b in zip(expl, data.node_label):\n",
    "#                 expl_true.append(b.item())\n",
    "#                 expl_pred.append(a.item())\n",
    "        except Exception as e: print(e)\n",
    "    return np.mean(fidelity),accuracy_score(expl_stab_true, expl_stab_pred),accuracy_score(expl_true, expl_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expl_res(model, dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def expl_res1(model, dataloader):\n",
    "    expl_true = []\n",
    "    expl_pred = []\n",
    "    expl_stab_true = []\n",
    "    expl_stab_pred = []\n",
    "    from torch_geometric.data import Data\n",
    "    from torch_geometric.utils import subgraph\n",
    "    test_loader = dataloader['test']\n",
    "    fidelity = []\n",
    "    for i, data in enumerate(test_loader):\n",
    "        # if data.y.item() == 1: \n",
    "            # continue\n",
    "        # try:\n",
    "        data.batch=torch.zeros(data.x.shape[0])\n",
    "        out = model(data.x.float(), data.edge_index)[0]\n",
    "        c = out.argmax(-1).item()\n",
    "        expl = torch.tensor(explain(model,data,c=c, sparsity=0.9))>0\n",
    "        # print(expl)\n",
    "        # expl = (1-expl).int()\n",
    "        batch = torch.zeros(data.x.shape[0]).to(data.x.device).long()\n",
    "        data_orig = Data(x=data.x, y=data.y,   edge_index=data.edge_index, edge_weight=None, batch=batch)\n",
    "        masked_edge_index, _ = subgraph(expl==0, data_orig.edge_index, relabel_nodes=True, num_nodes=data_orig.x.shape[0])\n",
    "        data_mask = Data(x=data.x[expl==0], y=data.y, edge_index=masked_edge_index, edge_weight=None, batch=batch[expl==0]) \n",
    "        fidelity.append(model(data_orig.x.float(), data_orig.edge_index)[0].softmax(-1)[0,0].item() - model(data_mask.x.float(), data_mask.edge_index)[0].softmax(-1)[0,0].item())\n",
    "        # print(self.model(data_orig)[0].softmax(-1)[0,0].item(),  self.model(data_mask)[0].softmax(-1)[0,0].item())\n",
    "\n",
    "        \n",
    "        \n",
    "        \n",
    "        stab_mask = (expl == 1) | (expl == 0 & (torch.rand(data.y.shape)>0.2).to(expl.device))\n",
    "        # print(stab_mask)\n",
    "        # print(data[i].node_label)\n",
    "        stab_edge_index, _ = subgraph(stab_mask, data_orig.edge_index, relabel_nodes=True, num_nodes=data_orig.x.shape[0])\n",
    "        data_stab = Data(x=data.x[stab_mask], y=data.y, edge_index=stab_edge_index, edge_weight=None, batch=batch[stab_mask])\n",
    "        expl_stab = torch.tensor(explain(model,data,c=c, sparsity=0.9))>0\n",
    "        # print(expl_stab.detach().cpu().numpy().tolist())\n",
    "        # print(expl[stab_mask].detach().cpu().numpy().tolist())\n",
    "        expl_stab_true.extend(expl_stab.detach().cpu().numpy().tolist())\n",
    "        expl_stab_pred.extend(expl[stab_mask].detach().cpu().numpy().tolist())\n",
    "\n",
    "        for a, b in zip(expl, data.node_label):\n",
    "            expl_true.append(b.item())\n",
    "            expl_pred.append(a.item())\n",
    "        # except Exception as e: print(e)\n",
    "    return np.mean(fidelity),accuracy_score(expl_stab_true, expl_stab_pred),accuracy_score(expl_true, expl_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = []\n",
    "acs = []\n",
    "ss = []\n",
    "l = []\n",
    "dataset_name='NCI1'\n",
    "for i in seeds:\n",
    "    try:\n",
    "        config, dataset, dataloader, model = set_config_ex(dataset=dataset_name, seed=i, model_name=\"gcn\", config_name=\"tes\")\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        continue\n",
    "        \n",
    "    f, s, a = expl_res(model, dataloader)\n",
    "    preds, accs = [], []\n",
    "    for batch in dataloader['test']:\n",
    "        batch = batch\n",
    "        batch_preds, _, _ = model(batch.x.float(), batch.edge_index)\n",
    "        batch_preds = batch_preds.argmax(-1)\n",
    "        preds.append(batch_preds)\n",
    "        accs.append(batch_preds == batch.y)\n",
    "    preds = torch.cat(preds, dim=-1)\n",
    "    test_acc = torch.cat(accs, dim=-1).float().mean().item()\n",
    "    l.append(test_acc)\n",
    "    ss.append(s)\n",
    "    fs.append(f)\n",
    "    acs.append(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = [f for il,f in zip(l, fs) if il>0.8]\n",
    "acs = [a for il,a in zip(l, acs) if il>0.8]\n",
    "ss = [s for il,s in zip(l, ss) if il>0.8]\n",
    "l = [il for il in l if il>0.8]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = [f for il,f in zip(l, fs) ]\n",
    "acs = [a for il,a in zip(l, acs)]\n",
    "ss = [s for il,s in zip(l, ss) ]\n",
    "l = [il for il in l ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(ss)/len(ss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.std(ss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(fs[:3]+fs[4:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(ss), np.std(ss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(acs), np.std(acs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(l), np.std(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-imp]",
   "language": "python",
   "name": "conda-env-.conda-imp-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
