{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c70e6d8f-d238-40f0-911d-2a57728809b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from syn_dataset import SynGraphDataset\n",
    "from spmotif_dataset import *\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import GINConv, global_mean_pool, global_max_pool, global_add_pool\n",
    "from utils import *\n",
    "from sklearn.model_selection import train_test_split\n",
    "import shutil\n",
    "import glob\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "import pandas as pd\n",
    "import argparse\n",
    "import pickle\n",
    "import json\n",
    "import io\n",
    "from model import GIN, GINTELL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cf8f4f3-6497-407d-844a-63b5ff2f4aee",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'BBBP'\n",
    "seed = 2\n",
    "def get_best_baseline_path(dataset_name):\n",
    "    l = glob.glob(f'results/{dataset_name}/*/results.json')\n",
    "    fl = [json.load(open(f)) for f in l]\n",
    "    df = pd.DataFrame(fl)\n",
    "    if df.shape[0] == 0: return None\n",
    "    df['fname'] = l\n",
    "    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])\n",
    "    df = df[df.fname.str.contains('nogumbel=False')]\n",
    "    fname = df.iloc[-1]['fname']\n",
    "    fname = fname.replace('/results.json', '')\n",
    "    return fname\n",
    "\n",
    "def get_best_path(dataset_name):\n",
    "    l = glob.glob(f'results_logic/{dataset_name}/*/*/results.json')\n",
    "    fl = [json.load(open(f)) for f in l]\n",
    "    df = pd.DataFrame(fl)\n",
    "    if df.shape[0] == 0: return None\n",
    "    df['fname'] = l\n",
    "    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])\n",
    "    df = df[df.fname.str.contains('nogumbel=False')]\n",
    "    print(df.tail())\n",
    "    fname = df.iloc[-1]['fname']\n",
    "    fname = fname.replace('/results.json', '')\n",
    "    return fname\n",
    "\n",
    "\n",
    "results_path = os.path.join(get_best_path(dataset_name), str(seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08ae3d2d-6f6c-4cea-bbe1-a694dbc42efd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "data = pickle.load(open(os.path.join(results_path, 'data.pkl'), 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bc57df2-60f0-4942-bdee-fa82429010df",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = json.load(open(os.path.join(results_path, 'args.json'), 'r'))\n",
    "args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dcc9c46-5431-40c6-bdf2-8ea0d923ea8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a534b1de-729d-4740-a8bd-6b809bbbe849",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = get_dataset(dataset_name)\n",
    "num_classes = dataset.num_classes\n",
    "num_features = dataset.num_features\n",
    "num_layers = 5\n",
    "hidden_dim = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6648b363-d9c6-450f-b7ab-d6fe8765d63b",
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = list(range(len(dataset)))\n",
    "train_indices, val_test_indices = train_test_split(indices, test_size=0.2,\n",
    "shuffle=True, stratify=dataset.data.y, random_state=1)\n",
    "\n",
    "val_indices = val_test_indices[:len(val_test_indices)//2]\n",
    "test_indices = val_test_indices[len(val_test_indices)//2:]\n",
    "\n",
    "train_dataset = dataset[train_indices]\n",
    "val_dataset = dataset[val_indices]\n",
    "test_dataset = dataset[test_indices]\n",
    "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b09b8ad6-773a-4e90-9839-17c19b6a1e21",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell = torch.load(os.path.join(results_path, 'best.pt'), map_location=device)\n",
    "# model_tell.dropout = lambda x:x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de540eda-a805-45ab-a230-257aa16c0638",
   "metadata": {},
   "outputs": [],
   "source": [
    "# tells = [c.nn[0] for c in model_tell.convs] + [model_tell.fc]\n",
    "# for tell in tells:\n",
    "#     tell.forward = forward_tell(tell, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fad2c355-8f8a-4edf-969e-672e67a9bd1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# val_acc = test_epoch(model_tell, val_loader, device)\n",
    "# test_acc = test_epoch(model_tell, test_loader, device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d77c982-c351-4571-b719-9c96d010ea5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# val_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b31350b-5ceb-409b-adb8-4a290a0131e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# test_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fc6bff2-2d4e-4d25-864a-127716e9ad5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def inverse_sigmoid(x):\n",
    "    \"\"\"Computes the inverse of the sigmoid function (logit function).\"\"\"\n",
    "    return torch.log(x / (1 - x))\n",
    "\n",
    "torch.no_grad()\n",
    "def find_logic_rules(w, t_in, t_out, activations=None, max_rule_len=10, max_rules=100, min_support=5):\n",
    "    w = w.clone()\n",
    "    t_in = t_in.clone()\n",
    "    t_out = t_out.clone()\n",
    "    t_out = t_out.item()\n",
    "    ordering_scores = w\n",
    "    sorted_idxs = torch.argsort(ordering_scores, 0, descending=True)\n",
    "    mask = w > 1e-5\n",
    "    if activations is not None:\n",
    "        mask = mask & (activations.sum(0) >= min_support)\n",
    "    total_result = set()\n",
    "\n",
    "    # Filter and sort indices based on the mask\n",
    "    idxs_to_visit = sorted_idxs[mask[sorted_idxs]]\n",
    "    if idxs_to_visit.numel() == 0:\n",
    "        return total_result\n",
    "\n",
    "    # Sort weights based on the filtered indices\n",
    "    sorted_weights = w[idxs_to_visit]\n",
    "    current_combination = []\n",
    "    result = set()\n",
    "\n",
    "    def find_logic_rules_recursive(index, current_sum):\n",
    "        # Stop if the maximum number of rules has been reached\n",
    "        if len(result) >= max_rules:\n",
    "            return\n",
    "\n",
    "        if len(current_combination) > max_rule_len:\n",
    "            return\n",
    "\n",
    "        # Check if the current combination satisfies the condition\n",
    "        if current_sum >= t_out:\n",
    "            c = idxs_to_visit[current_combination].cpu().detach().tolist()\n",
    "            c = tuple(sorted(c))\n",
    "            result.add(c)\n",
    "            return\n",
    "\n",
    "        # Prune if remaining weights can't satisfy t_out\n",
    "        remaining_max_sum = current_sum + sorted_weights[index:].sum()\n",
    "        if remaining_max_sum < t_out:\n",
    "            return\n",
    "\n",
    "        # Explore further combinations\n",
    "        for i in range(index, idxs_to_visit.shape[0]):\n",
    "            # Prune based on activations if provided\n",
    "            if activations is not None and len(current_combination) > 0 and activations[:, idxs_to_visit[current_combination + [i]]].all(-1).sum().item() < min_support:\n",
    "                continue\n",
    "\n",
    "            current_combination.append(i)\n",
    "            find_logic_rules_recursive(i + 1, current_sum + sorted_weights[i])\n",
    "            current_combination.pop()\n",
    "\n",
    "    # Start the recursive process\n",
    "    find_logic_rules_recursive(0, 0)\n",
    "    return result\n",
    "\n",
    "\n",
    "def extract_rules(self, feature=None, activations=None, max_rule_len=float('inf'), max_rules=5, min_support=10, out_threshold=0.5):\n",
    "    ws = self.weight\n",
    "    t_in = self.phi_in.t\n",
    "    t_out = -self.b + inverse_sigmoid(torch.tensor(out_threshold))\n",
    "\n",
    "    rules = []\n",
    "    if feature is None:\n",
    "        features = range(self.out_features)\n",
    "    else:\n",
    "        features = [feature]\n",
    "    for i in features:\n",
    "        w = ws[i].to('cpu')\n",
    "        ti = t_in.to('cpu')\n",
    "        to = t_out[i].to('cpu')\n",
    "        rules.append(find_logic_rules(w, ti, to, activations, max_rule_len, max_rules, min_support))\n",
    "\n",
    "    return rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ecb1e08-6efc-4468-86ab-31dd467277cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.ndarray"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "282826b6-c20e-4ac5-90ea-4579dcb42b16",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.data import Data, Batch\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "def plot_activations(batch_ids, batch, attr):\n",
    "    if type(batch_ids) != list:\n",
    "        batch_ids = [batch_ids]\n",
    "    num_ids = len(batch_ids)\n",
    "    cols = 5\n",
    "    rows = math.ceil(num_ids / cols)\n",
    "    \n",
    "    fig, axs = plt.subplots(rows, cols, figsize=(16*5, 8 * rows))\n",
    "    \n",
    "    if type(axs) != np.ndarray: axs = np.array([axs])\n",
    "    # Flatten axs if it's 2D to simplify indexing\n",
    "    axs = axs.flatten()\n",
    "    for ax in axs:\n",
    "        ax.set_axis_off()\n",
    "    for i, batch_id in enumerate(batch_ids):\n",
    "        node_mask = batch.batch == batch_id  # Get nodes where batch == 0\n",
    "        if node_mask.float().sum() == 0: continue\n",
    "        node_indices = torch.nonzero(node_mask, as_tuple=True)[0]\n",
    "        \n",
    "        subgraph_edge_mask = (batch.batch[batch.edge_index[0]] == batch_id) & \\\n",
    "                             (batch.batch[batch.edge_index[1]] == batch_id)\n",
    "        subgraph_edges = batch.edge_index[:, subgraph_edge_mask]\n",
    "        \n",
    "        node_mapping = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(node_indices)}\n",
    "        remapped_edges = torch.tensor([[node_mapping[e.item()] for e in edge] for edge in subgraph_edges.T])\n",
    "        \n",
    "        G = nx.Graph()\n",
    "        G.add_edges_from(remapped_edges.numpy())\n",
    "        \n",
    "        nx.set_node_attributes(G, {v: k for k, v in node_mapping.items()}, \"original_id\")\n",
    "        \n",
    "        node_colors = []\n",
    "        node_borders = []\n",
    "        \n",
    "        for node in G.nodes:\n",
    "            if attr[batch.batch==batch_id][node] == 1:\n",
    "                node_colors.append(\"lightblue\")  # Fill color\n",
    "                node_borders.append(\"red\")  # Border color for attr == 1\n",
    "            else:\n",
    "                node_colors.append(\"lightblue\")  # Fill color\n",
    "                node_borders.append(\"black\")  # Default border color\n",
    "        \n",
    "        \n",
    "        pos = nx.kamada_kawai_layout(G) \n",
    "        \n",
    "        nx.draw(\n",
    "            G, pos,\n",
    "            node_color=node_colors,\n",
    "            edgecolors=node_borders,  # Border colors\n",
    "            node_size=700,\n",
    "            with_labels=True,\n",
    "            ax = axs[i]\n",
    "        )\n",
    "        \n",
    "        # axs[i].set_title(f\"Class = {batch.y[batch_id]}\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56862307-99d9-4bf1-844a-33d77a317d14",
   "metadata": {},
   "outputs": [],
   "source": [
    "def hoyer_sparsity_loss(weights, lambda_=1.0, epsilon=1e-12):\n",
    "    \"\"\"\n",
    "    Hoyer's sparsity loss to promote sparsity.\n",
    "    \n",
    "    Args:\n",
    "        weights (torch.Tensor): The weights to regularize.\n",
    "        lambda_ (float): Regularization strength.\n",
    "        epsilon (float): Small value to prevent division by zero.\n",
    "    \n",
    "    Returns:\n",
    "        torch.Tensor: The Hoyer's sparsity loss.\n",
    "    \"\"\"\n",
    "    l1_norm = torch.sum(torch.abs(weights), -1)\n",
    "    l2_norm = torch.sqrt(torch.sum(weights**2, -1) + epsilon)\n",
    "    hoyer = (torch.sqrt(torch.tensor(weights.numel())) - l1_norm / l2_norm) / \\\n",
    "            (torch.sqrt(torch.tensor(weights.numel())) - 1 + epsilon)\n",
    "    loss = lambda_ * (1 - hoyer)\n",
    "    return loss.mean()\n",
    "\n",
    "\n",
    "def train_sparsity_epoch(model_tell, loader, device, optimizer, num_classes, conv_reg=1, fc_reg=1):\n",
    "    model_tell.train()\n",
    "    \n",
    "    total_loss = 0\n",
    "    total_correct = 0\n",
    "    \n",
    "    for data in loader:\n",
    "        try:\n",
    "            loss = 0\n",
    "            if data.x is None:\n",
    "                data.x = torch.ones((data.num_nodes, model_tell.num_features))\n",
    "            if data.y.numel() == 0: continue\n",
    "            if data.x.isnan().any(): continue\n",
    "            if data.y.isnan().any(): continue\n",
    "            y = data.y.reshape(-1).to(device).long()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            model_tell.fc.phi_in.tau = 10\n",
    "            out = model_tell(data.x.float().to(device), data.edge_index.to(device), data.batch.to(device))       \n",
    "            pred = out.argmax(-1)\n",
    "            loss += F.binary_cross_entropy(out.reshape(-1), torch.nn.functional.one_hot(y, num_classes=num_classes).float().reshape(-1)) + F.nll_loss(F.log_softmax(out, dim=-1), y.long())\n",
    "            # tells = [c.nn[0] for c in model_tell.convs] + [model_tell.fc]\n",
    "            for conv in model_tell.convs:\n",
    "                loss += conv_reg*(torch.sqrt(torch.clamp(conv.nn[0].weight, min=1e-5)).sum(-1).mean()+ conv.nn[0].phi_in.entropy)\n",
    "                #oss += (hoyer_sparsity_loss(torch.clamp(conv.nn[0].weight, min=1e-5)) + conv.nn[0].reg_loss + conv.nn[0].phi_in.entropy) + conv_reg*(torch.sqrt(torch.clamp(conv.nn[0].weight, min=1e-5)).sum(-1).mean()+ conv.nn[0].phi_in.entropy)\n",
    "                # loss += conv_reg*(torch.sqrt(torch.clamp(conv.nn[0].weight, min=1e-5)).sum(-1).mean()+ conv.nn[0].phi_in.entropy)\n",
    "            # loss += fc_reg*(torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean() + model_tell.fc.phi_in.entropy)\n",
    "            loss += (hoyer_sparsity_loss(torch.clamp(model_tell.fc.weight, min=1e-5)) + model_tell.fc.reg_loss + model_tell.fc.phi_in.entropy) + fc_reg*(torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean() + model_tell.fc.phi_in.entropy)\n",
    "            # loss += fc_reg*(torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean() + model_tell.fc.phi_in.entropy)\n",
    "            loss.backward()\n",
    "            zero_nan_gradients(model_tell)\n",
    "            optimizer.step()\n",
    "            total_loss += loss.item() * data.num_graphs / len(loader.dataset)\n",
    "            total_correct += pred.eq(y).sum().item() / len(loader.dataset)\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            pass\n",
    "\n",
    "    return total_loss, total_correct\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "201b2ba9-5346-4ff2-b482-dc17a0d2d814",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch_scatter\n",
    "\n",
    "import torch\n",
    "import torch_scatter\n",
    "def scatter_sum(x, edge_index):\n",
    "    # Get target nodes (i.e., the nodes receiving the messages)\n",
    "    target_nodes = edge_index[1]\n",
    "    \n",
    "    # Perform scatter sum\n",
    "    out = torch_scatter.scatter_add(x[edge_index[0]], target_nodes, dim=0, dim_size=x.size(0)) + x\n",
    "\n",
    "    return out\n",
    "@torch.no_grad()\n",
    "def forward_with_activations(self, x, edge_index, batch, *args, **kwargs):\n",
    "    returns = []\n",
    "    x = self.input_bnorm(x)\n",
    "    xs = []\n",
    "    for i, conv in enumerate(self.convs):\n",
    "        ret = {}\n",
    "        ret['x'] = torch.hstack([x, 1-x])\n",
    "        ret['x_sum'] = scatter_sum(ret['x'], edge_index)\n",
    "        ret['x_bin'] = conv.nn[0].phi_in(ret['x_sum']) >= 0.5\n",
    "        x = conv(torch.hstack([x, 1-x]), edge_index)\n",
    "        xs.append(x)\n",
    "        ret['y'] = x\n",
    "        ret['y_bin'] = x>=0.5\n",
    "        ret['batch'] = batch\n",
    "        returns.append(ret)\n",
    "    ret = {}\n",
    "    x_mean = global_mean_pool(torch.hstack(xs), batch)\n",
    "    x_max = global_max_pool(torch.hstack(xs), batch)\n",
    "    x_sum = global_add_pool(torch.hstack(xs), batch)\n",
    "    x = torch.hstack([x_mean, x_max, x_sum])\n",
    "    x = self.output_bnorm(x)\n",
    "    ret['x'] = torch.hstack([x, 1-x])\n",
    "    ret['x_bin'] = self.fc.phi_in(ret['x']) >= 0.5\n",
    "    x = self.fc(torch.hstack([x, 1-x]))\n",
    "    ret['y'] = x\n",
    "    ret['y_bin'] = x>=0.5\n",
    "    returns.append(ret)\n",
    "    return x, returns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da146ea5-220e-4f94-bde4-5111edf6d8d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid(x, tau=10):\n",
    "    return 1/(1+torch.exp(-tau*x))\n",
    "    \n",
    "def forward_tell(self, tau):\n",
    "    def fw(x):\n",
    "\n",
    "        # x = self.phi_in(torch.hstack([x, 1-x]))\n",
    "        x = self.phi_in(x)\n",
    "        self.max_in, _ = x.max(0)\n",
    "        reg_loss = 0\n",
    "        entropy_loss = 0\n",
    "        if self.use_weight_sigma:\n",
    "            reg_loss += torch.clamp(self.weight_s, min=1e-5).sum(-1).mean()\n",
    "        else:\n",
    "            reg_loss += torch.clamp(self.weight, min=1e-5).sum(-1).mean()\n",
    "        if self.phi_in.entropy is not None:\n",
    "            entropy_loss += self.phi_in.entropy\n",
    "        # print('b', reg_loss, entropy_loss)\n",
    "        self.reg_loss = reg_loss\n",
    "        \n",
    "        w = self.weight\n",
    "        o = sigmoid(x @ w.t() + self.b, tau=tau)\n",
    "        \n",
    "        self.entropy_loss = entropy_loss + -(o*torch.log(o+1e-8) + (1-o)*torch.log(1-o + 1e-8)).mean()\n",
    "        return o\n",
    "    return fw\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0805169-f70b-46f4-8df3-5d70b8be223b",
   "metadata": {},
   "source": [
    "# Extract Rules for Last Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93945d12-5023-411c-8c5f-fca93c733147",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell = model_tell.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1942390-6bcb-4742-9c19-dd33d7fef53b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from train_logic import test_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c5e07e5-567a-4b96-a75f-085ab64ed6ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_tell.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbfe543d-7278-438c-a52b-a98b76e43be2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import io\n",
    "\n",
    "def clone_model(model):\n",
    "    # Save the model to an in-memory buffer\n",
    "    buffer = io.BytesIO()\n",
    "    model.dropout = None\n",
    "    torch.save(model, buffer)\n",
    "    model.dropout = lambda x:x\n",
    "    \n",
    "    # Rewind the buffer\n",
    "    buffer.seek(0)\n",
    "    \n",
    "    \n",
    "    # Load the saved state into the new instance\n",
    "    cloned_model = torch.load(buffer)\n",
    "    cloned_model.dropout = lambda x:x\n",
    "    return cloned_model\n",
    "\n",
    "# def clone_model(model):\n",
    "#     # Save the model to an in-memory buffer\n",
    "#     buffer = io.BytesIO()\n",
    "#     # model.dropout = None\n",
    "#     torch.save(model, buffer)\n",
    "#     # model.dropout = lambda x:x\n",
    "    \n",
    "#     # Rewind the buffer\n",
    "#     buffer.seek(0)\n",
    "    \n",
    "    \n",
    "#     # Load the saved state into the new instance\n",
    "#     cloned_model = torch.load(buffer)\n",
    "#     # cloned_model.dropout = lambda x:x\n",
    "#     return cloned_model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a29e00bd-a899-442c-a70c-fec3b36c60a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "optimizer = torch.optim.Adam([model_tell.fc.weight_sigma, model_tell.fc.weight_exp], lr=0.01)\n",
    "\n",
    "# for param in model_tell.parameters():\n",
    "#     param.requires_grad = False\n",
    "    \n",
    "# print(\"Pruning last layer\")\n",
    "# model_tell.fc.weight_sigma.requires_grad = True\n",
    "# model_tell.fc.weight_exp.requires_grad = True\n",
    "# model_tell.fc.phi_in.w.requires_grad = True\n",
    "# model_tell.fc.phi_in.b.requires_grad = True\n",
    "best_weights = clone_model(model_tell)\n",
    "val_acc = test_epoch(model_tell, val_loader, device)\n",
    "n_w =  (model_tell.fc.weight>1e-4).sum().item()\n",
    "best_situation = (val_acc, -n_w)\n",
    "patience = max_patience = 100\n",
    "for i in range(1000):\n",
    "    # model_tell.fc.forward = forward_tell(model_tell.fc, 10)\n",
    "    train_loss, train_acc = train_sparsity_epoch(model_tell, train_loader, device, optimizer, num_classes, conv_reg=0.1, fc_reg=0.1)\n",
    "    # model_tell.fc.forward = forward_tell(model_tell.fc, 1000)\n",
    "    val_acc = test_epoch(model_tell, val_loader, device)\n",
    "    test_acc = test_epoch(model_tell, test_loader, device)\n",
    "    n_w =  (model_tell.fc.weight>1e-4).sum().item()\n",
    "    # if (val_acc, -n_w) > best_situation:\n",
    "    if (val_acc, -n_w) > best_situation or (-n_w > best_situation[1] and val_acc >= 0.99*best_situation[0]):\n",
    "        best_weights = clone_model(model_tell)\n",
    "        patience = max_patience\n",
    "        best_situation = (val_acc, -n_w)\n",
    "    patience -= 1 \n",
    "    if i%10 == 0:\n",
    "        print(i, train_loss, train_acc, val_acc, test_acc, n_w, patience)\n",
    "    if patience == 0:\n",
    "        break\n",
    "model_tell = clone_model(best_weights)\n",
    "val_acc = test_epoch(model_tell, val_loader, device)\n",
    "test_acc = test_epoch(model_tell, test_loader, device)\n",
    "n_w =  (model_tell.fc.weight>1e-4).sum().item()\n",
    "print(val_acc, test_acc, n_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d772a1b-0f97-40c8-b692-1eec65a59065",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_epoch(best_weights, val_loader, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fd13e76-fb2b-49c7-984e-a05c79a32558",
   "metadata": {},
   "outputs": [],
   "source": [
    "extract_rules(model_tell.fc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d6609a3-72d3-47e6-9085-f1c80a79a506",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model_tell = model_tell.to(device)\n",
    "\n",
    "# optimizer = torch.optim.Adam(model_tell.parameters(), lr=0.005)\n",
    "\n",
    "best_weights = clone_model(model_tell)\n",
    "\n",
    "for l in reversed(range(len(model_tell.convs))):\n",
    "    optimizer = torch.optim.Adam([model_tell.convs[l].nn[0].weight_sigma, model_tell.convs[l].nn[0].weight_exp], lr=0.005)\n",
    "    # for param in model_tell.parameters():\n",
    "        # param.requires_grad = False\n",
    "    # print(\"Pruning conv layer\", l)\n",
    "    # model_tell.convs[l].nn[0].weight_sigma.requires_grad = True\n",
    "    # model_tell.convs[l].nn[0].weight_exp.requires_grad = True\n",
    "    # model_tell.fc.phi_in.w.requires_grad = True\n",
    "    # model_tell.fc.phi_in.b.requires_grad = True\n",
    "    # best_weights = model_tell.state_dict()\n",
    "    val_acc = test_epoch(model_tell, val_loader, device)\n",
    "    test_acc = test_epoch(model_tell, test_loader, device)\n",
    "    n_w =  (model_tell.convs[l].nn[0].weight>1e-4).sum().item()\n",
    "    best_situation = (val_acc, -n_w)\n",
    "    print(best_situation)\n",
    "    patience = max_patience = 50\n",
    "    for i in range(1000):\n",
    "        # model_tell.fc.forward = forward_tell(model_tell.fc, 10)\n",
    "        train_loss, train_acc = train_sparsity_epoch(model_tell, train_loader, device, optimizer, num_classes, conv_reg=0.1, fc_reg=0.01)\n",
    "        val_acc = test_epoch(model_tell, val_loader, device)\n",
    "        test_acc = test_epoch(model_tell, test_loader, device)\n",
    "        n_w =  (model_tell.convs[l].nn[0].weight>1e-4).sum().item()\n",
    "        # if (val_acc, -n_w) > best_situation:\n",
    "        if (val_acc, -n_w) > best_situation or (-n_w > best_situation[1] and val_acc >= 0.99*best_situation[0]):\n",
    "            print((val_acc, -n_w), 'is better than', best_situation)\n",
    "            best_weights = clone_model(model_tell)\n",
    "            patience = max_patience\n",
    "            best_situation = (val_acc, -n_w)\n",
    "        patience -= 1 \n",
    "        if i%10 == 0:\n",
    "            print(i, train_loss, train_acc, val_acc, test_acc, n_w, patience, best_situation, test_epoch(best_weights, val_loader, device))\n",
    "        if patience == 0:\n",
    "            break\n",
    "    model_tell = clone_model(best_weights)\n",
    "    val_acc = test_epoch(model_tell, val_loader, device)\n",
    "    test_acc = test_epoch(model_tell, test_loader, device)\n",
    "    n_w =  (model_tell.convs[l].nn[0].weight>1e-4).sum().item()\n",
    "    print(val_acc, test_acc, n_w)\n",
    "    \n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ed29661-3575-497c-856c-d99c4ecf2594",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell = best_weights.to(device)\n",
    "val_acc = test_epoch(model_tell, val_loader, device)\n",
    "test_acc = test_epoch(model_tell, test_loader, device)\n",
    "# n_w =  (model_tell.convs[l].nn[0].weight>1e-4).sum().item()\n",
    "print(val_acc, test_acc, n_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f46a622-7d6c-4ff4-950b-0bc5cb56dd8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_epoch(best_weights, test_loader, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62862f15-66db-4a75-9d4d-c750887084c7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b91a80c-8459-4a7b-80f2-d120c19963ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell = best_weights\n",
    "val_acc = test_epoch(model_tell, val_loader, device)\n",
    "test_acc = test_epoch(model_tell, test_loader, device)\n",
    "n_w =  (model_tell.convs[l].nn[0].weight>1e-4).sum().item()\n",
    "print(val_acc, test_acc, n_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e8061f0-4818-4729-b384-b1ea21ee3a96",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_acc = test_epoch(model_tell, val_loader, device)\n",
    "val_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0185a11c-2dcd-4b08-aa75-ef8e19fd1387",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_situation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "701739a9-3a86-4cd4-9f53-2099ef4fbbd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    activations = None\n",
    "    for batch in val_loader:\n",
    "        _, rets = forward_with_activations(model_tell.to(device), batch.x.to(device).float(), batch.edge_index.to(device), batch.batch.to(device))\n",
    "        # activations.append(model_tell.fc.phi_in(torch.hstack([xs[-1], 1-xs[-1]])))\n",
    "        if activations is None:\n",
    "            activations = rets\n",
    "        else:\n",
    "            for l in range(len(rets)):\n",
    "                for k in rets[l]:\n",
    "                    if k == 'batch':\n",
    "                        rets[l]['batch'] += torch.max(activations[l]['batch'])+1\n",
    "                        activations[l][k] = torch.cat([activations[l][k], rets[l][k]])\n",
    "                    else:\n",
    "                        activations[l][k] = torch.vstack([activations[l][k], rets[l][k]])\n",
    "                    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "492d376a-e242-41d7-9281-d8a87dacf076",
   "metadata": {},
   "outputs": [],
   "source": [
    "activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "005bcf33-5c00-40c7-9d14-3d3873fea00e",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_layer_rules = extract_rules(model_tell.fc, activations=activations[-1]['x_bin'].cpu(), max_rules=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff12a665-405e-4451-86a1-f3cc69c5618c",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat_map = []\n",
    "for pos_neg in ['pos', 'neg']:\n",
    "    for readout in ['mean', 'max', 'sum']:\n",
    "        for l in range(num_layers):\n",
    "            for d in range(hidden_dim):\n",
    "                # feat_map.append(f'{pos_neg}_{readout}_{l}_{d}')\n",
    "                feat_map.append((pos_neg, readout, l, d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "414207d3-2bbf-414b-9e64-2b2077413496",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_layer_rules_renamed = []\n",
    "for c in range(len(last_layer_rules)):\n",
    "    s = set()\n",
    "    for t in last_layer_rules[c]:\n",
    "        new_t = []\n",
    "        for i in t:\n",
    "            if not activations[-1]['x_bin'][:,i].all().item():\n",
    "                new_t.append((i, *feat_map[i]))\n",
    "        s.add(tuple(new_t))\n",
    "    last_layer_rules_renamed.append(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e9538e0-85af-4aa6-a2e5-0f1d223bad68",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_layer_rules_renamed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6d5b120-e577-4b01-8fe2-5420dded7913",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_minimal_sets(list_of_sets):\n",
    "    minimal_sets = []\n",
    "    for i, s in enumerate(list_of_sets):\n",
    "        if not any((set(other)<set(s)) or (s == other and i != j) for j, other in enumerate(list_of_sets)):\n",
    "            minimal_sets.append(s)\n",
    "    return minimal_sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e78f4946-26ba-48ed-8ff9-201bd9ead5e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_layer_rules_renamed = [find_minimal_sets(r) for r in last_layer_rules_renamed]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6bb32b6-649e-4b2e-9105-426908bb02b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "tells = [c.nn[0] for c in model_tell.convs] + [model_tell.fc]\n",
    "for tell in tells:\n",
    "    tell.forward = forward_tell(tell, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "612ef740-09b4-48b5-b9a3-e59226189aa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(test_loader))\n",
    "y, rets = forward_with_activations(model_tell.to(device), batch.x.to(device).float(), batch.edge_index.to(device), batch.batch.to(device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10642270-acc3-44bc-9337-79d2311acc30",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.data import Batch, Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8807bd95-39f0-4cbd-9566-e521c374b7ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils import to_networkx, subgraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84490661-35e4-4c8b-994e-fbd9ccf696de",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_subgraph(data, node_mask):\n",
    "    nodes_to_keep = torch.where(node_mask)[0]\n",
    "    new_edge_index = subgraph(nodes_to_keep, data.edge_index, relabel_nodes=True)[0]\n",
    "    new_x = data.x[node_mask]\n",
    "    G = to_networkx(Data(x=new_x, edge_index=new_edge_index))\n",
    "    return G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cce867f-a877-41a9-82f1-d4ca73e73bbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_layer_rules_renamed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5be9ec2f-e6b5-45de-a842-da67797e5752",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5bfeae5-abc0-4b33-bcff-21aaf40da6b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "reprs = []\n",
    "model_tell = model_tell.to(device)\n",
    "for rule in last_layer_rules_renamed[1]:\n",
    "    print(rule)\n",
    "    # for literal in rule:\n",
    "    representatives = []\n",
    "    n = 0\n",
    "    for data in val_dataset:\n",
    "        batch = Batch.from_data_list([data])\n",
    "        \n",
    "        try:\n",
    "            y, rets = forward_with_activations(model_tell, batch.x.to(device).float(), batch.edge_index.to(device), batch.batch.to(device))      \n",
    "        except: continue \n",
    "        ll_feats = [literal[0] for literal in rule]\n",
    "        if not rets[-1]['x_bin'][:,ll_feats].all(): continue\n",
    "        node_mask = torch.ones(data.x.shape[0]).bool()\n",
    "        for literal in rule:\n",
    "            m = (rets[literal[3]]['y_bin'][:,literal[4]].detach().cpu())\n",
    "            if m.all():continue\n",
    "            # if literal[1] == 'neg':\n",
    "            #     m = ~m\n",
    "            node_mask &= m\n",
    "        if node_mask.sum().item() == 0: continue\n",
    "        n+=1\n",
    "        G = get_subgraph(data, node_mask)\n",
    "        isomorphic = False\n",
    "        for G1, _, _, _ in representatives:\n",
    "            if nx.is_isomorphic(G, G1): \n",
    "                isomorphic=True\n",
    "                break\n",
    "        if not isomorphic:\n",
    "            representatives.append((G,batch,node_mask, y))\n",
    "    print(n, len(representatives))\n",
    "    representatives = representatives[:10]\n",
    "    if len(representatives):plot_activations(list(range(len(representatives))), Batch.from_data_list([x[1] for x in representatives]), torch.cat([x[2] for x in representatives]))\n",
    "    reprs.append(representatives)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b4dd3a4-05a7-4e6c-9130-f947fc8c1dbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "\t# 0\tC\n",
    "\t# 1\tO\n",
    "\t# 2\tCl\n",
    "\t# 3\tH\n",
    "\t# 4\tN\n",
    "\t# 5\tF\n",
    "\t# 6\tBr\n",
    "\t# 7\tS\n",
    "\t# 8\tP\n",
    "\t# 9\tI\n",
    "\t# 10\tNa\n",
    "\t# 11\tK\n",
    "\t# 12\tLi\n",
    "\t# 13\tCa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c20bd1f-9a21-4048-9f01-ac460014763f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class_to_explain = 0\n",
    "instance_to_show = val_dataset[2].cpu()\n",
    "assert(instance_to_show.y == class_to_explain)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7feb2338-c1e5-4ee4-ac09-12ce9d87c4d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fe09116-8bd1-4238-b3d5-895039027be2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_step_intervals(w, b, xmin, xmax, tau=5, resolution=1000):\n",
    "    # Sample x values\n",
    "    xs = torch.linspace(xmin, xmax, resolution)\n",
    "    wxb = w * xs + b\n",
    "    ys = step(wxb, tau)\n",
    "\n",
    "    intervals = []\n",
    "    above = ys[0] > 0.5\n",
    "    start = xs[0].item() if above else None\n",
    "\n",
    "    for i in range(1, len(xs)):\n",
    "        curr = ys[i] > 0.5\n",
    "        if curr and not above:\n",
    "            # Rising edge\n",
    "            start = xs[i-1].item()\n",
    "        elif not curr and above:\n",
    "            # Falling edge\n",
    "            end = xs[i].item()\n",
    "            intervals.append((start, end))\n",
    "            start = None\n",
    "        above = curr\n",
    "\n",
    "    if above and start is not None:\n",
    "        intervals.append((start, xs[-1].item()))\n",
    "\n",
    "    return intervals\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2734f9d3-faab-4cbe-ac57-11104d7823a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tell import step\n",
    "\n",
    "conv_rules = []\n",
    "fc_rules = []\n",
    "model_tell = model_tell.cpu()\n",
    "\n",
    "batch = Batch.from_data_list([instance_to_show])\n",
    "instance_to_show_y, instance_to_show_rets = forward_with_activations(model_tell, batch.x.float(), batch.edge_index, batch.batch)      \n",
    "for rule in last_layer_rules_renamed[class_to_explain]:\n",
    "    print(rule)\n",
    "    # for literal in rule:\n",
    "\n",
    "    \n",
    "    ll_feats = [literal[0] for literal in rule]\n",
    "    # if not rets[-1]['x_bin'][:,ll_feats].all(): continue\n",
    "    node_mask = torch.ones(data.x.shape[0]).bool()\n",
    "    phi_in = model_tell.fc.phi_in\n",
    "    ands = []\n",
    "    node_mask = None\n",
    "    for literal in rule:\n",
    "        m = (activations[-1]['x_bin'][:,literal[0]].detach().cpu())\n",
    "        if node_mask is None:\n",
    "            node_mask = m\n",
    "        else: node_mask &= m\n",
    "        intervals = find_step_intervals(phi_in.w[literal[0]].cpu(), phi_in.b[literal[0]].cpu(), activations[-1]['x'][:, literal[0]].min().cpu(), activations[-1]['x'][:, literal[0]].max().cpu(), tau=10, resolution=1000)\n",
    "        ands.append((literal, intervals))\n",
    "    print(node_mask)\n",
    "    if node_mask.sum().item() == 0: continue\n",
    "    fc_rules.append(ands)\n",
    "\n",
    "    for literal,_ in ands:\n",
    "        conv_rules.append(literal[-2:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e3ea51f-4ee2-49a1-b2d0-943eb9a2d2e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell.convs[0].nn[0].phi_in.w.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "800cae72-9557-4561-bdcf-47ed0cdaaffb",
   "metadata": {},
   "outputs": [],
   "source": [
    "conv_rules_details = {}\n",
    "while(len(conv_rules)):\n",
    "    layer, feat = conv_rules.pop(0)\n",
    "    phi_in = model_tell.convs[layer].nn[0].phi_in\n",
    "    hd = phi_in.w.shape[0]//2\n",
    "    ho = model_tell.convs[layer].nn[0].weight.shape[0]\n",
    "    rules = extract_rules(model_tell.convs[layer].nn[0], feature=feat%ho)\n",
    "    refined_rules = []\n",
    "    for c in range(len(rules)):\n",
    "        s = set()\n",
    "        for t in rules[c]:\n",
    "            new_t = []\n",
    "            for i in t:\n",
    "                if not activations[layer]['x_bin'][:,i].all().item():\n",
    "                    intervals = find_step_intervals(phi_in.w[i].cpu(), phi_in.b[i].cpu(), (activations[layer]['x_sum'][:, i]).min().cpu(), (activations[layer]['x_sum'][:, i]).max().cpu(), tau=10, resolution=1000)\n",
    "                    new_t.append((i, tuple(intervals)))\n",
    "                    if layer!=0:\n",
    "                        if (layer-1, i) not in conv_rules_details and (layer-1, i) not in conv_rules:\n",
    "                            conv_rules.append((layer-1, i))\n",
    "            if new_t:\n",
    "                s.add(tuple(new_t))\n",
    "            else:\n",
    "                s.add((True,))\n",
    "        refined_rules.append(s)\n",
    "    # print(rules)\n",
    "    # print(refined_rules)\n",
    "    # print(layer, feat, find_minimal_sets(refined_rules[0]))\n",
    "\n",
    "    conv_rules_details[layer, feat] = (find_minimal_sets(refined_rules[0]), instance_to_show_rets[layer]['y_bin'][:,feat%ho] if feat < ho else ~instance_to_show_rets[layer]['y_bin'][:,feat%ho])\n",
    "    print(conv_rules)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee3361a4-a41a-411f-b492-bc523ce134fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "conv_rules_details"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61f92501-1858-4cf7-b7a6-3d47b3651b2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "\n",
    "dict(Counter([x[0]+1 for x in conv_rules_details.keys()if len(conv_rules_details[x][0])!=0 ]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06f3e18b-3e2b-4342-b2bb-f3b210bd7aed",
   "metadata": {},
   "outputs": [],
   "source": [
    "[x[0]+1 for x in conv_rules_details.keys() if len(conv_rules_details[x][0])==0 ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efe8796f-6035-4ee2-8a64-9f23d2d56352",
   "metadata": {},
   "outputs": [],
   "source": [
    "acts = []\n",
    "for i in range(5):\n",
    "    rs = [global_add_pool(activations[x[0]]['x_bin'][:,[x[1]%activations[x[0]]['x_bin'].shape[1]]], activations[x[0]]['batch']).float().mean().item() for x in conv_rules_details.keys() if x[0] == i]\n",
    "    acts.append(rs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b479125-6a64-4c15-81f9-567fd8b0e0fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 4\n",
    "rs = [global_add_pool(activations[x[0]]['x_bin'][:,[x[1]%activations[x[0]]['x_bin'].shape[1]]], activations[x[0]]['batch']).float().mean().item() for x in conv_rules_details.keys() if x[0] == i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f0d994f-e496-44ba-9d63-aaad8464b05f",
   "metadata": {},
   "outputs": [],
   "source": [
    "acts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3e29389-6fc7-4c14-bf9b-f76d7507d70e",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 5, figsize=(16, 3))\n",
    "\n",
    "for i in range(5):\n",
    "    print(max(1, min(10, len(acts[i]))))\n",
    "    axes[i].hist(acts[i], bins=max(1, min(10, len(acts[i])-1)))\n",
    "    axes[i].set_title(f'Layer {i+1}')\n",
    "    axes[i].set_xlabel('% Activated Nodes')\n",
    "    axes[i].set_ylabel('% Number of Rules')\n",
    "    axes[i].set_xlim(0, 1)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1fc6b22-bf56-4e3c-8c2e-be203f00d4d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(acts[3], bins=100)\n",
    "plt.xlim(-6e-9, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83cd787c-d499-4a3c-bd34-c035b51566a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(acts[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8561013b-2645-4fd5-8075-449eb7953f95",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(acts[2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3607e71-a239-4e3a-a007-6c31fd48ef77",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(acts[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c4f9df2-e0a5-4ae4-99ef-95c4e5f0f1d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "acts[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86e5d0f0-f1a4-4e40-9362-b7613939e350",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), conv_rules_details[(0,28)][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "603e5523-734a-42d5-bf3e-972cd59e3fae",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), conv_rules_details[(0,45)][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2d9809b-59fa-4376-9179-e6ad0a13ddd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), conv_rules_details[(0,48)][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee7507b5-5b12-4040-8e72-980b2cfc8af0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), conv_rules_details[(0,52)][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d4f13f2-9f8e-4ed3-8ac9-a294765c52d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), conv_rules_details[(1,5)][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63d12f36-62f3-4e84-975f-f76867fa7d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show_rets[-1]['x'][:,837]/14"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25e3f70d-5307-43f7-8aa3-c729b7636acf",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show_rets[1]['y'][:,5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8daa78f-06b5-46e7-8117-e1c9739ed057",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), instance_to_show_rets[0]['y_bin'][:,45%32])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4519d970-7684-4824-897c-c2677b9abafa",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), conv_rules_details[(1,5)][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8d0e04f-2604-42ee-9408-64412da51056",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), instance_to_show_rets[1]['x_bin'][:,(28,45,48)].all(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f9afa7a-0e5f-4ec0-bd7f-80dd83e9e094",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), instance_to_show_rets[1]['x_bin'][:,(28,45,52)].all(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42fab8e8-fdd8-402b-b547-62c90fbae18f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), conv_rules_details[(0,52)][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83367bea-bbab-4f92-96ce-6bc2b64cfa03",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), instance_to_show_rets[1]['x_bin'][:,52])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aa6385e-f889-41f8-b7a5-e144986893a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), instance_to_show_rets[1]['x_bin'][:,48])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7b69399-5411-4afe-a956-80f3b73d7268",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show_rets[0]['x_sum'][:,(0,7)][[16,19]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f1c18f5-3a1a-43b1-9e11-40b35ba662ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show_rets[0]['x_bin'][:,(0,7)][[16,19]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daa6437d-e3cc-4448-bac4-deaf1b970ade",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show_rets[1]['x_bin'][:,48]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d9f37e1-af0e-470d-b417-9ef8c6521156",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show_rets[1]['y'][:,5][4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0068ca9a-cc7b-462e-a477-269dec9ba048",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show_rets[1]['x_sum'][:,28][4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "407f9236-63df-497a-9d14-6c590e04fefa",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show_rets[1]['x_sum'][:,45%32][4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea9f1a37-d290-4216-9473-f93adcb98495",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a156dff-b942-4730-9b90-91b63c573e5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell.convs[1].nn[0].phi_in(instance_to_show_rets[1]['x_sum'])[:,(28,45,52)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc0c54a3-161b-4a08-803d-32c205e6f818",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell.convs[1].nn[0].phi_in(instance_to_show_rets[1]['x_sum']+instance_to_show_rets[1]['x'])[:,(28,45,52)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "164311c4-ff7d-4e76-ba60-4e8f7765e6ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show_rets[1]['y'][:,(5,)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30f687bf-3d8b-4d23-b65d-163f3ec21df2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell.convs[1].nn[0].weight[5,(28,45,52)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17608230-d042-4dff-8efc-4ed30d2d5e52",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d382d961-15e8-451d-adb0-587799baec75",
   "metadata": {},
   "outputs": [],
   "source": [
    "-model_tell.convs[1].nn[0].b[5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f50d654-127f-4b0f-9ab5-f21ead367417",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a13e805-f853-41ae-a11a-44b2edad93ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), instance_to_show_rets[1]['x_bin'][:,(28,45,52)].all(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e4c0e0b-6913-4770-a0d1-34e2c9ee4dc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), instance_to_show_rets[1]['x_bin'][:,(28,45,48)].all(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ae0c1b4-ad36-40fc-b9b4-b1efc0351d4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([instance_to_show]), instance_to_show_rets[1]['y'][:,5]>0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "054cf972-f424-44a1-8cdb-86892c99abc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show_rets[1]['x_bin'][:,45]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fbd2752-ee32-4eb1-9651-5b4a03877a91",
   "metadata": {},
   "outputs": [],
   "source": [
    "layer = 1\n",
    "feat = 5%32\n",
    "rules = extract_rules(model_tell.convs[layer].nn[0], feature=feat, max_rules=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37cb3454-4c21-4e82-b9a0-1d7bbdffccbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93a91ffa-641d-4314-bdf3-2cd8714a0213",
   "metadata": {},
   "outputs": [],
   "source": [
    "activations[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd5bf47b-f908-409c-85f1-a9b95da8c9a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "12%7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8669cd08-bd55-4b7a-b1db-d9854ae4e5f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "s1 = set(activations[1]['x_bin'][:,(28, 45, 48)].all(-1).nonzero().reshape(-1).tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caeaa378-0f23-4d91-bc05-f72d725d2cbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "s2 = set(activations[1]['x_bin'][:,(28, 45, 52)].all(-1).nonzero().reshape(-1).tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43455bff-8f27-4ecd-9e70-0d123b56e848",
   "metadata": {},
   "outputs": [],
   "source": [
    "s2 < s1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f3cb163-7cb0-4aab-abea-43630cb5b66f",
   "metadata": {},
   "outputs": [],
   "source": [
    "extract_rules(model_tell.convs[0].nn[0])[28]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54451623-5f4b-4d7a-bbc3-a851453eea93",
   "metadata": {},
   "outputs": [],
   "source": [
    "activations[1]['y'][:,(0, 9)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "707e2efb-61b1-4ada-86db-9b389c8c0526",
   "metadata": {},
   "outputs": [],
   "source": [
    " plot_activations([0], Batch.from_data_list([instance_to_show]), instance_to_show_rets[1]['x_bin'][:,(21,36,61)].all(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a256c66b-5fc2-44ce-9b71-3354fc2efae3",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show_rets[1]['x'][:,38]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a4e1fd3-6f6b-4d3c-97c8-6dd0c17576f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_to_show_rets[1]['x_bin'][:,(21,36,61)].all(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50b86a8a-4c1e-48d4-9bd0-1fcd29a6447d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, (r,m) in conv_rules_details.items():\n",
    "    print(k, r, m)\n",
    "    plot_activations([0], Batch.from_data_list([instance_to_show]), m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09aec343-fab4-4241-8474-4a5b514b3e84",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2604fab3-5ee2-4503-8aba-d68e42c15ab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "literal = (829, 'neg', 'sum', 0, 29)\n",
    "phi_in = model_tell.fc.phi_in\n",
    "r = torch.arange(activations[-1]['x'][:, literal[0]].min(), activations[-1]['x'][:, literal[0]].max(), activations[-1]['x'][:, literal[0]].abs().max()/100).to(device)\n",
    "output = step(phi_in.w[literal[0]]*r+phi_in.b[literal[0]], tau=10).reshape(-1)\n",
    "plt.plot(r.reshape(-1).cpu().numpy(), output.cpu().detach().numpy())\n",
    "plt.scatter(activations[-1]['x'][:, literal[0]].cpu(), activations[-1]['x_bin'][:, literal[0]].cpu())\n",
    "plt.scatter(activations[-1]['x'][:, literal[0]].cpu(), activations[-1]['y'][:,1].cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7597443-b63a-4af9-9f4a-39b1ea44cd18",
   "metadata": {},
   "outputs": [],
   "source": [
    "find_step_intervals(phi_in.w[literal[0]].cpu(), phi_in.b[literal[0]].cpu(), activations[-1]['x'][:, literal[0]].min().cpu(), activations[-1]['x'][:, literal[0]].max().cpu(), tau=10, resolution=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5ca4302-52fe-4633-9bfb-8072760f21c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "literal = (321, 'pos', 'sum', 0, 1)\n",
    "phi_in = model_tell.fc.phi_in\n",
    "r = torch.arange(activations[-1]['x'][:, literal[0]].min(), activations[-1]['x'][:, literal[0]].max(), activations[-1]['x'][:, literal[0]].abs().max()/100).to(device)\n",
    "output = step(phi_in.w[literal[0]]*r+phi_in.b[literal[0]], tau=10).reshape(-1)\n",
    "plt.plot(r.reshape(-1).cpu().numpy(), output.cpu().detach().numpy())\n",
    "plt.scatter(activations[-1]['x'][:, literal[0]].cpu(), activations[-1]['x_bin'][:, literal[0]].cpu())\n",
    "plt.scatter(activations[-1]['x'][:, literal[0]].cpu(), activations[-1]['y'][:,1].cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5b59ff0-0d72-4654-8211-3dd1ae7f6fce",
   "metadata": {},
   "outputs": [],
   "source": [
    "find_step_intervals(phi_in.w[literal[0]].cpu(), phi_in.b[literal[0]].cpu(), activations[-1]['x'][:, literal[0]].min().cpu(), activations[-1]['x'][:, literal[0]].max().cpu(), tau=10, resolution=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7d2b85d-68bf-4d51-8093-f0418247e68b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1f42556-42fc-49fc-8041-93fa16ebe002",
   "metadata": {},
   "outputs": [],
   "source": [
    "activations[-1]['x'][:,literal[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad232f4a-1c42-428b-8fcb-a397b85a6a59",
   "metadata": {},
   "outputs": [],
   "source": [
    "literal[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af667b6c-c78c-40b0-a78b-748b130c5f01",
   "metadata": {},
   "outputs": [],
   "source": [
    "literal[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e872dd0b-4afe-4d48-ad11-d8762912dbce",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28155fc8-b5ea-4e4f-864b-6b56df02fc59",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_scatter import scatter_add\n",
    "\n",
    "# Example node features (num_nodes x num_features)\n",
    "x = rets[literal[3]]['x_bin'].float().cpu()\n",
    "\n",
    "if literal[1] == 'neg':\n",
    "    x = -x\n",
    "\n",
    "# Example edge_index (2 x num_edges), undirected edges should be included in both directions\n",
    "edge_index = batch.edge_index.cpu()\n",
    "\n",
    "# Include self-loops if not already present\n",
    "from torch_geometric.utils import add_self_loops\n",
    "\n",
    "edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))\n",
    "\n",
    "# Aggregate features of neighbors (including self)\n",
    "x_prime = scatter_add(x[edge_index[1]], edge_index[0], dim=0, dim_size=x.size(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42c57ee8-5ca3-4f1c-81f2-ae9941484ee9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tell import step\n",
    "r = torch.arange(x_prime[:, literal[4]].min(), x_prime[:, literal[4]].max(), x_prime[:, literal[4]].abs().max()/100)\n",
    "# r = torch.arange(0,100,0.1)\n",
    "output = step(phi_in.w[literal[4]]*r+phi_in.b[literal[4]], tau=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1ef9cfe-42e4-4939-bccc-4f105a664738",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(r, output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07148aec-1cd1-425d-94cc-e977b8e434f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.nn import aggr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b56fb1ab-b6b8-470c-89e6-e9e27b1294c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e404aca7-b266-476f-bb6f-9554c954dda7",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_prime[:,literal[4]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eee2fa1c-7f6c-4da4-ac17-a0616f5c3090",
   "metadata": {},
   "outputs": [],
   "source": [
    "phi_in(torch.arange(x_prime[:, literal[4]].min(), x_prime[:, literal[4]].max(), x_prime[:, literal[4]].min()/10).unsqueeze(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "994f85a5-845a-4c8b-b658-313f698bada2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2702c97c-2f4c-449c-9907-9aa392e11b03",
   "metadata": {},
   "outputs": [],
   "source": [
    "suma(rets[literal[3]]['x'], batch.edge_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3014144-f18f-4cf3-9cb8-9619d45aba05",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_layer_rules_renamed[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb1d3653-b56a-4817-b14f-b818a7ac6e67",
   "metadata": {},
   "outputs": [],
   "source": [
    "extract_rules(model_tell.convs[0].nn[0], activations=activations[0]['x_bin'].cpu())[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd5d002e-1ca5-48c9-8245-be9342dc2a9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = extract_rules(model_tell.convs[1].nn[0], activations=activations[1]['x_bin'].cpu())[21]\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f62e967f-7e77-4258-a567-53f69c0efe42",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = extract_rules(model_tell.convs[0].nn[0], activations=activations[0]['x_bin'].cpu())[21]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef6d363d-2288-4219-9069-ed7ddb66e104",
   "metadata": {},
   "outputs": [],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "094438b4-3c1c-4701-8835-36b56c632e8f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf3ea3a8-fcd7-4b3f-80e4-963bfd9dc4f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "s = {}\n",
    "for aa in a:\n",
    "    print(aa, activations[1]['x_bin'][:,aa].all(-1).sum())\n",
    "    s[aa] = set(torch.nonzero(activations[1]['x_bin'][:,aa].all(-1)).flatten().tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c525cb9-9b2f-4d07-9d88-052998093e20",
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_subset_values(d):\n",
    "    minimal_dict = {}\n",
    "    items = list(d.items())\n",
    "    \n",
    "    for i, (key_i, set_i) in enumerate(items):\n",
    "        is_subset = False\n",
    "        for j, (key_j, set_j) in enumerate(items):\n",
    "            if i != j and set_i.issubset(set_j):\n",
    "                is_subset = True\n",
    "                break\n",
    "        if not is_subset:\n",
    "            minimal_dict[key_i] = set_i\n",
    "    \n",
    "    return minimal_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05bd6bcd-019b-4bea-98de-08795a0521f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "remove_subset_values(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b604d73-9ffe-47ce-916b-10ecfa5f0578",
   "metadata": {},
   "outputs": [],
   "source": [
    "set(torch.nonzero(activations[1]['x_bin'][:,aa].all(-1)).flatten().tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a449c577-5301-40e4-9743-f95c56da6c51",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations(1, 1, [29])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c542049d-164e-4fb1-8060-3566f174934e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell.convs[2].nn[0].phi_in.w[43]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a36abd8-80aa-4ab2-969a-992f7120cbe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_tell.convs[2].nn[0].phi_in.b[43]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e4a12b6-a4a7-44f4-9818-da14d9b21b0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations(1, 1, [29])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bd7d06f-cdb0-4752-b5fd-b2500fcb96de",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Extracting last layer rules\")\n",
    "with torch.no_grad():\n",
    "    activations = []\n",
    "    for batch in val_loader:\n",
    "        xs, ys, acts = forward_with_activations(model_tell, batch.x.to(device), batch.edge_index.to(device), batch.batch.to(device))\n",
    "        # activations.append(model_tell.fc.phi_in(torch.hstack([xs[-1], 1-xs[-1]])))\n",
    "        activations.append(model_tell.convs[2].nn[0].phi_in(torch.hstack([xs[2], 1-xs[2]])))\n",
    "activations = torch.cat(activations)\n",
    "activations = (activations>=0.5).int()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a2895fe-5787-4da1-af5b-09497e086b2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8da7cecc-783f-4bd6-9851-bb0ad384b916",
   "metadata": {},
   "outputs": [],
   "source": [
    "rules = extract_rules(model_tell.convs[2].nn[0], feature=30, max_rules=1000000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c0e9e51-056b-447a-b5bc-96595afaa499",
   "metadata": {},
   "outputs": [],
   "source": [
    "rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "779ba91d-672e-4d5b-8859-3d236a23ef8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "supports = []\n",
    "for c in range(len(rules)):\n",
    "    supports.append([])\n",
    "    s = []\n",
    "    for t in rules[c]:\n",
    "        print(t)\n",
    "        new_t = []\n",
    "        for i in t:\n",
    "            print(i)\n",
    "            if not activations[:,i].all().item():\n",
    "                new_t.append(i)\n",
    "        print(new_t)\n",
    "        if all([activations[:,i].any().item() for i in new_t]):\n",
    "            s.append(tuple(new_t))\n",
    "            supports[c].append(activations[:,new_t].all(-1).sum())\n",
    "    rules[c] = s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f615f9d1-7b5d-4cc0-a40d-9a0b54693240",
   "metadata": {},
   "outputs": [],
   "source": [
    "rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8133760c-fe15-4e4f-8272-93828f834b26",
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted([(s,r) for r, s in zip(rules[0], supports[0])], reverse=True)[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3981e922-96a0-4273-af0b-d6ae13842e8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "activations.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2357fb-81fe-4718-9fa9-b35bc438ca61",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-imp]",
   "language": "python",
   "name": "conda-env-.conda-imp-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
