{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5bed12f-8435-4429-a9b4-87d297a38dc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os\n",
    "# os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
    "from torch import nn\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from syn_dataset import SynGraphDataset\n",
    "from spmotif_dataset import *\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import GINConv, global_mean_pool, global_max_pool, global_add_pool\n",
    "from utils import *\n",
    "from sklearn.model_selection import train_test_split\n",
    "import shutil\n",
    "import glob\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "import pandas as pd\n",
    "import argparse\n",
    "import pickle\n",
    "import json\n",
    "import io\n",
    "from model import GIN\n",
    "from train_baseline import test_epoch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e60dbe7b-82f1-47ea-a240-e8ae6cef5236",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'BaMultiShapes'\n",
    "seed = 5\n",
    "def get_best_baseline_path(dataset_name):\n",
    "    l = glob.glob(f'results/{dataset_name}/*/results.json')\n",
    "    fl = [json.load(open(f)) for f in l]\n",
    "    df = pd.DataFrame(fl)\n",
    "    if df.shape[0] == 0: return None\n",
    "    df['fname'] = l\n",
    "    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])\n",
    "    df = df[df.fname.str.contains('nogumbel=True')]\n",
    "    fname = df.iloc[-1]['fname']\n",
    "    fname = fname.replace('/results.json', '')\n",
    "    return fname\n",
    "\n",
    "\n",
    "results_path = os.path.join(get_best_baseline_path(dataset_name), str(seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0d05420-5fd9-4920-9c99-d42fcce4fdef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "data = pickle.load(open(os.path.join(results_path, 'data.pkl'), 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2aa2e89a-31ee-4368-80db-0a85b87a161e",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = json.load(open(os.path.join(results_path, 'args.json'), 'r'))\n",
    "args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b8a62b2-2159-4525-9bb4-080e3f85a06a",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "# device = torch.device('cpu')\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8771c7d0-42fe-4c80-97e3-3493ab9f182f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = get_dataset(dataset_name)\n",
    "num_classes = dataset.num_classes\n",
    "num_features = dataset.num_features\n",
    "num_layers = args['num_layers']\n",
    "hidden_dim = args['hidden_dim']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a409a08-1fdb-4fd0-b261-fab8c8b1607d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = GIN(num_classes=num_classes, num_features=num_features, num_layers=num_layers, hidden_dim=hidden_dim, nogumbel=True, dropout=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7201525-36bc-41c1-b3b5-8555043429a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_state_dict(torch.load(os.path.join(results_path, 'best.pt'), map_location=device))\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7ac600e-9555-492a-859a-59e7ca25910f",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_indices = data['train_indices']\n",
    "val_indices = data['val_indices']\n",
    "test_indices = data['test_indices']\n",
    "\n",
    "train_dataset = dataset[train_indices]\n",
    "val_dataset = dataset[val_indices]\n",
    "test_dataset = dataset[test_indices]\n",
    "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)\n",
    "val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de56d194-fc7a-4742-8b43-0004d953c32c",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_NODE_FEATURES = dataset.num_node_features\n",
    "NUM_EDGE_FEATURES = 1#dataset.num_edge_features\n",
    "TU_DATASETS = [\"MUTAG\", \"Mutagenicity\", \"NCI1\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "622d15e3-3408-4a0a-a2a2-5dad91c72e8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Helper functions\"\"\"\n",
    "from tqdm import tqdm\n",
    "import torch_geometric as pyg\n",
    "import torch\n",
    "import networkx as nx\n",
    "from sympy.logic.boolalg import Or, And, Not\n",
    "from sympy.parsing.sympy_parser import parse_expr\n",
    "\n",
    "from pygcanl import canonical\n",
    "\n",
    "def preprocess_dataset(dataset):\n",
    "    dataset_ = []\n",
    "    for graph in dataset:\n",
    "        try:\n",
    "            edge_attr=graph.edge_attr.argmax(dim=1)\n",
    "        except:\n",
    "            edge_attr = torch.ones(graph.edge_index.shape[1])\n",
    "        data = pyg.data.Data(\n",
    "            x=graph.x.argmax(dim=1),\n",
    "            id=torch.arange(graph.num_nodes),\n",
    "            edge_index=graph.edge_index,\n",
    "            y=graph.y,\n",
    "            edge_attr=edge_attr,\n",
    "        )\n",
    "        dataset_.append(data)\n",
    "    return dataset_\n",
    "\n",
    "def graph_from_dfs_code(dfs_code):\n",
    "    G = nx.DiGraph()\n",
    "    dfs_code = dfs_code.split(\" \")\n",
    "    expec_root = True\n",
    "    id = 0\n",
    "    par = {}\n",
    "    curr = -1\n",
    "    par[-1] = -2\n",
    "    i = 0\n",
    "    while(i < len(dfs_code)):\n",
    "        ch = dfs_code[i]\n",
    "        if(expec_root):\n",
    "            G.add_node(id, attr=int(ch))\n",
    "            curr = id\n",
    "            expec_root = False\n",
    "            par[id] = id\n",
    "            id += 1\n",
    "            i += 1\n",
    "        else:\n",
    "            if(ch == '$'):\n",
    "                curr = par[curr]\n",
    "                i += 1\n",
    "                continue\n",
    "            ch_nxt = dfs_code[i+1]\n",
    "            par[id] = curr\n",
    "            G.add_node(id, attr=int(ch_nxt))\n",
    "            G.add_edge(curr, id, attr=int(ch))\n",
    "            curr = id\n",
    "            id += 1\n",
    "            i += 2\n",
    "    return G.reverse()\n",
    "\n",
    "def int_to_onehot(attr: int, num_features: int):\n",
    "    one_hot = [0 for __ in range(num_features)]\n",
    "    one_hot[attr] = 1.\n",
    "    return one_hot\n",
    "\n",
    "def nx_to_pyg(\n",
    "        ctree: nx.digraph,\n",
    "        num_node_features: int,\n",
    "        num_edge_features: int\n",
    ") -> pyg.data.Data:\n",
    "    for node in ctree.nodes:\n",
    "        ctree.nodes[node][\"attr\"] = int_to_onehot(\n",
    "            ctree.nodes[node][\"attr\"], num_node_features\n",
    "        )\n",
    "    # for edges in ctree.edges:\n",
    "    #     ctree.edges[edges][\"attr\"] = int_to_onehot(\n",
    "    #         ctree.edges[edges][\"attr\"], num_edge_features\n",
    "    #     )\n",
    "    if len(ctree.edges) == 0:\n",
    "        graph_pyg = pyg.utils.from_networkx(\n",
    "            ctree, group_node_attrs=[\"attr\"], group_edge_attrs=None\n",
    "        )\n",
    "    else:\n",
    "        graph_pyg = pyg.utils.from_networkx(\n",
    "            ctree, group_node_attrs=[\"attr\"], group_edge_attrs=[\"attr\"]\n",
    "        )\n",
    "\n",
    "    return graph_pyg\n",
    "\n",
    "def simplify_expression(str_exp):\n",
    "    expression = parse_expr(str_exp)\n",
    "    simplified_expr = expression.simplify()\n",
    "    return simplified_expr\n",
    "\n",
    "def getVariables(expr):\n",
    "    return expr.atoms()\n",
    "\n",
    "def dfs(ctree, ctree_id, node_mapping=None):\n",
    "    \"\"\"\n",
    "    ! Incorrect\n",
    "    The ctree_id code is generated by writing down the node id as the canonical label of ctree\n",
    "    is generated. Hence, the node order between the two is preserved. Therefore, we can map\n",
    "    ctree's node attriburtes to ctree_id's node attributes.\n",
    "    \"\"\"\n",
    "    G = nx.Graph()\n",
    "    for i in range(len(ctree_id.nodes)):\n",
    "        if node_mapping is not None:\n",
    "            attr = node_mapping[ctree.nodes[i]['attr']]\n",
    "        else:\n",
    "            attr = ctree.nodes[i]['attr']\n",
    "        G.add_node(ctree_id.nodes[i]['attr'], attr=attr)\n",
    "    for e in ctree_id.edges:\n",
    "        src, dest = e\n",
    "        src = ctree_id.nodes[src]['attr']\n",
    "        dest = ctree_id.nodes[dest]['attr']\n",
    "        G.add_edge(src, dest, attr=ctree_id.edges[e]['attr'])\n",
    "    return G\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a917b8aa-ec38-4545-b447-5d29b937c067",
   "metadata": {},
   "outputs": [],
   "source": [
    "# * ----- Indicator vectors of training graphs\n",
    "# start_time = process_time()\n",
    "\n",
    "processed_dataset = preprocess_dataset(\n",
    "    [dataset[i] for i in train_indices])\n",
    "list_of_dfs_id_codes = canonical(processed_dataset, 3)\n",
    "\n",
    "dict_dfs_id_codes = {}\n",
    "all_ctree_codes = []\n",
    "graph_cnt_lst = []\n",
    "\n",
    "str1 = 'dfs_code \\ id_code'\n",
    "\n",
    "for l in list_of_dfs_id_codes:\n",
    "    temp = []\n",
    "    d = {}\n",
    "    for s in l:\n",
    "        key = s.split('/')[0]\n",
    "        val = s.split('/')[1]\n",
    "        temp.append(key)\n",
    "        dict_dfs_id_codes[key] = val\n",
    "        if key in d:\n",
    "            d[key] += 1\n",
    "        else:\n",
    "            d[key] = 1\n",
    "    graph_cnt_lst.append(d)\n",
    "    all_ctree_codes.append(temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f3f966a-2d63-49b6-be8e-4e3d95b86991",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculated in train, used directly in val, test.\n",
    "unique_ctree_codes = list(dict_dfs_id_codes.keys())\n",
    "\n",
    "# with open(f\"{FOLDER}/dict_dfs_id_codes.pkl\", \"wb\") as file:\n",
    "#     dump(dict_dfs_id_codes, file)\n",
    "# with open(f\"{FOLDER}/unique_ctree_codes.pkl\", \"wb\") as file:\n",
    "#     dump(unique_ctree_codes, file)\n",
    "\n",
    "cnt_ind_vec = []\n",
    "for g_dict in graph_cnt_lst:\n",
    "    temp = []\n",
    "    for ct in unique_ctree_codes:\n",
    "        if ct in g_dict:\n",
    "            temp.append(g_dict[ct])\n",
    "        else:\n",
    "            temp.append(0)\n",
    "    cnt_ind_vec.append(temp)\n",
    "cnt_ind_vec = np.array(cnt_ind_vec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a82b223e-3a21-47c5-bfed-d1de724028b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# * ----- Indicator vectors of validation graphs\n",
    "processed_dataset_val = preprocess_dataset(\n",
    "    [dataset[i] for i in val_indices])\n",
    "list_of_dfs_id_codes_val = canonical(processed_dataset_val, 3)\n",
    "\n",
    "dict_dfs_id_codes_val = {}\n",
    "all_ctree_codes_val = []\n",
    "graph_cnt_lst_val = []\n",
    "\n",
    "for l in list_of_dfs_id_codes_val:\n",
    "    temp = []\n",
    "    d = {}\n",
    "    for s in l:\n",
    "        key = s.split('/')[0]\n",
    "        val = s.split('/')[1]\n",
    "        temp.append(key)\n",
    "        dict_dfs_id_codes_val[key] = val\n",
    "        if key in d:\n",
    "            d[key] += 1\n",
    "        else:\n",
    "            d[key] = 1\n",
    "    graph_cnt_lst_val.append(d)\n",
    "    all_ctree_codes_val.append(temp)\n",
    "\n",
    "cnt_ind_vec_val = []\n",
    "for g_dict in graph_cnt_lst_val:\n",
    "    temp = []\n",
    "    for ct in unique_ctree_codes:\n",
    "        if ct in g_dict:\n",
    "            temp.append(g_dict[ct])\n",
    "        else:\n",
    "            temp.append(0)\n",
    "    cnt_ind_vec_val.append(temp)\n",
    "cnt_ind_vec_val = np.array(cnt_ind_vec_val)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59bb6488-86c6-494b-8098-e82be75678c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# * ----- Indicator vectors of test graphs\n",
    "processed_dataset_test = preprocess_dataset(\n",
    "    [dataset[i] for i in test_indices])\n",
    "list_of_dfs_id_codes_test = canonical(processed_dataset_test, 3)\n",
    "\n",
    "dict_dfs_id_codes_test = {}\n",
    "all_ctree_codes_test = []\n",
    "graph_cnt_lst_test = []\n",
    "\n",
    "for l in list_of_dfs_id_codes_test:\n",
    "    temp = []\n",
    "    d = {}\n",
    "    for s in l:\n",
    "        key = s.split('/')[0]\n",
    "        val = s.split('/')[1]\n",
    "        temp.append(key)\n",
    "        dict_dfs_id_codes_test[key] = val\n",
    "        if key in d:\n",
    "            d[key] += 1\n",
    "        else:\n",
    "            d[key] = 1\n",
    "    graph_cnt_lst_test.append(d)\n",
    "    all_ctree_codes_test.append(temp)\n",
    "\n",
    "cnt_ind_vec_test = []\n",
    "for g_dict in graph_cnt_lst_test:\n",
    "    temp = []\n",
    "    for ct in unique_ctree_codes:\n",
    "        if ct in g_dict:\n",
    "            temp.append(g_dict[ct])\n",
    "        else:\n",
    "            temp.append(0)\n",
    "    cnt_ind_vec_test.append(temp)\n",
    "cnt_ind_vec_test = np.array(cnt_ind_vec_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ceba22-18ac-4054-a94d-65122eb9a45e",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_NODE_FEATURES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b57a3ce-5440-4323-8f93-14f16eb03a1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# * ----- Ctree Embeddings\n",
    "model.eval()\n",
    "\n",
    "ctree_embeddings = []\n",
    "with torch.no_grad():\n",
    "    for ct in tqdm(unique_ctree_codes, desc=\"Ctree embeddings\", colour=\"green\"):\n",
    "        ctree_nx = graph_from_dfs_code(ct)\n",
    "        ctree_pyg = nx_to_pyg(\n",
    "            ctree_nx,\n",
    "            NUM_NODE_FEATURES,\n",
    "            NUM_EDGE_FEATURES,\n",
    "        )\n",
    "        xs, ys = model.forward_e(\n",
    "            torch.ones_like(ctree_pyg.x)*0.1,\n",
    "            ctree_pyg.edge_index,\n",
    "            batch=None\n",
    "        )\n",
    "\n",
    "        ctree_embeddings.append(torch.hstack(ys[:-1])[0].tolist())\n",
    "\n",
    "ctree_embeddings = np.array(ctree_embeddings)\n",
    "# end_time = process_time()\n",
    "# print(f\"[TIME] gen_ctree: {end_time - start_time} s.ms\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bee11d7-a265-4353-acb6-8794f99955c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import multiprocess as mp\n",
    "from argparse import ArgumentParser\n",
    "from pickle import load, dump\n",
    "\n",
    "import numpy as np\n",
    "import shap\n",
    "import torch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6411909-42bb-41f7-b577-3ef383fc328f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_graph_embedding(row):\n",
    "    embd = np.zeros((ctree_embeddings.shape[1]))\n",
    "    freq = 0\n",
    "    max_embd = embd\n",
    "    for i, cnt in enumerate(row):\n",
    "        freq += cnt\n",
    "        embd += cnt * ctree_embeddings[i]\n",
    "        max_embd = np.maximum(max_embd, cnt * ctree_embeddings[i])\n",
    "    final_embd = [embd/max(1,freq), max_embd, embd]\n",
    "    \n",
    "    return np.hstack(final_embd)\n",
    "\n",
    "\n",
    "def f(ind_vectors):\n",
    "    embedings = np.apply_along_axis(\n",
    "        get_graph_embedding, axis=1, arr=ind_vectors)\n",
    "    embedings = torch.Tensor(embedings)\n",
    "    with torch.no_grad():\n",
    "        model.eval()\n",
    "        x = model.fc1(embedings)\n",
    "        x = torch.nn.functional.relu(x)\n",
    "        out = torch.sigmoid(model.fc2(x))\n",
    "        # Probability of class 1\n",
    "        out = out[:, 1]\n",
    "    return out.numpy()\n",
    "\n",
    "\n",
    "def calculate_shap(shap_arr, chunk_num):\n",
    "    z = np.zeros((1, shap_arr[0].shape[0]))\n",
    "    print('initializing shap')\n",
    "    explainer = shap.KernelExplainer(f, z)\n",
    "    print('running shap')\n",
    "    print(shap_arr.shape)\n",
    "    shap_values = explainer.shap_values(X=shap_arr[:10,:], gc_collect=True, silent=False, nsamples=10)\n",
    "    print('finished shap')\n",
    "    print(f\"Chunk {chunk_num} done!\")\n",
    "    return shap_values\n",
    "\n",
    "\n",
    "procs = 1\n",
    "chunk_size = len(cnt_ind_vec) // procs\n",
    "remainder  = len(cnt_ind_vec) %  procs\n",
    "\n",
    "# Doesn't contain zero and len(cnt_ind_vec)\n",
    "chunks = [chunk_size for __ in range(procs)]\n",
    "\n",
    "# Distribute the remainder\n",
    "for j in range(remainder):\n",
    "    chunks[j] += 1\n",
    "\n",
    "# Take the cumulative sum to get the indices.\n",
    "indices = [0] + np.cumsum(chunks).tolist()\n",
    "\n",
    "# Divide the indicator vectors into chunks.\n",
    "chunked_ind_vectors = []\n",
    "for i in range(len(indices) - 1):\n",
    "    idx_start = indices[i]\n",
    "    idx_end = indices[i + 1]\n",
    "    chunked_ind_vectors.append(cnt_ind_vec[idx_start: idx_end])\n",
    "\n",
    "print(\"Chunk size:\", chunks)\n",
    "print(\"#Chunks:\", len(chunked_ind_vectors))\n",
    "print()\n",
    "\n",
    "with mp.Pool(procs) as p:\n",
    "    results = p.starmap(calculate_shap, zip(chunked_ind_vectors, range(len(chunked_ind_vectors))))\n",
    "\n",
    "shap_values = np.concatenate(results, axis=0)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54f1c5a5-19a4-41c2-908a-56252584ecab",
   "metadata": {},
   "outputs": [],
   "source": [
    "shap_imp = np.abs(shap_values).mean(axis=0)\n",
    "indices = np.argsort(shap_imp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a45362e5-c932-4749-8be6-19ed8d4c9c75",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(loader):\n",
    "    model.eval()\n",
    "    predictions = []\n",
    "    probabilities = []\n",
    "    for data in loader:\n",
    "        out = model(\n",
    "            x=data.x,\n",
    "            edge_index=data.edge_index,\n",
    "            batch=data.batch\n",
    "        )\n",
    "        pred = out.argmax(dim=1)\n",
    "        prob = out\n",
    "        predictions += pred.tolist()\n",
    "        probabilities += prob.tolist()\n",
    "    return predictions, probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d91597a-a47a-4d97-ac3b-97ec7aace7d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pred, train_prob = predict(train_loader)\n",
    "val_pred, val_prob = predict(val_loader)\n",
    "test_pred, test_prob = predict(test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "568cb114-14db-4718-8d13-6887414eaaf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 200\n",
    "c = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18716045-099a-4c83-bf21-ac8290617a94",
   "metadata": {},
   "outputs": [],
   "source": [
    "pysr_weights = [max(prob) for prob in train_prob]\n",
    "\n",
    "x_train = cnt_ind_vec[:, indices[- k:]]\n",
    "x_val = cnt_ind_vec_val[:, indices[- k:]]\n",
    "x_test = cnt_ind_vec_test[:, indices[- k:]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52d9a0bc-0985-4dc0-a98b-6dc2fe0540fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train_bin = []\n",
    "x_val_bin = []\n",
    "x_test_bin = []\n",
    "for set_ in [\"train\", \"val\", \"test\"]:\n",
    "    for lst in eval(f\"x_{set_}\"):\n",
    "        temp = []\n",
    "        for i in lst:\n",
    "            if i > 0:\n",
    "                temp.append(1)\n",
    "            else:\n",
    "                temp.append(0)\n",
    "        eval(f\"x_{set_}_bin.append(temp)\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6384256e-f149-4077-bac7-312e0636fd3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pysr import PySRRegressor # Import this first!\n",
    "\n",
    "\n",
    "# * ----- Symbolic Regression\n",
    "# start_time = process_time()\n",
    "\n",
    "pysrmodel = PySRRegressor(\n",
    "    unary_operators = [\"Not(x) = (x <= zero(x)) * one(x)\"],\n",
    "    binary_operators = [\n",
    "        \"And(x, y) = ((x > zero(x)) & (y > zero(y))) * one(x)\",\n",
    "        \"Or(x, y)  = ((x > zero(x)) | (y > zero(y))) * one(x)\",\n",
    "        \"Xor(x, y) = (((x > 0) & (y <= 0)) | ((x <= 0) & (y > 0))) * 1f0\",\n",
    "    ],\n",
    "    extra_sympy_mappings = {\n",
    "        \"Not\": lambda x: sympy.Piecewise((1.0, (x <= 0)), (0.0, True)),\n",
    "        \"And\": lambda x, y: sympy.Piecewise((1.0, (x > 0) & (y > 0)), (0.0, True)),\n",
    "        \"Or\":  lambda x, y: sympy.Piecewise((1.0, (x > 0) | (y > 0)), (0.0, True)),\n",
    "        \"Xor\": lambda x, y: sympy.Piecewise((1.0, (x > 0) ^ (y > 0)), (0.0, True)),\n",
    "    },\n",
    "\n",
    "    elementwise_loss = \"loss(prediction, target) = sum(prediction != target)\",\n",
    "    model_selection=\"accuracy\",\n",
    "\n",
    "    complexity_of_variables=c,\n",
    "    complexity_of_operators={'Not': c, 'And': c, 'Or': c, 'Xor': c},\n",
    "\n",
    "    select_k_features = min(k, 10),\n",
    "    weights = pysr_weights,\n",
    "\n",
    "    batch_size = 32,\n",
    "\n",
    "    # Paperwork\n",
    "    temp_equation_file = True,\n",
    "    delete_tempfiles = True,\n",
    "\n",
    "    # Determinism\n",
    "    procs=0,\n",
    "    deterministic=True,\n",
    "    multithreading=False,\n",
    "    random_state=0,\n",
    "    warm_start=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76055526-b27b-4cf3-bfcb-150994d5c8f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sympy\n",
    "def cal_pysr_acc(X, Y, index=None):\n",
    "    Y = np.array(Y)\n",
    "    Y_pred = pysrmodel.predict(X, index=index)\n",
    "    assert Y.shape == Y_pred.shape , \"Shape mismatch!\"\n",
    "    return (Y_pred == Y).sum() / len(Y)\n",
    "\n",
    "\n",
    "pysrmodel.fit(x_train_bin, train_pred)\n",
    "print(pysrmodel)\n",
    "\n",
    "selected_ctrees = pysrmodel.selection_mask_\n",
    "selected_ctrees = np.where(pysrmodel.selection_mask_)[0]  # Convert boolean mask to integer indices\n",
    "\n",
    "df_equations = pysrmodel.equations.drop([\"sympy_format\", \"lambda_format\"], axis=1)\n",
    "# Add a column for accuracy.\n",
    "df_equations[\"acc\"] = 1 - df_equations[\"loss\"]\n",
    "# Re-arrange columns to have \"acc\" as the second column.\n",
    "cols = df_equations.columns.tolist()\n",
    "cols.insert(1, cols.pop(-1))\n",
    "df_equations = df_equations[cols]\n",
    "# Round values.\n",
    "for col in [\"acc\", \"loss\", \"score\"]:\n",
    "    df_equations[col] = df_equations[col].round(4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99cd5898-39f7-4cbc-9d25-b0053d968a08",
   "metadata": {},
   "outputs": [],
   "source": [
    "pysrmodel.equations_.equation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bef184e1-45e8-4c55-a78d-1d7761d71a7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find the equation that performs the best on the validation set\n",
    "best_val_acc = 0\n",
    "print(\"\\nValidation accuracies:\")\n",
    "for j in range(pysrmodel.equations_.shape[0]):\n",
    "    if pysrmodel.equations_.equation.iloc[j] == '1.0':continue\n",
    "    # PySR sometimes fails to evaluate certain formulae\n",
    "    # it usually happens when C is set to a small value.\n",
    "    # We've been unable to identify when and why it happens\n",
    "    try:\n",
    "        __ = pysrmodel.predict(x_train_bin, index=j)\n",
    "        __ = pysrmodel.predict(x_test_bin, index=j)\n",
    "        pysr_val_pred = pysrmodel.predict(x_val_bin, index=j)\n",
    "    except ValueError:\n",
    "        print(f\"{j}: failed\")\n",
    "        continue\n",
    "    val_acc = (pysr_val_pred == val_pred).sum() / len(val_pred)\n",
    "    print(f\"{j}: {val_acc}\")\n",
    "    if val_acc > best_val_acc:\n",
    "        best_val_acc = val_acc\n",
    "        best_index = j\n",
    "print(\"Best equation index:\", best_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cc6cdc6-a020-4c14-bb93-0e202158c2aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# * ----- Metrics\n",
    "pysr_train_pred = pysrmodel.predict(x_train_bin, index=best_index)\n",
    "pysr_test_pred = pysrmodel.predict(x_test_bin, index=best_index)\n",
    "pysr_train_pred = torch.LongTensor(pysr_train_pred)\n",
    "pysr_test_pred = torch.LongTensor(pysr_test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "934fac0f-2241-44bb-bfcf-a057a91efca3",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_index = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83748186-f99c-4ff3-bb38-c8fa9bf58969",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Best based on val set\n",
    "equation = pysrmodel.get_best(index=best_index).equation\n",
    "print()\n",
    "print(\"=\" * 50)\n",
    "print(\"Equation:\", equation)\n",
    "print(\"C =\", c)\n",
    "\n",
    "train_acc = round(cal_pysr_acc(x_train_bin, train_pred, index=best_index), 3)\n",
    "test_acc  = round(cal_pysr_acc(x_test_bin, test_pred, index=best_index), 3)\n",
    "\n",
    "equation = simplify_expression(equation)\n",
    "print(\"Simplified equation:\", equation)\n",
    "print(\"Train accuracy:\", train_acc)\n",
    "print(\"Test accuracy:\", test_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d65b7a7b-3593-4602-bbaa-afd51c2bd63f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_equations.equation[5]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11ccfe0d-c00b-4bbf-ae37-f0224c2caa18",
   "metadata": {},
   "outputs": [],
   "source": [
    "# * ----- Save stuff to disk\n",
    "# Save equations\n",
    "# df_equations.to_csv(f\"{FOLDER}/equations_sample{args.sample}.csv\", index=True)\n",
    "# del df_equations\n",
    "\n",
    "# Save predictions\n",
    "# torch.save(torch.LongTensor(train_pred), f\"{FOLDER}/gnn_train_pred.pt\")\n",
    "# torch.save(torch.LongTensor(test_pred), f\"{FOLDER}/gnn_test_pred.pt\")\n",
    "# # torch.save(pysr_train_pred, f\"{FOLDER}/pysr_train_pred_sample{args.sample}.pt\")\n",
    "# torch.save(pysr_test_pred, f\"{FOLDER}/pysr_test_pred_sample{args.sample}.pt\")\n",
    "# del train_pred, test_pred, pysr_train_pred, pysr_test_pred\n",
    "\n",
    "# # Save pysrmodel\n",
    "# with open(f\"{FOLDER}/pysrmodel_sample{args.sample}.pkl\", \"wb\") as file:\n",
    "#     dump(pysrmodel, file)\n",
    "# del pysrmodel\n",
    "\n",
    "# end_time = process_time()\n",
    "# print(f\"[TIME] gen_formulae: {end_time - start_time} s.ms\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8ff88eb-e6c0-4b90-a764-e42d0846260f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "# * ----- Visualize the computation trees present in the forumulae.\n",
    "variables_eq = getVariables(equation)\n",
    "\n",
    "node_mapping = None\n",
    "if dataset_name == \"MUTAG\":\n",
    "    node_mapping = {0: \"C\", 1: \"N\", 2: \"O\", 3: \"F\", 4: \"I\", 5: \"Cl\", 6: \"Br\"}\n",
    "elif dataset_name == \"Mutagenicity\":\n",
    "    node_mapping = {0: 'C', 1: 'O', 2: 'Cl', 3: 'H', 4: 'N', 5: 'F', 6: 'Br',\n",
    "                    7: 'S', 8: 'P', 9: 'I', 10: 'Na', 11: 'K', 12: 'Li', 13: 'Ca'}\n",
    "\n",
    "FIGSIZE = (8, 6)\n",
    "NODESIZE = ...\n",
    "EDGE_WIDTH = 1.75\n",
    "NODE_COLOR = \"#FD5D02\"\n",
    "colors = ['green', 'black', 'blue', 'red'] # aromatic, single, double, triple\n",
    "\n",
    "print(selected_ctrees)\n",
    "# for v in variables_eq:\n",
    "#     if str(v)[0] != \"x\":\n",
    "#         continue\n",
    "#     v = int(str(v)[1:])\n",
    "for v in selected_ctrees:\n",
    "    print(v)\n",
    "    code_dfs = unique_ctree_codes[indices[- k:][v]]\n",
    "    code_id = dict_dfs_id_codes[code_dfs]\n",
    "\n",
    "    # * ----- Ctree using node attributes\n",
    "    if dataset_name == 'BAMultiShapesDataset':\n",
    "        ctree = graph_from_dfs_code(code_id)\n",
    "    else:\n",
    "        ctree = graph_from_dfs_code(code_dfs)\n",
    "    ctree = ctree.reverse()\n",
    "\n",
    "    edge_colors = None\n",
    "    if dataset_name in [\"MUTAG\", \"Mutagenicity\"]:\n",
    "        edge_colors = [colors[ctree.edges[edge]['attr']] for edge in ctree.edges()]\n",
    "\n",
    "    labeldict = {}\n",
    "    for i in range(len(ctree.nodes)):\n",
    "        if dataset_name in [\"MUTAG\", \"Mutagenicity\"]:\n",
    "            labeldict[i] = node_mapping[ctree.nodes[i]['attr']]\n",
    "        elif dataset_name == \"NCI1\":\n",
    "            labeldict[i] = ctree.nodes[i]['attr']\n",
    "\n",
    "    plt.figure(figsize=FIGSIZE)\n",
    "    plt.title(v)\n",
    "    if dataset_name == \"NCI1\":\n",
    "        nx.draw_planar(\n",
    "            ctree,\n",
    "            labels=labeldict,\n",
    "            with_labels=True,\n",
    "            node_color=NODE_COLOR,\n",
    "            width=EDGE_WIDTH\n",
    "        )\n",
    "    elif dataset_name in TU_DATASETS:\n",
    "        nx.draw_planar(\n",
    "            ctree,\n",
    "            labels=labeldict,\n",
    "            with_labels=True,\n",
    "            node_color=NODE_COLOR,\n",
    "            width=EDGE_WIDTH,\n",
    "            edge_color=edge_colors,\n",
    "        )\n",
    "    else:\n",
    "        nx.draw_planar(ctree, node_color=NODE_COLOR, width=EDGE_WIDTH)\n",
    "\n",
    "    # plt.savefig(f\"{PLOT_FOLDER}/{v}_ctree.png\")\n",
    "    print(v)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "    # * ----- Ctree using node ids\n",
    "    ctree_id = graph_from_dfs_code(code_id)\n",
    "    ctree_id = ctree_id.reverse()\n",
    "\n",
    "    labeldict = None\n",
    "    if dataset_name in TU_DATASETS:\n",
    "        labeldict = {}\n",
    "        for i in range(len(ctree_id.nodes)):\n",
    "            labeldict[i] = ctree_id.nodes[i]['attr']\n",
    "\n",
    "    plt.figure(figsize=FIGSIZE)\n",
    "    plt.title(v)\n",
    "    if dataset_name == \"NCI1\":\n",
    "        nx.draw_planar(\n",
    "            ctree_id,\n",
    "            labels=labeldict,\n",
    "            with_labels=True,\n",
    "            node_color=NODE_COLOR,\n",
    "            width=EDGE_WIDTH,\n",
    "        )\n",
    "    elif dataset_name in TU_DATASETS:\n",
    "        nx.draw_planar(\n",
    "            ctree_id,\n",
    "            labels=labeldict,\n",
    "            with_labels=True,\n",
    "            node_color=NODE_COLOR,\n",
    "            width=EDGE_WIDTH,\n",
    "            edge_color=edge_colors,\n",
    "        )\n",
    "    else:\n",
    "        nx.draw_planar(ctree_id, node_color=NODE_COLOR, width=EDGE_WIDTH)\n",
    "\n",
    "    # plt.savefig(f\"{PLOT_FOLDER}/{v}_ctree_id.png\")\n",
    "    print(v)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "    # * ----- Ctree to subgraph\n",
    "    G = dfs(ctree=ctree, ctree_id=ctree_id, node_mapping=node_mapping)\n",
    "    edge_colors = [colors[G.edges[edge]['attr']] for edge in G.edges()]\n",
    "\n",
    "    labeldict = None\n",
    "    if dataset_name in TU_DATASETS:\n",
    "        labeldict = {}\n",
    "        for i in G.nodes:\n",
    "            labeldict[i] = G.nodes[i]['attr']\n",
    "\n",
    "    plt.figure(figsize=FIGSIZE)\n",
    "    plt.title(v)\n",
    "    if dataset_name == \"NCI1\":\n",
    "        nx.draw_kamada_kawai(\n",
    "            G,\n",
    "            labels=labeldict,\n",
    "            with_labels=True,\n",
    "            node_color=NODE_COLOR,\n",
    "            width=EDGE_WIDTH,\n",
    "        )\n",
    "    elif dataset_name in TU_DATASETS:\n",
    "        nx.draw_kamada_kawai(\n",
    "            G,\n",
    "            labels=labeldict,\n",
    "            with_labels=True,\n",
    "            node_color=NODE_COLOR,\n",
    "            width=EDGE_WIDTH,\n",
    "            edge_color=edge_colors,\n",
    "        )\n",
    "    else:\n",
    "        nx.draw_kamada_kawai(G, node_color=NODE_COLOR, width=EDGE_WIDTH)\n",
    "\n",
    "    # plt.savefig(f\"{PLOT_FOLDER}/{v}_structure.png\")\n",
    "    print(v)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0d2f321-dfb6-4046-9d6b-ecd0869d01d8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6cabeec-f89f-4dca-a3aa-a305a2f62f79",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abdd7af4-9ceb-4710-af40-0e15e1605502",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-GraphTrail]",
   "language": "python",
   "name": "conda-env-.conda-GraphTrail-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
