{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c70e6d8f-d238-40f0-911d-2a57728809b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from syn_dataset import SynGraphDataset\n",
    "from spmotif_dataset import *\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import GINConv, global_mean_pool, global_max_pool, global_add_pool\n",
    "from utils import *\n",
    "from sklearn.model_selection import train_test_split\n",
    "import shutil\n",
    "import glob\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "import pandas as pd\n",
    "import argparse\n",
    "import pickle\n",
    "import json\n",
    "import io\n",
    "from model_node import GIN, GINTELL\n",
    "from train_baseline_node import test_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cf8f4f3-6497-407d-844a-63b5ff2f4aee",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'BaShapes'\n",
    "seed = 2\n",
    "def get_best_baseline_path(dataset_name):\n",
    "    l = glob.glob(f'results/{dataset_name}/*/results.json')\n",
    "    fl = [json.load(open(f)) for f in l]\n",
    "    df = pd.DataFrame(fl)\n",
    "    if df.shape[0] == 0: return None\n",
    "    df['fname'] = l\n",
    "    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])\n",
    "    df = df[df.fname.str.contains('nogumbel=False')]\n",
    "    fname = df.iloc[-1]['fname']\n",
    "    fname = fname.replace('/results.json', '')\n",
    "    return fname\n",
    "\n",
    "def get_best_path(dataset_name):\n",
    "    l = glob.glob(f'results_logic/{dataset_name}/*/*/results.json')\n",
    "    fl = [json.load(open(f)) for f in l]\n",
    "    df = pd.DataFrame(fl)\n",
    "    if df.shape[0] == 0: return None\n",
    "    df['fname'] = l\n",
    "    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])\n",
    "    df = df[df.fname.str.contains('nogumbel=False')]\n",
    "    fname = df.iloc[-1]['fname']\n",
    "    fname = fname.replace('/results.json', '')\n",
    "    return fname\n",
    "\n",
    "\n",
    "results_path = os.path.join(get_best_path(dataset_name), str(seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08ae3d2d-6f6c-4cea-bbe1-a694dbc42efd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "data = pickle.load(open(os.path.join(results_path, 'data.pkl'), 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bc57df2-60f0-4942-bdee-fa82429010df",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = json.load(open(os.path.join(results_path, 'args.json'), 'r'))\n",
    "args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dcc9c46-5431-40c6-bdf2-8ea0d923ea8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a534b1de-729d-4740-a8bd-6b809bbbe849",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = get_dataset(dataset_name)\n",
    "num_classes = dataset.num_classes\n",
    "num_features = dataset.num_features\n",
    "num_layers = 5\n",
    "hidden_dim = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b09b8ad6-773a-4e90-9839-17c19b6a1e21",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell = torch.load(os.path.join(results_path, 'best.pt'), map_location=device)\n",
    "# model_tell.dropout = lambda x:x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fad2c355-8f8a-4edf-969e-672e67a9bd1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)\n",
    "test_acc = test_epoch(model_tell, dataset[0], dataset[0].test_mask, device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b31350b-5ceb-409b-adb8-4a290a0131e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "656dd205-0e17-47d3-ae63-29d1bd07b102",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[0].y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fc6bff2-2d4e-4d25-864a-127716e9ad5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def inverse_sigmoid(x):\n",
    "    \"\"\"Computes the inverse of the sigmoid function (logit function).\"\"\"\n",
    "    return torch.log(x / (1 - x))\n",
    "\n",
    "torch.no_grad()\n",
    "def find_logic_rules(w, t_in, t_out, activations=None, max_rule_len=10, max_rules=100, min_support=5):\n",
    "    w = w.clone()\n",
    "    t_in = t_in.clone()\n",
    "    t_out = t_out.clone()\n",
    "    t_out = t_out.item()\n",
    "    ordering_scores = w\n",
    "    sorted_idxs = torch.argsort(ordering_scores, 0, descending=True)\n",
    "    mask = w > 1e-5\n",
    "    if activations is not None:\n",
    "        mask = mask & (activations.sum(0) >= min_support)\n",
    "    total_result = set()\n",
    "\n",
    "    # Filter and sort indices based on the mask\n",
    "    idxs_to_visit = sorted_idxs[mask[sorted_idxs]]\n",
    "    if idxs_to_visit.numel() == 0:\n",
    "        return total_result\n",
    "\n",
    "    # Sort weights based on the filtered indices\n",
    "    sorted_weights = w[idxs_to_visit]\n",
    "    current_combination = []\n",
    "    result = set()\n",
    "\n",
    "    def find_logic_rules_recursive(index, current_sum):\n",
    "        # Stop if the maximum number of rules has been reached\n",
    "        if len(result) >= max_rules:\n",
    "            return\n",
    "\n",
    "        if len(current_combination) > max_rule_len:\n",
    "            return\n",
    "\n",
    "        # Check if the current combination satisfies the condition\n",
    "        if current_sum >= t_out:\n",
    "            c = idxs_to_visit[current_combination].cpu().detach().tolist()\n",
    "            c = tuple(sorted(c))\n",
    "            result.add(c)\n",
    "            return\n",
    "\n",
    "        # Prune if remaining weights can't satisfy t_out\n",
    "        remaining_max_sum = current_sum + sorted_weights[index:].sum()\n",
    "        if remaining_max_sum < t_out:\n",
    "            return\n",
    "\n",
    "        # Explore further combinations\n",
    "        for i in range(index, idxs_to_visit.shape[0]):\n",
    "            # Prune based on activations if provided\n",
    "            if activations is not None and len(current_combination) > 0 and activations[:, idxs_to_visit[current_combination + [i]]].all(-1).sum().item() < min_support:\n",
    "                continue\n",
    "\n",
    "            current_combination.append(i)\n",
    "            find_logic_rules_recursive(i + 1, current_sum + sorted_weights[i])\n",
    "            current_combination.pop()\n",
    "\n",
    "    # Start the recursive process\n",
    "    find_logic_rules_recursive(0, 0)\n",
    "    return result\n",
    "\n",
    "\n",
    "def extract_rules(self, feature=None, activations=None, max_rule_len=float('inf'), max_rules=100, min_support=1, out_threshold=0.5):\n",
    "    ws = self.weight\n",
    "    t_in = self.phi_in.t\n",
    "    t_out = -self.b + inverse_sigmoid(torch.tensor(out_threshold))\n",
    "\n",
    "    rules = []\n",
    "    if feature is None:\n",
    "        features = range(self.out_features)\n",
    "    else:\n",
    "        features = [feature]\n",
    "    for i in features:\n",
    "        w = ws[i].to('cpu')\n",
    "        ti = t_in.to('cpu')\n",
    "        to = t_out[i].to('cpu')\n",
    "        rules.append(find_logic_rules(w, ti, to, activations, max_rule_len, max_rules, min_support))\n",
    "\n",
    "    return rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "282826b6-c20e-4ac5-90ea-4579dcb42b16",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.data import Data, Batch\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_activations(batch_ids, batch, attr):\n",
    "    if type(batch_ids) != list:\n",
    "        batch_ids = [batch_ids]\n",
    "    fig, axs = plt.subplots(1, len(batch_ids), figsize=(16*len(batch_ids), 8))\n",
    "    if type(axs) != np.ndarray: axs = np.array([axs])\n",
    "    for i, batch_id in enumerate(batch_ids):\n",
    "        node_mask = batch.batch == batch_id  # Get nodes where batch == 0\n",
    "        node_indices = torch.nonzero(node_mask, as_tuple=True)[0]\n",
    "        \n",
    "        subgraph_edge_mask = (batch.batch[batch.edge_index[0]] == batch_id) & \\\n",
    "                             (batch.batch[batch.edge_index[1]] == batch_id)\n",
    "        subgraph_edges = batch.edge_index[:, subgraph_edge_mask]\n",
    "        \n",
    "        node_mapping = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(node_indices)}\n",
    "        remapped_edges = torch.tensor([[node_mapping[e.item()] for e in edge] for edge in subgraph_edges.T])\n",
    "        \n",
    "        G = nx.Graph()\n",
    "        G.add_edges_from(remapped_edges.numpy())\n",
    "        \n",
    "        nx.set_node_attributes(G, {v: k for k, v in node_mapping.items()}, \"original_id\")\n",
    "        \n",
    "        node_colors = []\n",
    "        node_borders = []\n",
    "        \n",
    "        for node in G.nodes:\n",
    "            if attr[batch.batch==batch_id][node] == 1:\n",
    "                node_colors.append(\"lightblue\")  # Fill color\n",
    "                node_borders.append(\"red\")  # Border color for attr == 1\n",
    "            else:\n",
    "                node_colors.append(\"lightblue\")  # Fill color\n",
    "                node_borders.append(\"black\")  # Default border color\n",
    "        \n",
    "        \n",
    "        pos = nx.kamada_kawai_layout(G) \n",
    "        \n",
    "        nx.draw(\n",
    "            G, pos,\n",
    "            node_color=node_colors,\n",
    "            edgecolors=node_borders,  # Border colors\n",
    "            node_size=700,\n",
    "            with_labels=True,\n",
    "            ax = axs[i]\n",
    "        )\n",
    "        \n",
    "        axs[i].set_title(f\"Class = {batch.y[batch_id]}\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56862307-99d9-4bf1-844a-33d77a317d14",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_sparsity_epoch(model_tell, data, mask, device, optimizer, num_classes, conv_reg=1, fc_reg=1):\n",
    "    model_tell.train()\n",
    "    data = data.to(device)\n",
    "    total_loss = 0\n",
    "    total_correct = 0\n",
    "    \n",
    "        # try:\n",
    "    loss = 0\n",
    "    if data.x is None:\n",
    "        data.x = torch.ones((data.num_nodes, model_tell.num_features))\n",
    "        \n",
    "    y = data.y.reshape(-1).to(device).long()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    model_tell.fc.phi_in.tau = 10\n",
    "    out = model_tell(data.x.float().to(device), data.edge_index.to(device))       \n",
    "    pred = out.argmax(-1)\n",
    "    loss += F.binary_cross_entropy(out.reshape(-1), torch.nn.functional.one_hot(y, num_classes=num_classes).float().reshape(-1)) + F.nll_loss(F.log_softmax(out, dim=-1), y.long())\n",
    "    # tells = [c.nn[0] for c in model_tell.convs] + [model_tell.fc]\n",
    "    for conv in model_tell.convs:\n",
    "        #loss += conv_reg*(100*hoyer_sparsity_loss(torch.clamp(conv.nn[0].weight, min=1e-5)) + conv.nn[0].reg_loss + conv.nn[0].phi_in.entropy)\n",
    "        loss += fc_reg*(torch.sqrt(torch.clamp(conv.nn[0].weight, min=1e-5)).sum(-1).mean()+ conv.nn[0].phi_in.entropy)\n",
    "    # loss += fc_reg*(100*hoyer_sparsity_loss(torch.clamp(model_tell.fc.weight, min=1e-5)) + model_tell.fc.reg_loss + model_tell.fc.phi_in.entropy)\n",
    "    loss += fc_reg*(torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean() + model_tell.fc.phi_in.entropy)\n",
    "    loss.backward()\n",
    "    zero_nan_gradients(model_tell)\n",
    "    optimizer.step()\n",
    "    total_loss += loss.item() \n",
    "    total_correct += pred.eq(y).sum().item() / data.x.shape[0]\n",
    "        # except Exception as e:\n",
    "        #     print(e)\n",
    "        #     pass\n",
    "\n",
    "    return total_loss, total_correct\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdb6e3e1-bd6b-4194-b842-a6de9878a7a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch_scatter\n",
    "\n",
    "def scatter_sum(x, edge_index):\n",
    "    # Get target nodes (i.e., the nodes receiving the messages)\n",
    "    target_nodes = edge_index[1]\n",
    "    \n",
    "    # Perform scatter sum\n",
    "    out = torch_scatter.scatter_add(x[edge_index[0]], target_nodes, dim=0, dim_size=x.size(0)) + x\n",
    "\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "201b2ba9-5346-4ff2-b482-dc17a0d2d814",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def forward_with_activations(self, x, edge_index, *args, **kwargs):\n",
    "    returns = []\n",
    "    x = self.input_bnorm(x)\n",
    "    xs = []\n",
    "    for i, conv in enumerate(self.convs):\n",
    "        ret = {}\n",
    "        ret['x'] = torch.hstack([x, 1-x])\n",
    "        ret['x_sum'] = scatter_sum(ret['x'], edge_index)\n",
    "        ret['x_bin'] = conv.nn[0].phi_in(ret['x_sum']) >= 0.5\n",
    "        x = conv(torch.hstack([x, 1-x]), edge_index)\n",
    "        xs.append(x)\n",
    "        ret['y'] = x\n",
    "        ret['y_bin'] = x>=0.5\n",
    "        returns.append(ret)\n",
    "    ret = {}\n",
    "    x = torch.hstack(xs)\n",
    "    x = self.output_bnorm(x)\n",
    "    ret['x'] = torch.hstack([x, 1-x])\n",
    "    ret['x_bin'] = self.fc.phi_in(ret['x']) >= 0.5\n",
    "    x = self.fc(torch.hstack([x, 1-x]))\n",
    "    ret['y'] = x\n",
    "    ret['y_bin'] = x>=0.5\n",
    "    returns.append(ret)\n",
    "    return x, returns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da146ea5-220e-4f94-bde4-5111edf6d8d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid(x, tau=10):\n",
    "    return 1/(1+torch.exp(-tau*x))\n",
    "    \n",
    "def forward_tell(self, tau):\n",
    "    def fw(x):\n",
    "\n",
    "        # x = self.phi_in(torch.hstack([x, 1-x]))\n",
    "        x = self.phi_in(x)\n",
    "        self.max_in, _ = x.max(0)\n",
    "        reg_loss = 0\n",
    "        entropy_loss = 0\n",
    "        if self.use_weight_sigma:\n",
    "            reg_loss += torch.clamp(self.weight_s, min=1e-5).sum(-1).mean()\n",
    "        else:\n",
    "            reg_loss += torch.clamp(self.weight, min=1e-5).sum(-1).mean()\n",
    "        if self.phi_in.entropy is not None:\n",
    "            entropy_loss += self.phi_in.entropy\n",
    "        # print('b', reg_loss, entropy_loss)\n",
    "        self.reg_loss = reg_loss\n",
    "        \n",
    "        w = self.weight\n",
    "        o = sigmoid(x @ w.t() + self.b, tau=tau)\n",
    "        \n",
    "        self.entropy_loss = entropy_loss + -(o*torch.log(o+1e-8) + (1-o)*torch.log(1-o + 1e-8)).mean()\n",
    "        return o\n",
    "    return fw\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0805169-f70b-46f4-8df3-5d70b8be223b",
   "metadata": {},
   "source": [
    "# Extract Rules for Last Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93945d12-5023-411c-8c5f-fca93c733147",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell = model_tell.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1942390-6bcb-4742-9c19-dd33d7fef53b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import io\n",
    "\n",
    "def clone_model(model):\n",
    "    # Save the model to an in-memory buffer\n",
    "    buffer = io.BytesIO()\n",
    "    # model.dropout = None\n",
    "    torch.save(model, buffer)\n",
    "    # model.dropout = lambda x:x\n",
    "    \n",
    "    # Rewind the buffer\n",
    "    buffer.seek(0)\n",
    "    \n",
    "    \n",
    "    # Load the saved state into the new instance\n",
    "    cloned_model = torch.load(buffer)\n",
    "    # cloned_model.dropout = lambda x:x\n",
    "    return cloned_model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a29e00bd-a899-442c-a70c-fec3b36c60a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam([model_tell.fc.weight_sigma, model_tell.fc.weight_exp], lr=0.01, weight_decay=0)\n",
    "\n",
    "# for param in model_tell.parameters():\n",
    "#     param.requires_grad = False\n",
    "    \n",
    "print(\"Pruning last layer\")\n",
    "# model_tell.fc.weight_sigma.requires_grad = True\n",
    "# model_tell.fc.weight_exp.requires_grad = True\n",
    "# model_tell.fc.phi_in.w.requires_grad = True\n",
    "# model_tell.fc.phi_in.b.requires_grad = True\n",
    "best_weights = clone_model(model_tell)\n",
    "val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)\n",
    "n_w =  (model_tell.fc.weight>1e-4).sum().item()\n",
    "best_situation = (val_acc, -n_w)\n",
    "patience = max_patience = 50\n",
    "\n",
    "for i in range(1000):\n",
    "    # model_tell.fc.forward = forward_tell(model_tell.fc, 10)\n",
    "    train_loss, train_acc = train_sparsity_epoch(model_tell, dataset[0], dataset[0].train_mask, device, optimizer, num_classes, conv_reg=0.1, fc_reg=0.1)\n",
    "    # model_tell.fc.forward = forward_tell(model_tell.fc, 1000)\n",
    "    val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)\n",
    "    test_acc = test_epoch(model_tell, dataset[0], dataset[0].test_mask, device)\n",
    "    n_w =  (model_tell.fc.weight>1e-4).sum().item()\n",
    "    if (val_acc, -n_w) > best_situation or (-n_w > best_situation[1] and val_acc >= 0.95*best_situation[0]):\n",
    "        best_weights = clone_model(model_tell)\n",
    "        patience = max_patience\n",
    "        best_situation = (val_acc, -n_w)\n",
    "    patience -= 1 \n",
    "    if i%10 == 0:\n",
    "        print(i, train_loss, train_acc, val_acc, test_acc, n_w, patience)\n",
    "    if patience == 0:\n",
    "        break\n",
    "model_tell = clone_model(best_weights)\n",
    "val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)\n",
    "test_acc = test_epoch(model_tell, dataset[0], dataset[0].test_mask, device)\n",
    "n_w =  (model_tell.fc.weight>1e-4).sum().item()\n",
    "print(val_acc, test_acc, n_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b339bf6-5e07-4e6f-8090-cebb127b5c0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "extract_rules(model_tell.fc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5b84c0b-5377-4442-a280-9412d8c76e98",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell = model_tell.to(device)\n",
    "\n",
    "\n",
    "best_weights = clone_model(model_tell)\n",
    "\n",
    "for l in reversed(range(len(model_tell.convs))):\n",
    "    optimizer = torch.optim.Adam([model_tell.convs[l].nn[0].weight_sigma, model_tell.convs[l].nn[0].weight_exp], lr=0.005)\n",
    "    \n",
    "    val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)\n",
    "    test_acc = test_epoch(model_tell, dataset[0], dataset[0].test_mask, device)\n",
    "    n_w =  (model_tell.convs[l].nn[0].weight>1e-4).sum().item()\n",
    "    best_situation = (val_acc, -n_w)\n",
    "    print(best_situation)\n",
    "    patience = max_patience = 50\n",
    "    for i in range(3000):\n",
    "        # tell_layer.forward = forward_tell(tell_layer, 10)\n",
    "        train_loss, train_acc = train_sparsity_epoch(model_tell, dataset[0], dataset[0].train_mask, device, optimizer, num_classes, conv_reg=0.1, fc_reg=0.1)\n",
    "        # model_tell.fc.forward = forward_tell(model_tell.fc, 1000)\n",
    "        val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)\n",
    "        test_acc = test_epoch(model_tell, dataset[0], dataset[0].test_mask, device)\n",
    "        n_w =  (model_tell.convs[l].nn[0].weight>1e-4).sum().item()\n",
    "        if (val_acc, -n_w) > best_situation or (-n_w > best_situation[1] and val_acc >= 0.95*best_situation[0]):\n",
    "            best_weights = clone_model(model_tell)\n",
    "            patience = max_patience\n",
    "            best_situation = (val_acc, -n_w)\n",
    "        patience -= 1 \n",
    "        if i%10 == 0:\n",
    "            print(i, train_loss, train_acc, val_acc, test_acc, n_w, patience)\n",
    "        if patience == 0:\n",
    "            break    \n",
    "    model_tell = clone_model(best_weights)\n",
    "    val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)\n",
    "    test_acc = test_epoch(model_tell, dataset[0], dataset[0].test_mask, device)\n",
    "    n_w =  (model_tell.convs[l].nn[0].weight>1e-4).sum().item()\n",
    "    print(val_acc, test_acc, n_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "701739a9-3a86-4cd4-9f53-2099ef4fbbd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    activations = None\n",
    "    _, activations = forward_with_activations(model_tell, dataset[0].x.to(device), dataset[0].edge_index.to(device))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "005bcf33-5c00-40c7-9d14-3d3873fea00e",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_layer_rules = extract_rules(model_tell.fc, activations=activations[-1]['x_bin'].cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff12a665-405e-4451-86a1-f3cc69c5618c",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat_map = []\n",
    "for pos_neg in ['pos', 'neg']:\n",
    "    for l in range(num_layers):\n",
    "        for d in range(hidden_dim):\n",
    "            # feat_map.append(f'{pos_neg}_{readout}_{l}_{d}')\n",
    "            feat_map.append((pos_neg, l, d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "754a780b-41f5-4dae-8039-d7f27df4df39",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_layer_rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "414207d3-2bbf-414b-9e64-2b2077413496",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_layer_rules_renamed = []\n",
    "for c in range(len(last_layer_rules)):\n",
    "    s = set()\n",
    "    for t in last_layer_rules[c]:\n",
    "        new_t = []\n",
    "        for i in t:\n",
    "            if not activations[-1]['x_bin'][:,i].all().item():\n",
    "                new_t.append((i, *feat_map[i]))\n",
    "        s.add(tuple(new_t))\n",
    "    last_layer_rules_renamed.append(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6d5b120-e577-4b01-8fe2-5420dded7913",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_layer_rules_renamed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10642270-acc3-44bc-9337-79d2311acc30",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.data import Batch, Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8807bd95-39f0-4cbd-9566-e521c374b7ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils import to_networkx, subgraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84490661-35e4-4c8b-994e-fbd9ccf696de",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_subgraph(data, node_mask):\n",
    "    nodes_to_keep = torch.where(node_mask)[0]\n",
    "    new_edge_index = subgraph(nodes_to_keep, data.edge_index, relabel_nodes=True)[0]\n",
    "    new_x = data.x[node_mask]\n",
    "    G = to_networkx(Data(x=new_x, edge_index=new_edge_index))\n",
    "    return G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cce867f-a877-41a9-82f1-d4ca73e73bbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_layer_rules_renamed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4ef2ea9-82ae-477d-aa8b-6a843057b4f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_minimal_sets(list_of_sets):\n",
    "    minimal_sets = []\n",
    "    for i, s in enumerate(list_of_sets):\n",
    "        if not any((set(other)<set(s)) or (s == other and i != j) for j, other in enumerate(list_of_sets)):\n",
    "            minimal_sets.append(s)\n",
    "    return minimal_sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ae51ea1-4adc-415a-b096-b8349b498c00",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_layer_rules_renamed = [find_minimal_sets(r) for r in last_layer_rules_renamed]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08a296ce-1a4c-498d-9e50-0274af0d2624",
   "metadata": {},
   "outputs": [],
   "source": [
    "tells = [c.nn[0] for c in model_tell.convs] + [model_tell.fc]\n",
    "for tell in tells:\n",
    "    tell.forward = forward_tell(tell, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb6615f5-6c33-4731-92bd-fddb3e81dc83",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_step_intervals(w, b, xmin, xmax, tau=5, resolution=1000):\n",
    "    # Sample x values\n",
    "    xs = torch.linspace(xmin, xmax, resolution)\n",
    "    wxb = w * xs + b\n",
    "    ys = step(wxb, tau)\n",
    "\n",
    "    intervals = []\n",
    "    above = ys[0] > 0.5\n",
    "    start = xs[0].item() if above else None\n",
    "\n",
    "    for i in range(1, len(xs)):\n",
    "        curr = ys[i] > 0.5\n",
    "        if curr and not above:\n",
    "            # Rising edge\n",
    "            start = xs[i-1].item()\n",
    "        elif not curr and above:\n",
    "            # Falling edge\n",
    "            end = xs[i].item()\n",
    "            intervals.append((start, end))\n",
    "            start = None\n",
    "        above = curr\n",
    "\n",
    "    if above and start is not None:\n",
    "        intervals.append((start, xs[-1].item()))\n",
    "\n",
    "    return intervals\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "690bd582-e743-41eb-83c0-fc15a86f438c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tell import step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bd7d06f-cdb0-4752-b5fd-b2500fcb96de",
   "metadata": {},
   "outputs": [],
   "source": [
    "conv_rules = []\n",
    "fc_rules = []\n",
    "model_tell = model_tell.cpu()\n",
    "class_to_explain = 0\n",
    "for rule in last_layer_rules_renamed[class_to_explain]:\n",
    "    print(rule)\n",
    "    # for literal in rule:\n",
    "\n",
    "    \n",
    "    ll_feats = [literal[0] for literal in rule]\n",
    "    # if not rets[-1]['x_bin'][:,ll_feats].all(): continue\n",
    "    node_mask = torch.ones(dataset[0].x.shape[0]).bool()\n",
    "    phi_in = model_tell.fc.phi_in\n",
    "    ands = []\n",
    "    node_mask = None\n",
    "    for literal in rule:\n",
    "        m = (activations[-1]['x_bin'][:,literal[0]].detach().cpu())\n",
    "        if node_mask is None:\n",
    "            node_mask = m\n",
    "        else: node_mask &= m\n",
    "        intervals = find_step_intervals(phi_in.w[literal[0]].cpu(), phi_in.b[literal[0]].cpu(), activations[-1]['x'][:, literal[0]].min().cpu(), activations[-1]['x'][:, literal[0]].max().cpu(), tau=10, resolution=1000)\n",
    "        ands.append((literal, intervals))\n",
    "    print(node_mask)\n",
    "    if node_mask.sum().item() == 0: continue\n",
    "    fc_rules.append(ands)\n",
    "\n",
    "    for literal,_ in ands:\n",
    "        conv_rules.append(literal[-2:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a2895fe-5787-4da1-af5b-09497e086b2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "activations[-1]['x'][:,159].median()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8da7cecc-783f-4bd6-9851-bb0ad384b916",
   "metadata": {},
   "outputs": [],
   "source": [
    "fc_rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "779ba91d-672e-4d5b-8859-3d236a23ef8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "conv_rules_details = {}\n",
    "while(len(conv_rules)):\n",
    "    layer, feat = conv_rules.pop(0)\n",
    "    phi_in = model_tell.convs[layer].nn[0].phi_in\n",
    "    hd = phi_in.w.shape[0]//2\n",
    "    ho = model_tell.convs[layer].nn[0].weight.shape[0]\n",
    "    rules = extract_rules(model_tell.convs[layer].nn[0], feature=feat%ho)\n",
    "    refined_rules = []\n",
    "    for c in range(len(rules)):\n",
    "        s = set()\n",
    "        for t in rules[c]:\n",
    "            new_t = []\n",
    "            for i in t:\n",
    "                if not activations[layer]['x_bin'][:,i].all().item():\n",
    "                    intervals = find_step_intervals(phi_in.w[i].cpu(), phi_in.b[i].cpu(), (activations[layer]['x_sum'][:, i]).min().cpu(), (activations[layer]['x_sum'][:, i]).max().cpu(), tau=10, resolution=1000)\n",
    "                    new_t.append((i, tuple(intervals)))\n",
    "                    if layer!=0:\n",
    "                        if (layer-1, i) not in conv_rules_details and (layer-1, i) not in conv_rules:\n",
    "                            conv_rules.append((layer-1, i))\n",
    "            if new_t:\n",
    "                s.add(tuple(new_t))\n",
    "            else:\n",
    "                s.add((True,))\n",
    "        refined_rules.append(s)\n",
    "    # print(rules)\n",
    "    # print(refined_rules)\n",
    "    # print(layer, feat, find_minimal_sets(refined_rules[0]))\n",
    "\n",
    "    conv_rules_details[layer, feat] = (find_minimal_sets(refined_rules[0]), activations[layer]['y_bin'][:,feat%ho] if feat < ho else ~activations[layer]['y_bin'][:,feat%ho])\n",
    "    print(conv_rules)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f615f9d1-7b5d-4cc0-a40d-9a0b54693240",
   "metadata": {},
   "outputs": [],
   "source": [
    "# conv_rules_details"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb71fc7c-eda3-44f3-bff7-3624d0471fee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "\n",
    "dict(Counter([x[0]+1 for x in conv_rules_details.keys()if len(conv_rules_details[x][0])!=0 ]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8133760c-fe15-4e4f-8272-93828f834b26",
   "metadata": {},
   "outputs": [],
   "source": [
    "acts = []\n",
    "for i in range(5):\n",
    "    rs = [activations[x[0]]['x_bin'][:,[x[1]%activations[x[0]]['x_bin'].shape[1]]].float().mean().item() for x in conv_rules_details.keys() if x[0] == i]\n",
    "    acts.append(rs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3981e922-96a0-4273-af0b-d6ae13842e8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "acts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2357fb-81fe-4718-9fa9-b35bc438ca61",
   "metadata": {},
   "outputs": [],
   "source": [
    "conv_rules_details"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-imp]",
   "language": "python",
   "name": "conda-env-.conda-imp-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
