{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# ProtoPNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import hydra\n",
    "import torch\n",
    "import shutil\n",
    "import warnings\n",
    "from torch.optim import Adam\n",
    "from omegaconf import OmegaConf\n",
    "from utils import check_dir\n",
    "from tes_gnnNets import *\n",
    "from dataset import get_dataset, get_dataloader\n",
    "import torch.nn.functional as F\n",
    "from torch.optim.lr_scheduler import MultiStepLR\n",
    "import logging\n",
    "import networkx as nx\n",
    "import torch_geometric\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:100% !important; }</style>\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hydra import compose, initialize\n",
    "from omegaconf import OmegaConf\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(config_path=\"tes_config\", job_name=\"test_app\")\n",
    "cfg = compose(config_name=\"config\", overrides=[\"datasets=ba_2motifs\", \"seed=0\",\"models.param.ba_2motifs.num_basis_per_class=10\"]) #here put the same settings used for the train\n",
    "#print(OmegaConf.to_yaml(cfg))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config.models.gnn_saving_dir = 'tes_checkpoints'\n",
    "config.models.param = config.models.param[config.datasets.dataset_name]\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda', index=config.device_id)\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "\n",
    "dataset = get_dataset(dataset_root=config.datasets.dataset_root,\n",
    "                      dataset_name=config.datasets.dataset_name)\n",
    "dataset.data.x = dataset.data.x.float()\n",
    "dataset.data.y = dataset.data.y.squeeze().long()\n",
    "if config.models.param.graph_classification:\n",
    "    dataloader_params = {'batch_size': 1,\n",
    "                         'random_split_flag': config.datasets.random_split_flag,\n",
    "                         'data_split_ratio': config.datasets.data_split_ratio,\n",
    "                         'seed': config.seed}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls tes_checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = get_dataset(dataset_root=config.datasets.dataset_root,\n",
    "                          dataset_name=config.datasets.dataset_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = get_dataloader(dataset, **dataloader_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = get_gnnNets(dataset.num_node_features, dataset.num_classes, config.models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f\"outputs_tes/datasets={config.datasets.dataset_name},device_id=0,models.param.{config.datasets.dataset_name}.num_basis_per_class=10,models=gcn,seed={config.seed}/tes_checkpoints/{config.seed}/{config.datasets.dataset_name}_from_None/gcn_{config.models.param.num_basis_per_class}_3l_best.pth\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "saved = torch.load(f\"outputs_tes/datasets={config.datasets.dataset_name},device_id=0,models.param.{config.datasets.dataset_name}.num_basis_per_class=10,models=gin,seed={config.seed}/tes_checkpoints/{config.seed}/{config.datasets.dataset_name}_from_None/gin_{config.models.param.num_basis_per_class}_5l_best.pth\") #here the path to the saved model\n",
    "state_dict = saved['net']\n",
    "basis_data = saved['basis_data']\n",
    "model.load_state_dict(state_dict)\n",
    "model.basis_data = basis_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !ls \"outputs_tes/datasets=ba_2motifs,device_id=0,models.param.ba_2motifs.num_basis_per_class=10,models=gin,seed=0/tes_checkpoints/0/ba_2motifs_from_None\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_config_ex(dataset=\"ba_2motifs\", seed=0, num_basis_per_class=10, config_name=\"tes\", model_name=\"gin\"):\n",
    "    hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "    initialize(config_path=config_name + \"_\" + \"config\" , job_name=\"test_app\")\n",
    "    cfg = compose(config_name=\"config\", overrides=[f\"datasets={dataset}\", f\"seed={seed}\", f\"models={model_name}\"])\n",
    "#     print(OmegaConf.to_yaml(cfg))\n",
    "    config = cfg\n",
    "    config.models.gnn_saving_dir = config_name + \"_\" + 'checkpoints'\n",
    "    config.models.param = config.models.param[config.datasets.dataset_name]\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        device = torch.device('cuda', index=config.device_id)\n",
    "    else:\n",
    "        device = torch.device('cpu')\n",
    "\n",
    "    dataset = get_dataset(dataset_root=config.datasets.dataset_root,\n",
    "                          dataset_name=config.datasets.dataset_name)\n",
    "    dataset.data.x = dataset.data.x.float()\n",
    "    dataset.data.y = dataset.data.y.squeeze().long()\n",
    "    if config.models.param.graph_classification:\n",
    "        dataloader_params = {'batch_size': 1,\n",
    "                             'random_split_flag': config.datasets.random_split_flag,\n",
    "                             'data_split_ratio': config.datasets.data_split_ratio,\n",
    "                             'seed': config.seed}\n",
    "    dataset = get_dataset(dataset_root=config.datasets.dataset_root,\n",
    "                          dataset_name=config.datasets.dataset_name)\n",
    "    dataloader = get_dataloader(dataset, **dataloader_params)\n",
    "    model = get_gnnNets(dataset.num_node_features, dataset.num_classes, config.models)\n",
    "    saved = torch.load(f\"outputs_tes/datasets={config.datasets.dataset_name},device_id=0,models.param.{config.datasets.dataset_name}.num_basis_per_class=10,models=gin,seed={config.seed}/tes_checkpoints/{config.seed}/{config.datasets.dataset_name}_from_None/gin_{config.models.param.num_basis_per_class}_5l_latest.pth\", map_location='cuda:0') #here the path to the saved model\n",
    "    # saved = torch.load(f\"outputs_tes_old/datasets={config.datasets.dataset_name},device_id=0,models.param.{config.datasets.dataset_name}.num_basis_per_class=10,models=gcn,seed={config.seed}/tes_checkpoints/{config.seed}/{config.datasets.dataset_name}_from_None/gcn_{config.models.param.num_basis_per_class}_3l_latest.pth\", map_location='cuda:0') #here the path to the saved model\n",
    "    state_dict = saved['net']\n",
    "    basis_data = saved['basis_data']\n",
    "    model.load_state_dict(state_dict)\n",
    "    model.basis_data = basis_data\n",
    "    return config, dataset, dataloader, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name='bbbp'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def explain(model, data, sparsity=0.8, c=None, protop=True):\n",
    "    out, embs, cosines=model(data.x.float(), data.edge_index)    \n",
    "    y_hat=c\n",
    "    if y_hat is None:\n",
    "        y_hat=out.argmax(-1).item()\n",
    "    mask=np.zeros(len(data.x))\n",
    "    if True:\n",
    "        max_nodes = cosines[:,model.num_basis_per_class*y_hat:model.num_basis_per_class*(y_hat+1)].argmax(axis=0).tolist()\n",
    "        for i in range(model.num_basis_per_class*y_hat,model.num_basis_per_class*(y_hat+1)):\n",
    "            try:\n",
    "                to_mask = torch_geometric.utils.k_hop_subgraph(max_nodes[i%model.num_basis_per_class], 3, data.edge_index)[0].tolist()\n",
    "                cosines_i = cosines[max_nodes[i%model.num_basis_per_class],i].item()\n",
    "                mask[to_mask] += model.classifier_weights[i, y_hat].detach().numpy()*cosines_i\n",
    "            except:pass\n",
    "    else:\n",
    "        for i in range(len(data.x)):\n",
    "            mask[i]+= cosines[i,model.num_basis_per_class*y_hat:model.num_basis_per_class*(y_hat+1)].sum().item()\n",
    "    # if protop: mask=torch.log(mask+1)/(mask+1e-8)\n",
    "    # mask = mask*model.classifier_weights[i]\n",
    "    mask=(mask-mask.min())/(mask.max()-mask.min()+1e-8)\n",
    "    # p=np.percentile(mask, 100*sparsity)\n",
    "    # mask[mask<p]=0\n",
    "    # edge_mask=np.zeros(data.edge_index.shape[1])\n",
    "    # for i, (a,b) in enumerate(data.edge_index.T):\n",
    "    #     edge_mask[i]=(mask[a]+mask[b])/2\n",
    "    # edge_mask[edge_mask>0]=1\n",
    "    return mask#, edge_mask\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.utils import subgraph, add_self_loops\n",
    "from torch_geometric.data import Data,Batch\n",
    "\n",
    "def calculate_fidelity(data, node_mask, model, remove_nodes=True, top_k=None):\n",
    "    data = Batch.from_data_list([data])\n",
    "    device = next(model.parameters()).device\n",
    "    data = data.to(device)\n",
    "\n",
    "    # Compute sparsity\n",
    "    total_nodes = data.x.shape[0]\n",
    "    sparsity = 1 - node_mask.sum().item() / total_nodes if total_nodes > 0 else 0\n",
    "\n",
    "    # Get original predictions\n",
    "    original_pred = model(data)[0]\n",
    "    original_pred = F.softmax(original_pred, dim=1)\n",
    "    label = original_pred.argmax(-1).item()\n",
    "    \n",
    "    # Apply node mask (InvFidelity)\n",
    "    if remove_nodes:\n",
    "        masked_edge_index, _ = subgraph(node_mask == 0, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)\n",
    "        n_nodes = (node_mask == 0).sum()\n",
    "        new_data = Batch.from_data_list([Data(x=data.x[node_mask == 0], edge_index=masked_edge_index)])\n",
    "        masked_pred = model(new_data)[0]\n",
    "    else:\n",
    "        masked_x = data.x.clone()\n",
    "        masked_x[node_mask==1] = 0\n",
    "        new_data = Batch.from_data_list([Data(x=masked_x, edge_index=data.edge_index)])\n",
    "        masked_pred = model(new_data)[0]\n",
    "    masked_pred = F.softmax(masked_pred, dim=1)\n",
    "\n",
    "    # Keep only important nodes (Fidelity)\n",
    "    if remove_nodes:\n",
    "        masked_edge_index, _ = subgraph(node_mask == 1, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)\n",
    "        n_nodes = (node_mask == 1).sum()\n",
    "        new_data = Batch.from_data_list([Data(x=data.x[node_mask == 1], edge_index=masked_edge_index)])\n",
    "        retained_pred = model(new_data)[0]\n",
    "    else:\n",
    "        masked_x = data.x.clone()\n",
    "        masked_x[node_mask==0] = 0\n",
    "        new_data = Batch.from_data_list([Data(x=masked_x, edge_index=data.edge_index)])\n",
    "        retained_pred = model(new_data)[0]\n",
    "    retained_pred = F.softmax(retained_pred, dim=1)\n",
    "\n",
    "\n",
    "    # # Compute Fidelity+ and Fidelity-\n",
    "    # inv_fidelity = (original_pred[:, label] - \n",
    "    #                              retained_pred[:, label]).mean().item()\n",
    "\n",
    "    # fidelity = (original_pred[:, label] - \n",
    "    #                              masked_pred[:, label]).mean().item()\n",
    "\n",
    "    inv_fidelity = (original_pred.argmax(-1) != \n",
    "                                 retained_pred.argmax(-1)).float().item()\n",
    "\n",
    "    fidelity = (original_pred.argmax(-1)  != \n",
    "                                 masked_pred.argmax(-1) ).float().item()\n",
    "\n",
    "    n_fidelity = inv_fidelity*sparsity\n",
    "    n_inv_fidelity = inv_fidelity*(1-sparsity)\n",
    "    \n",
    "    # Compute HFidelity (harmonic mean of Fidelity+ and Fidelity-)\n",
    "    hfidelity = ((1+n_fidelity) * (1-n_inv_fidelity)) / (2 + n_fidelity - n_inv_fidelity) if (1 + n_fidelity - n_inv_fidelity) != 0 else 0\n",
    "\n",
    "    return {\n",
    "        \"Fidelity\": fidelity,\n",
    "        \"InvFidelity\": inv_fidelity,\n",
    "        \"HFidelity\": hfidelity\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def generate_hard_masks(soft_mask):\n",
    "#     sparsity_levels = torch.arange(0.5,1, 0.05)\n",
    "#     hard_masks = []\n",
    "#     for sparsity in sparsity_levels:\n",
    "#         threshold = np.percentile(soft_mask, sparsity * 100)\n",
    "#         hard_mask = (soft_mask > threshold).int()\n",
    "#         if hard_mask.sum() == 0:\n",
    "#             hard_mask = (soft_mask > soft_mask.min()).int()\n",
    "#         hard_masks.append(hard_mask)\n",
    "#     return list(zip(sparsity_levels, hard_masks))\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "\n",
    "def generate_hard_masks(soft_mask):\n",
    "    soft_mask_flat = soft_mask.flatten()\n",
    "    total_elements = soft_mask_flat.numel()\n",
    "    sparsity_levels = torch.arange(0.5, 1.0, 0.05)\n",
    "    hard_masks = []\n",
    "\n",
    "    # Get sorted indices (ascending: lowest values first)\n",
    "    sorted_indices = torch.argsort(soft_mask_flat)\n",
    "\n",
    "    for sparsity in sparsity_levels:\n",
    "        num_to_mask = int(sparsity.item() * total_elements)\n",
    "        mask_flat = torch.ones_like(soft_mask_flat, dtype=torch.int)\n",
    "        \n",
    "        if num_to_mask >= total_elements:\n",
    "            num_to_mask = total_elements - 1  # keep at least one element\n",
    "        \n",
    "        # Zero out the lowest `num_to_mask` elements\n",
    "        mask_flat[sorted_indices[:num_to_mask]] = 0\n",
    "        \n",
    "        # Reshape to original shape\n",
    "        hard_mask = mask_flat.view_as(soft_mask)\n",
    "        hard_masks.append(hard_mask)\n",
    "\n",
    "    return list(zip(sparsity_levels, hard_masks))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config, dataset, dataloader, model = set_config_ex(dataset=dataset_name, seed=0, model_name=\"gin\", config_name=\"tes\")\n",
    "l = []\n",
    "seeds = [0,1,2,3,4]\n",
    "dataset_name = 'bbbp'\n",
    "for i in seeds:\n",
    "    try:\n",
    "        config, dataset, dataloader, model = set_config_ex(dataset=dataset_name, seed=i, model_name=\"gin\", config_name=\"tes\")\n",
    "    except:\n",
    "        continue\n",
    "    # preds, accs = [], []\n",
    "    # for batch in dataloader['test']:\n",
    "    #     batch = batch\n",
    "    #     batch_preds, _, _ = model(batch.x.float(), batch.edge_index)\n",
    "    #     batch_preds = batch_preds.argmax(-1)\n",
    "    #     preds.append(batch_preds)\n",
    "    #     accs.append(batch_preds == batch.y)\n",
    "    # preds = torch.cat(preds, dim=-1)\n",
    "    # test_acc = torch.cat(accs, dim=-1).float().mean().item()\n",
    "    seed_expl_res = []\n",
    "    for data in tqdm.tqdm(dataloader['test']):\n",
    "        if data.y==0: continue\n",
    "        try:\n",
    "            data.x = data.x.float()\n",
    "            out = model(data=data)\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            continue\n",
    "        logit, embs, sims = out\n",
    "        c = logit.argmax(-1).item()\n",
    "        mask = explain(model, data,c=c)\n",
    "        # except:\n",
    "        #     continue\n",
    "        masks = generate_hard_masks(torch.tensor(mask).cpu().detach())\n",
    "        d = {}\n",
    "        for m in masks:\n",
    "            d[m[0].item()] = calculate_fidelity(data, m[1], model)\n",
    "        seed_expl_res.append(pd.DataFrame(d))\n",
    "    l.append(sum(seed_expl_res)/len(seed_expl_res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import get_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "l = [np.array(x) for x in l]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(l, axis=0)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.std(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "l = []\n",
    "for data in tqdm.tqdm(dataloader['test']):\n",
    "    try:\n",
    "        out = model(data=b)\n",
    "        logit, embs, sims = out\n",
    "        c = logit.argmax(-1).item()\n",
    "        mask = explain(model, data,c=c)\n",
    "    except:\n",
    "        continue\n",
    "    masks = generate_hard_masks(torch.tensor(mask).cpu().detach())\n",
    "    d = {}\n",
    "    for m in masks:\n",
    "        d[m[0].item()] = calculate_fidelity(data, m[1], model)\n",
    "    l.append(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "l = [pd.DataFrame(x) for x in l]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(l)/len(l)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-imp]",
   "language": "python",
   "name": "conda-env-.conda-imp-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
