{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50182d0a-eae8-4367-8e4c-b2d6bc2cf233",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from syn_dataset import SynGraphDataset\n",
    "from spmotif_dataset import *\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import GINConv, global_mean_pool, global_max_pool, global_add_pool\n",
    "from utils import *\n",
    "from sklearn.model_selection import train_test_split\n",
    "import shutil\n",
    "import glob\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "import pandas as pd\n",
    "import argparse\n",
    "import pickle\n",
    "import json\n",
    "import io\n",
    "from model import GIN, GINTELL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f64c7a32-6b0a-41ba-9492-8a67ac95184d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'MUTAG'\n",
    "seed = 0\n",
    "def get_best_baseline_path(dataset_name):\n",
    "    l = glob.glob(f'results/{dataset_name}/*/results.json')\n",
    "    fl = [json.load(open(f)) for f in l]\n",
    "    df = pd.DataFrame(fl)\n",
    "    if df.shape[0] == 0: return None\n",
    "    df['fname'] = l\n",
    "    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])\n",
    "    df = df[df.fname.str.contains('nogumbel=False')]\n",
    "    fname = df.iloc[-1]['fname']\n",
    "    fname = fname.replace('/results.json', '')\n",
    "    return fname\n",
    "\n",
    "def get_best_path(dataset_name):\n",
    "    l = glob.glob(f'results_logic/{dataset_name}/*/*/results.json')\n",
    "    fl = [json.load(open(f)) for f in l]\n",
    "    df = pd.DataFrame(fl)\n",
    "    if df.shape[0] == 0: return None\n",
    "    df['fname'] = l\n",
    "    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])\n",
    "    df = df[df.fname.str.contains('nogumbel=False')]\n",
    "    fname = df.iloc[-1]['fname']\n",
    "    fname = fname.replace('/results.json', '')\n",
    "    return fname\n",
    "\n",
    "\n",
    "results_path = os.path.join(get_best_path(dataset_name), str(seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b87b42ad-8cef-46a2-9d0c-446a00c3339f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "data = pickle.load(open(os.path.join(results_path, 'data.pkl'), 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "773bca1e-4aa8-47be-b0c4-f907e0e85335",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = json.load(open(os.path.join(results_path, 'args.json'), 'r'))\n",
    "args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98affae0-19b0-4438-b8c5-838573a9b9cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e1bf73c-9170-433e-892b-4b2dc9e609fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = get_dataset(dataset_name)\n",
    "num_classes = dataset.num_classes\n",
    "num_features = dataset.num_features\n",
    "num_layers = 5\n",
    "hidden_dim = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11ae031e-8a0c-440d-a66f-bd6887e54a04",
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = list(range(len(dataset)))\n",
    "train_indices, val_test_indices = train_test_split(indices, test_size=0.2,\n",
    "shuffle=True, stratify=dataset.data.y, random_state=1)\n",
    "\n",
    "val_indices = val_test_indices[:len(val_test_indices)//2]\n",
    "test_indices = val_test_indices[len(val_test_indices)//2:]\n",
    "\n",
    "train_dataset = dataset[train_indices]\n",
    "val_dataset = dataset[val_indices]\n",
    "test_dataset = dataset[test_indices]\n",
    "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2c5e27c-e80e-48d6-b714-2d01b7357e90",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell = torch.load(os.path.join(results_path, 'best.pt'), map_location=device)\n",
    "model_tell.dropout = lambda x:x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d68fa885-a6cd-4709-8609-b5da6c1b4d9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch_scatter\n",
    "def scatter_sum(x, edge_index):\n",
    "    # Get target nodes (i.e., the nodes receiving the messages)\n",
    "    target_nodes = edge_index[1]\n",
    "    \n",
    "    # Perform scatter sum\n",
    "    out = torch_scatter.scatter_add(x[edge_index[0]], target_nodes, dim=0, dim_size=x.size(0))\n",
    "\n",
    "    return out\n",
    "\n",
    "@torch.no_grad()\n",
    "def forward_with_activations(self, x, edge_index, batch, *args, **kwargs):\n",
    "    returns = []\n",
    "    x = self.input_bnorm(x)\n",
    "    xs = []\n",
    "    for i, conv in enumerate(self.convs):\n",
    "        ret = {}\n",
    "        ret['x'] = torch.hstack([x, 1-x])\n",
    "        ret['x_sum'] = scatter_sum(ret['x'], edge_index)\n",
    "        ret['x_bin'] = conv.nn[0].phi_in(ret['x_sum']) >= 0.5\n",
    "        x = conv(torch.hstack([x, 1-x]), edge_index)\n",
    "        xs.append(x)\n",
    "        ret['y'] = x\n",
    "        ret['y_bin'] = x>=0.5\n",
    "        returns.append(ret)\n",
    "    ret = {}\n",
    "    x_mean = global_mean_pool(torch.hstack(xs), batch)\n",
    "    x_max = global_max_pool(torch.hstack(xs), batch)\n",
    "    x_sum = global_add_pool(torch.hstack(xs), batch)\n",
    "    x = torch.hstack([x_mean, x_max, x_sum])\n",
    "    x = self.output_bnorm(x)\n",
    "    ret['x'] = torch.hstack([x, 1-x])\n",
    "    ret['x_bin'] = self.fc.phi_in(ret['x']) >= 0.5\n",
    "    x = self.fc(torch.hstack([x, 1-x]))\n",
    "    ret['y'] = x\n",
    "    ret['y_bin'] = x>=0.5\n",
    "    returns.append(ret)\n",
    "    return x, returns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bf5cd16-457e-4ee6-98b8-1a635ddbe5cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def inverse_sigmoid(x):\n",
    "    \"\"\"Computes the inverse of the sigmoid function (logit function).\"\"\"\n",
    "    return torch.log(x / (1 - x))\n",
    "\n",
    "torch.no_grad()\n",
    "def find_logic_rules(w, t_in, t_out, activations=None, max_rule_len=10, max_rules=100, min_support=5):\n",
    "    w = w.clone()\n",
    "    t_in = t_in.clone()\n",
    "    t_out = t_out.clone()\n",
    "    t_out = t_out.item()\n",
    "    ordering_scores = w\n",
    "    sorted_idxs = torch.argsort(ordering_scores, 0, descending=True)\n",
    "    mask = w > 1e-5\n",
    "    if activations is not None:\n",
    "        mask = mask & (activations.sum(0) >= min_support)\n",
    "    total_result = set()\n",
    "\n",
    "    # Filter and sort indices based on the mask\n",
    "    idxs_to_visit = sorted_idxs[mask[sorted_idxs]]\n",
    "    if idxs_to_visit.numel() == 0:\n",
    "        return total_result\n",
    "\n",
    "    # Sort weights based on the filtered indices\n",
    "    sorted_weights = w[idxs_to_visit]\n",
    "    current_combination = []\n",
    "    result = set()\n",
    "\n",
    "    def find_logic_rules_recursive(index, current_sum):\n",
    "        # Stop if the maximum number of rules has been reached\n",
    "        if len(result) >= max_rules:\n",
    "            return\n",
    "\n",
    "        if len(current_combination) > max_rule_len:\n",
    "            return\n",
    "\n",
    "        # Check if the current combination satisfies the condition\n",
    "        if current_sum >= t_out:\n",
    "            c = idxs_to_visit[current_combination].cpu().detach().tolist()\n",
    "            c = tuple(sorted(c))\n",
    "            result.add(c)\n",
    "            return\n",
    "\n",
    "        # Prune if remaining weights can't satisfy t_out\n",
    "        remaining_max_sum = current_sum + sorted_weights[index:].sum()\n",
    "        if remaining_max_sum < t_out:\n",
    "            return\n",
    "\n",
    "        # Explore further combinations\n",
    "        for i in range(index, idxs_to_visit.shape[0]):\n",
    "            # Prune based on activations if provided\n",
    "            if activations is not None and len(current_combination) > 0 and activations[:, idxs_to_visit[current_combination + [i]]].all(-1).sum().item() < min_support:\n",
    "                continue\n",
    "\n",
    "            current_combination.append(i)\n",
    "            find_logic_rules_recursive(i + 1, current_sum + sorted_weights[i])\n",
    "            current_combination.pop()\n",
    "\n",
    "    # Start the recursive process\n",
    "    find_logic_rules_recursive(0, 0)\n",
    "    return result\n",
    "\n",
    "\n",
    "def extract_rules(self, feature=None, activations=None, max_rule_len=float('inf'), max_rules=100, min_support=1, out_threshold=0.5):\n",
    "    ws = self.weight\n",
    "    t_in = self.phi_in.t\n",
    "    t_out = -self.b + inverse_sigmoid(torch.tensor(out_threshold))\n",
    "\n",
    "    rules = []\n",
    "    if feature is None:\n",
    "        features = range(self.out_features)\n",
    "    else:\n",
    "        features = [feature]\n",
    "    for i in features:\n",
    "        w = ws[i].to('cpu')\n",
    "        ti = t_in.to('cpu')\n",
    "        to = t_out[i].to('cpu')\n",
    "        rules.append(find_logic_rules(w, ti, to, activations, max_rule_len, max_rules, min_support))\n",
    "\n",
    "    return rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c01931ec-3ce4-4273-9049-f7420dd06c41",
   "metadata": {},
   "outputs": [],
   "source": [
    "rules = extract_rules(model_tell.fc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4acf570-8573-4511-951f-f7801e96c8f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e7a5be4-9cd8-4c25-bbe1-852ba88f77b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat_map = []\n",
    "for pos_neg in ['pos', 'neg']:\n",
    "    for readout in ['mean', 'max', 'sum']:\n",
    "        for l in range(num_layers):\n",
    "            for d in range(hidden_dim):\n",
    "                # feat_map.append(f'{pos_neg}_{readout}_{l}_{d}')\n",
    "                feat_map.append((pos_neg, readout, l, d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4432b6c2-e0a0-49db-9745-b827dd96256e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from torch_geometric.data import Batch\n",
    "tell_explainer_results = []\n",
    "for data in tqdm(test_dataset):\n",
    "    data = data.to(device)\n",
    "    data = Batch.from_data_list([data])\n",
    "    data.x = data.x.float()\n",
    "    pred, rets =forward_with_activations(model_tell, data.x.float(), data.edge_index, data.batch)\n",
    "    pred_c = pred.argmax(-1).item()\n",
    "    rules = extract_rules(model_tell.fc)\n",
    "    soft_mask = torch.zeros(data.x.shape[0]).to(device)\n",
    "    for c, class_rules in enumerate(rules):\n",
    "        for rule in class_rules:\n",
    "            # print(rule)\n",
    "            # print(rule_acts[:,rule])\n",
    "            \n",
    "            # if not rule_activated: continue\n",
    "            for literal in rule:\n",
    "                # print(literal)\n",
    "                pos_neg, agg, layer, i = feat_map[literal]\n",
    "                acts = rets[layer]['y_bin'][:,i]\n",
    "                m = torch.zeros_like(soft_mask)\n",
    "                if agg == 'max':\n",
    "                    m[acts>=acts.max()] = (1 if pred_c==c else -1)*acts.max()*rets[-1]['x_bin'][:,literal].item()*model_tell.fc.weight[c,literal]\n",
    "                elif agg == 'sum':\n",
    "                    m=(1 if pred_c==c else -1)*acts*rets[-1]['x_bin'][:,literal].item()*model_tell.fc.weight[c,literal]\n",
    "                else:\n",
    "                    m=(1 if pred_c==c else -1)*acts*rets[-1]['x_bin'][:,literal].item()*model_tell.fc.weight[c,literal]\n",
    "                m_=torch.zeros_like(m)\n",
    "                for i in range(len(m)):\n",
    "                    if m[i] > 0:\n",
    "                        try:\n",
    "                            subset, _, _, _ = k_hop_subgraph(i, 1, data.edge_index.cpu())\n",
    "                            m_[subset] += m[i]\n",
    "                        except:\n",
    "                            m_[i] = m[i]\n",
    "                soft_mask+=m_    \n",
    "\n",
    "    # print(soft_mask)\n",
    "    soft_mask = soft_mask.detach().cpu()\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64860008-fbd9-4fe4-a0ab-6cfaed14df4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.utils import subgraph, add_self_loops\n",
    "from torch_geometric.data import Data,Batch\n",
    "\n",
    "def calculate_fidelity(data, node_mask, model, remove_nodes=True, top_k=None):\n",
    "    data = Batch.from_data_list([data])\n",
    "    device = next(model.parameters()).device\n",
    "    data = data.to(device)\n",
    "\n",
    "    # Compute sparsity\n",
    "    total_nodes = data.x.shape[0]\n",
    "    sparsity = 1 - node_mask.sum().item() / total_nodes if total_nodes > 0 else 0\n",
    "\n",
    "    # Get original predictions\n",
    "    original_pred = model(data.x, data.edge_index, data.batch)\n",
    "    original_pred = F.softmax(original_pred, dim=1)\n",
    "    label = original_pred.argmax(-1).item()\n",
    "    \n",
    "    # Apply node mask (InvFidelity)\n",
    "    if remove_nodes:\n",
    "        masked_edge_index, _ = subgraph(node_mask == 0, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)\n",
    "        n_nodes = (node_mask == 0).sum()\n",
    "        new_data = Batch.from_data_list([Data(x=data.x[node_mask == 0], edge_index=masked_edge_index)])\n",
    "        masked_pred = model(new_data.x, new_data.edge_index, new_data.batch)\n",
    "    else:\n",
    "        masked_x = data.x.clone()\n",
    "        masked_x[node_mask==1] = 0\n",
    "        new_data = Batch.from_data_list([Data(x=masked_x, edge_index=data.edge_index)])\n",
    "        masked_pred = model(new_data.x, new_data.edge_index, new_data.batch)\n",
    "    masked_pred = F.softmax(masked_pred, dim=1)\n",
    "\n",
    "    # Keep only important nodes (Fidelity)\n",
    "    if remove_nodes:\n",
    "        masked_edge_index, _ = subgraph(node_mask == 1, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)\n",
    "        n_nodes = (node_mask == 1).sum()\n",
    "        new_data = Batch.from_data_list([Data(x=data.x[node_mask == 1], edge_index=masked_edge_index)])\n",
    "        retained_pred = model(new_data.x, new_data.edge_index, new_data.batch)\n",
    "    else:\n",
    "        masked_x = data.x.clone()\n",
    "        masked_x[node_mask==0] = 0\n",
    "        new_data = Batch.from_data_list([Data(x=masked_x, edge_index=data.edge_index)])\n",
    "        retained_pred = model(new_data.x, new_data.edge_index, new_data.batch)\n",
    "    retained_pred = F.softmax(retained_pred, dim=1)\n",
    "\n",
    "\n",
    "    # Compute Fidelity+ and Fidelity-\n",
    "    inv_fidelity = (original_pred[:, label] - \n",
    "                                 retained_pred[:, label]).mean().item()\n",
    "\n",
    "    fidelity = (original_pred[:, label] - \n",
    "                                 masked_pred[:, label]).mean().item()\n",
    "\n",
    "    # inv_fidelity = (original_pred.argmax(-1) != \n",
    "    #                              retained_pred.argmax(-1)).float().item()\n",
    "\n",
    "    # fidelity = (original_pred.argmax(-1)  != \n",
    "    #                              masked_pred.argmax(-1) ).float().item()\n",
    "\n",
    "    n_fidelity = inv_fidelity*sparsity\n",
    "    n_inv_fidelity = inv_fidelity*(1-sparsity)\n",
    "    \n",
    "    # Compute HFidelity (harmonic mean of Fidelity+ and Fidelity-)\n",
    "    hfidelity = ((1+n_fidelity) * (1-n_inv_fidelity)) / (2 + n_fidelity - n_inv_fidelity) if (1 + n_fidelity - n_inv_fidelity) != 0 else 0\n",
    "\n",
    "    return {\n",
    "        \"Fidelity\": fidelity,\n",
    "        \"InvFidelity\": inv_fidelity,\n",
    "        \"HFidelity\": hfidelity\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea86183d-e7e9-4845-bc21-43ad0bacfee3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def generate_hard_masks(soft_mask):\n",
    "#     sparsity_levels = torch.arange(0.5,1, 0.05)\n",
    "#     hard_masks = []\n",
    "#     for sparsity in sparsity_levels:\n",
    "#         threshold = np.percentile(soft_mask, sparsity * 100)\n",
    "#         hard_mask = (soft_mask > threshold).int()\n",
    "#         if hard_mask.sum() == 0:\n",
    "#             hard_mask = (soft_mask > soft_mask.min()).int()\n",
    "#         hard_masks.append(hard_mask)\n",
    "#     return list(zip(sparsity_levels, hard_masks))\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "\n",
    "def generate_hard_masks(soft_mask):\n",
    "    soft_mask_flat = soft_mask.flatten()\n",
    "    total_elements = soft_mask_flat.numel()\n",
    "    sparsity_levels = torch.arange(0.5, 1.0, 0.05)\n",
    "    hard_masks = []\n",
    "\n",
    "    # Get sorted indices (ascending: lowest values first)\n",
    "    sorted_indices = torch.argsort(soft_mask_flat)\n",
    "\n",
    "    for sparsity in sparsity_levels:\n",
    "        num_to_mask = int(sparsity.item() * total_elements)\n",
    "        mask_flat = torch.ones_like(soft_mask_flat, dtype=torch.int)\n",
    "        \n",
    "        if num_to_mask >= total_elements:\n",
    "            num_to_mask = total_elements - 1  # keep at least one element\n",
    "        \n",
    "        # Zero out the lowest `num_to_mask` elements\n",
    "        mask_flat[sorted_indices[:num_to_mask]] = 0\n",
    "        \n",
    "        # Reshape to original shape\n",
    "        hard_mask = mask_flat.view_as(soft_mask)\n",
    "        hard_masks.append(hard_mask)\n",
    "\n",
    "    return list(zip(sparsity_levels, hard_masks))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00cedda2-cbec-4289-9f4a-4d87435a485c",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks = generate_hard_masks(soft_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfb659de-b9e7-46e6-8d7f-722b191f48ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "calculate_fidelity(data, masks[0][1], model_tell, remove_nodes=True, top_k=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "365ab951-c0a0-4d3c-9cb3-3b3d162b6010",
   "metadata": {},
   "outputs": [],
   "source": [
    "tell_explainer_results = []\n",
    "for data in tqdm(test_dataset):\n",
    "    try:\n",
    "        data = data.to(device)\n",
    "        data.x = data.x.float()\n",
    "        pred_tell, rule_acts, layers_acts = model_tell(data.x.float(), data.edge_index, activations=True)\n",
    "        \n",
    "        pred = model(data.x.float(), data.edge_index)\n",
    "        # rule_acts = rule_acts>0.5\n",
    "        r = {\n",
    "            'data': data,\n",
    "            'pred': pred.softmax(-1).detach().cpu().numpy(),\n",
    "            'res':{}\n",
    "        }\n",
    "        pred_c = r['pred'].argmax(-1).item()\n",
    "        rules = extract_rules(model_tell.fc)\n",
    "        soft_mask = torch.zeros(data.x.shape[0]).to(device)\n",
    "        for c, class_rules in enumerate(rules):\n",
    "            for rule in class_rules:\n",
    "                # print(rule)\n",
    "                # print(rule_acts[:,rule])\n",
    "                \n",
    "                # if not rule_activated: continue\n",
    "                for literal in rule:\n",
    "                    # print(literal)\n",
    "                    agg, layer, i = feat_map[literal]\n",
    "                    acts = layers_acts[layer][:,i]\n",
    "                    m = torch.zeros_like(soft_mask)\n",
    "                    if agg == 'max':\n",
    "                        m[acts>=acts.max()] = (1 if pred_c==c else -1)*acts.max()*rule_acts[:,literal].item()*model_tell.fc.weight[c,literal]\n",
    "                    elif agg == 'sum':\n",
    "                        m=(1 if pred_c==c else -1)*acts*rule_acts[:,literal].item()*model_tell.fc.weight[c,literal]\n",
    "                    else:\n",
    "                        m=(1 if pred_c==c else -1)*acts*rule_acts[:,literal].item()*model_tell.fc.weight[c,literal]\n",
    "                    m_=torch.zeros_like(m)\n",
    "                    for i in range(len(m)):\n",
    "                        if m[i] > 0:\n",
    "                            try:\n",
    "                                subset, _, _, _ = k_hop_subgraph(i, 1, data.edge_index.cpu())\n",
    "                                m_[subset] += m[i]\n",
    "                            except:\n",
    "                                m_[i] = m[i]\n",
    "                    soft_mask+=m_    \n",
    "\n",
    "        # print(soft_mask)\n",
    "        soft_mask = soft_mask.detach().cpu()\n",
    "        r['soft_mask'] = soft_mask\n",
    "        hard_masks = generate_hard_masks(soft_mask)\n",
    "        for sparsity, hard_mask in hard_masks:\n",
    "            sparsity = sparsity.item()\n",
    "            r['res'][sparsity] = calculate_fidelity(data, hard_mask, model)\n",
    "            r['res'][sparsity]['hard_mask'] = hard_mask\n",
    "        r['res_topk'] = {\n",
    "            1: calculate_fidelity_topk(data, soft_mask, model,1),\n",
    "            3: calculate_fidelity_topk(data, soft_mask, model,3),\n",
    "            5: calculate_fidelity_topk(data, soft_mask, model,5)\n",
    "        }\n",
    "        tell_explainer_results.append(r)\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-imp]",
   "language": "python",
   "name": "conda-env-.conda-imp-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
