{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# ! pip install gmpy2\n",
    "# ! pip install statsmodels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import sys\n",
    "# sys.path.append(\"/workspace/Project_EvoWire/EVAttack/adversarial_training\")\n",
    "\n",
    "import os\n",
    "import yaml\n",
    "from ml_collections import ConfigDict\n",
    "from tqdm import tqdm\n",
    "from copy import deepcopy\n",
    "from torch_sparse import SparseTensor\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import torch\n",
    "import torch_geometric\n",
    "\n",
    "from utils.data import load_dataset, make_dataset_splits, load_dataset_splits, check_dataset_valid\n",
    "from utils.split import SplitManager, node_induced_subgraph\n",
    "from utils.storage import TensorHash\n",
    "from utils.model import load_model_class, accuracy, load_model_instance, create_model_instance, load_robust_model_instance, from_sparse_GCN, from_sparse_GPRGNN\n",
    "from utils.attack import load_attack_class, attack_storage_label, create_attack_instance, load_attack_instance\n",
    "\n",
    "from robust_diffusion.data import SparseGraph\n",
    "from robust_diffusion.data import count_edges_for_idx\n",
    "from robust_diffusion.helper import utils as robust_utils\n",
    "from robust_diffusion.train import train\n",
    "\n",
    "from sacred import Experiment\n",
    "\n",
    "from sparse_smoothing.models import GCN, GAT, APPNPNet, CNN_MNIST, GIN\n",
    "from sparse_smoothing.training import train_gnn, train_pytorch\n",
    "from sparse_smoothing.prediction import predict_smooth_gnn, predict_smooth_pytorch, sample_multiple_graphs\n",
    "from sparse_smoothing.cert import binary_certificate, binary_certificate_grid\n",
    "from sparse_smoothing.utils import (load_and_standardize, split, accuracy_majority,\n",
    "                                    sample_perturbed_mnist, sample_batch_pyg, get_mnist_dataloaders)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Experiment configs\n",
    "dataset_name = \"cora_ml\"\n",
    "\n",
    "# model_name in [\"GCN\", \"DenseGCN\", \"GAT\", \"GPRGNN\", \"DenseGPRGNN\", \"APPNP\", \"ChebNetII\", \"SoftMedian_GDC\"]\n",
    "model_name = \"GCN\"\n",
    "n_splits = 10\n",
    "\n",
    "training_split = None\n",
    "validation_split = None\n",
    "training_split_type = None\n",
    "validation_split_type = None\n",
    "test_split = None\n",
    "test_split_type = None\n",
    "\n",
    "model_params = None\n",
    "epsilon = 0.01\n",
    "\n",
    "# attack_name in [\"PRBCD\", \"LRBCD\", \"EvaAttack\", \"Evafast\", \"PGD\"]\n",
    "attack_name = \"EvaAttack\"\n",
    "train_attack_name = None\n",
    "attack_params = None\n",
    "\n",
    "inductive = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Experiment Started\n",
      "Loading dataset = cora_ml\n",
      "Found 10 splits!\n",
      "Loading pretrained GCN model on cora_ml dataset for 10 splits\n"
     ]
    }
   ],
   "source": [
    "general_config = yaml.safe_load(open(\"conf/general-config.yaml\"))\n",
    "default_dataset_configs = yaml.safe_load(open(\"conf/data-configs.yaml\")).get(\"configs\").get(\"default\")\n",
    "default_model_configs = yaml.safe_load(open(\"conf/model-configs.yaml\")).get(\"configs\")\n",
    "default_attack_configs = yaml.safe_load(open(\"conf/attack-configs.yaml\")).get(\"configs\")\n",
    "\n",
    "\n",
    "# extracting configs \n",
    "dataset_root = general_config.get(\"dataset_root\", \"data/\")\n",
    "splits_root = general_config.get(\"splits_root\", \"splits/\")\n",
    "models_root = general_config.get(\"models_root\", \"models/\")\n",
    "results_root = general_config.get(\"results_root\", \"results/\")\n",
    "reports_root = general_config.get(\"reports_root\", \"reports/\")\n",
    "    \n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "print(\"Experiment Started\")\n",
    "# Trains the specified model on the given graph and saves the model artifacts, and the splits.\n",
    "\n",
    "print(\"Loading dataset =\", dataset_name)\n",
    "\n",
    "dataset_splits = [\n",
    "    split_record for split_record in os.listdir(splits_root) \n",
    "    if split_record.split(\"-\")[0] == dataset_name \n",
    "    and check_dataset_valid(split_record=split_record, training_split=training_split,\n",
    "                            validation_split=validation_split, training_split_type=training_split_type, \n",
    "                            validation_split_type=validation_split_type, test_split=test_split, \n",
    "                            test_split_type=test_split_type, splits_root=splits_root)]\n",
    "creating_splits = max(n_splits - len(dataset_splits), 0)\n",
    "\n",
    "if creating_splits > 0:\n",
    "    raise ValueError(\"Not enough splits for the dataset. Create the splits by running training scripts.\")\n",
    "\n",
    "# creating remaining needed dataset splits\n",
    "print(f\"Found {len(dataset_splits)} splits!\")\n",
    "\n",
    "print(f\"Loading pretrained {model_name} model on {dataset_name} dataset for {n_splits} splits\")\n",
    "\n",
    "clean_accs = []\n",
    "pert_accs = []\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for split_file in tqdm(dataset_splits[:n_splits]):\n",
    "split_file = dataset_splits[0]\n",
    "split_code = split_file.split(\"-\")[1].replace(\".pt\", \"\")\n",
    "\n",
    "data = load_dataset_splits(\n",
    "    dataset_name, split_code, inductive=inductive, \n",
    "    dataset_root=dataset_root, splits_root=splits_root, device=device)\n",
    "\n",
    "training_attr = data[\"training_attr\"]\n",
    "training_adj = data[\"training_adj\"]\n",
    "labels = data[\"labels\"]\n",
    "training_idx = data[\"training_idx\"]\n",
    "validation_idx = data[\"validation_idx\"]\n",
    "test_attr = data[\"test_attr\"]\n",
    "test_adj = data[\"test_adj\"]\n",
    "test_mask = data[\"test_mask\"]\n",
    "unlabeled_mask = data[\"unlabeled_mask\"]\n",
    "test_idx = test_mask.nonzero(as_tuple=True)[0]\n",
    "dataset_info = data[\"dataset_info\"]\n",
    "split_name = data[\"split_name\"]\n",
    "\n",
    "try:\n",
    "    if train_attack_name is None:\n",
    "        model_instance = load_model_instance(\n",
    "            model_name=model_name, model_params=model_params, \n",
    "            test_attr=test_attr, test_adj=test_adj, labels=labels, \n",
    "            test_mask=test_mask, unlabeled_mask=unlabeled_mask,\n",
    "            split_name=split_name, dataset_info=dataset_info, \n",
    "            inductive=inductive,\n",
    "            models_root=models_root,\n",
    "            default_model_configs=default_model_configs, device=device)\n",
    "    else:\n",
    "        model_instance = load_robust_model_instance(\n",
    "            model_name=model_name, model_params=model_params, \n",
    "            dataset_info=dataset_info, \n",
    "            test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, unlabeled_mask=unlabeled_mask,\n",
    "            split_name=split_name, inductive=inductive,\n",
    "            models_root=models_root, self_training=True, robust_training=True, train_attack_name=train_attack_name, robust_epsilon=0.2,\n",
    "            default_model_configs=default_model_configs, suffix='', device=device)\n",
    "\n",
    "except FileNotFoundError as e:\n",
    "    print(e)\n",
    "    raise ValueError(\"Model not found. Run training scripts to train the model.\")\n",
    "\n",
    "\n",
    "model = model_instance[\"model\"]\n",
    "acc = model_instance[\"accuracy\"]\n",
    "model_params = model_instance[\"model_params\"]\n",
    "model_storage_name = model_instance[\"model_storage_name\"]\n",
    "clean_accs.append(acc)\n",
    "\n",
    "if attack_name == \"PGD\" and model_name == \"GCN\":\n",
    "    model = from_sparse_GCN(model, model_params)\n",
    "elif attack_name == \"PGD\" and model_name == \"GPRGNN\":\n",
    "    model = from_sparse_GPRGNN(model, model_params)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples_eval = 1_000\n",
    "pf_plus_att = 0\n",
    "pf_minus_att = 0\n",
    "pf_plus_adj = 0.0\n",
    "pf_minus_adj = 0.0\n",
    "# pf_plus_adj = 0.0\n",
    "# pf_minus_adj = 0.0\n",
    "\n",
    "smoothing_config = {\n",
    "    \"n_samples\": n_samples_eval,\n",
    "    \"pf_plus_att\": pf_plus_att,\n",
    "    \"pf_minus_att\": pf_minus_att,\n",
    "    \"pf_plus_adj\": pf_plus_adj,\n",
    "    \"pf_minus_adj\": pf_minus_adj,\n",
    "}\n",
    "batch_size = 50\n",
    "p_lower = 0.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<00:00, 95.45it/s]\n"
     ]
    }
   ],
   "source": [
    "def filtered_induced_subgraph(attr, adj, labels, target_mask, idxs=None, depth=2):\n",
    "    # computing the k-hop neighbors of the target node\n",
    "    target_idx = target_mask.nonzero(as_tuple=True)[0]\n",
    "    k_adj = adj.clone()\n",
    "    for _ in range(depth):\n",
    "        k_adj = k_adj + (adj @ k_adj)\n",
    "    filter_mask = (k_adj)[target_idx].to_torch_sparse_coo_tensor().coalesce().indices()[1].unique()\n",
    "\n",
    "    filtered_attr = attr[filter_mask]\n",
    "    filtered_adj = adj[filter_mask, filter_mask]\n",
    "    filtered_labels = labels[filter_mask]\n",
    "    \n",
    "    if idxs is None:\n",
    "        return filtered_attr, filtered_adj, filtered_labels\n",
    "    \n",
    "    filtered_idxs = []\n",
    "    for idx in idxs:\n",
    "        mask = torch.zeros_like(target_mask)\n",
    "        mask[idx] = 1\n",
    "        mask_new = mask[filter_mask]\n",
    "        idx_new = mask_new.nonzero(as_tuple=True)[0]\n",
    "        filtered_idxs.append(idx_new)\n",
    "    # import pdb; pdb.set_trace()\n",
    "    return filtered_attr, filtered_adj, filtered_labels, filtered_idxs\n",
    "\n",
    "def smooth_model_p(test_attr, test_adj, labels, test_mask, model, smoothing_config, dataset_info, batch_size=50):\n",
    "    edge_idx = torch.stack([test_adj.coo()[0], test_adj.coo()[1]]).long().to(device)\n",
    "    attr_idx = torch.stack(list(test_attr.nonzero(as_tuple=True)))\n",
    "\n",
    "    votes = torch.zeros((test_attr.shape[0], dataset_info[\"n_classes\"])).to(device)\n",
    "\n",
    "    for i in tqdm(range(smoothing_config[\"n_samples\"] // batch_size)):\n",
    "        attr_idx_batch, edge_idx_batch = sample_multiple_graphs(\n",
    "            attr_idx=attr_idx, edge_idx=edge_idx,\n",
    "            sample_config=smoothing_config, n=test_attr.shape[0], d=0, nsamples=batch_size)\n",
    "\n",
    "        test_attr_bulk = test_attr.repeat(batch_size, 1).to(device)\n",
    "\n",
    "        test_adj_bulk = torch.sparse_coo_tensor(indices=edge_idx_batch, values=torch.ones_like(edge_idx_batch[0]).float(), size=(test_attr_bulk.shape[0], test_attr_bulk.shape[0])).to(device)\n",
    "\n",
    "        votes += F.one_hot(model(test_attr_bulk, test_adj_bulk).argmax(1), dataset_info[\"n_classes\"]).float().reshape(-1, test_attr.shape[0], dataset_info[\"n_classes\"]).sum(dim=0)\n",
    "\n",
    "    y_true_mask = F.one_hot(labels).bool()\n",
    "    p_emps = (votes / smoothing_config[\"n_samples\"])[y_true_mask]\n",
    "    p_emps_test = p_emps[test_mask]\n",
    "    return p_emps_test\n",
    "\n",
    "certification_results = smooth_model_p(test_attr, test_adj, labels, test_mask, model, smoothing_config, dataset_info, batch_size=batch_size)\n",
    "certified_mask = (certification_results > p_lower)\n",
    "test_targets = test_mask.nonzero(as_tuple=True)[0][certified_mask.cpu()]\n",
    "target_mask = torch.zeros_like(test_mask).to(device)\n",
    "target_mask[test_targets[:10]] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(241, device='cuda:0')"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(certification_results > 0.6).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(10, device='cuda:0')"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_mask.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_attr, filtered_adj, filtered_labels, filtered_idxs = filtered_induced_subgraph(\n",
    "    test_attr, test_adj, labels, target_mask,\n",
    "    idxs=[training_idx, validation_idx, test_idx, unlabeled_mask.nonzero(as_tuple=True)[0], target_mask.nonzero(as_tuple=True)[0]], depth=2)\n",
    "training_idx_filtered = filtered_idxs[0]\n",
    "validation_idx_filtered = filtered_idxs[1]\n",
    "test_idx_filtered = filtered_idxs[2]\n",
    "unlabeled_idx_filtered = filtered_idxs[3]\n",
    "target_idx_filtered = filtered_idxs[4]\n",
    "\n",
    "filtered_dataset_info = deepcopy(dataset_info)\n",
    "filtered_dataset_info[\"n_nodes\"] = filtered_attr.size(0)\n",
    "\n",
    "test_mask_filtered = torch.zeros(size=(filtered_attr.shape[0], ), dtype=bool).to(device)\n",
    "test_mask_filtered[test_idx_filtered] = 1\n",
    "unlabeled_mask_filtered = torch.zeros(size=(filtered_attr.shape[0], ), dtype=bool).to(device)\n",
    "unlabeled_mask_filtered[unlabeled_idx_filtered] = 1\n",
    "target_mask_filtered = torch.zeros(size=(filtered_attr.shape[0], ), dtype=bool).to(device)\n",
    "target_mask_filtered[target_idx_filtered] = 1\n",
    "target_mask_filtered.sum()\n",
    "\n",
    "training_mask_filtered = torch.zeros_like(test_mask_filtered)\n",
    "training_mask_filtered[training_idx_filtered] = 1\n",
    "validation_mask_filtered = torch.zeros_like(test_mask_filtered)\n",
    "validation_mask_filtered[validation_idx_filtered] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<00:00, 130.15it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smooth_model_p(filtered_attr, filtered_adj, filtered_labels, target_mask_filtered, model, smoothing_config, filtered_dataset_info, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "from robust_diffusion.data import count_edges_for_idx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(177, device='cuda:0')"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_mask_filtered.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SparseTensor(row=tensor([   0,    0,    0,  ..., 1783, 1783, 1784], device='cuda:0'),\n",
       "             col=tensor([   1,  844,  882,  ...,  610, 1237,  988], device='cuda:0'),\n",
       "             val=tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0'),\n",
       "             size=(1785, 1785), nnz=11430, density=0.36%)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filtered_adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "         iter : 1\n",
      "pop_best_eval : 0.20241178572177887\n",
      "  median_eval : 0.20374512672424316\n",
      "    mean_eval : 0.20366939902305603\n",
      "\n",
      "         iter : 2\n",
      "pop_best_eval : 0.20220917463302612\n",
      "  median_eval : 0.2034052312374115\n",
      "    mean_eval : 0.2033187448978424\n",
      "\n",
      "         iter : 3\n",
      "pop_best_eval : 0.20219609141349792\n",
      "  median_eval : 0.20309151709079742\n",
      "    mean_eval : 0.20303694903850555\n",
      "\n",
      "         iter : 4\n",
      "pop_best_eval : 0.20159479975700378\n",
      "  median_eval : 0.20284314453601837\n",
      "    mean_eval : 0.2027219831943512\n",
      "\n",
      "         iter : 5\n",
      "pop_best_eval : 0.20154903829097748\n",
      "  median_eval : 0.20247715711593628\n",
      "    mean_eval : 0.20244497060775757\n",
      "\n",
      "         iter : 6\n",
      "pop_best_eval : 0.20101962983608246\n",
      "  median_eval : 0.20228107273578644\n",
      "    mean_eval : 0.20221519470214844\n",
      "\n",
      "         iter : 7\n",
      "pop_best_eval : 0.20101962983608246\n",
      "  median_eval : 0.20203924179077148\n",
      "    mean_eval : 0.2019588202238083\n",
      "\n",
      "         iter : 8\n",
      "pop_best_eval : 0.20052289962768555\n",
      "  median_eval : 0.20177781581878662\n",
      "    mean_eval : 0.20170116424560547\n",
      "\n",
      "         iter : 9\n",
      "pop_best_eval : 0.20052289962768555\n",
      "  median_eval : 0.20150327682495117\n",
      "    mean_eval : 0.2014436423778534\n",
      "\n",
      "         iter : 10\n",
      "pop_best_eval : 0.20003923773765564\n",
      "  median_eval : 0.2012287825345993\n",
      "    mean_eval : 0.20117373764514923\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/anonymous/anonymous-home/projects/eva-evolutionary-attacks-on-graphs/EVAttack/experiments/eva_attack.py:58: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  var_coef = var - torch.tensor(self.variances[-10]).mean()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "         iter : 11\n",
      "pop_best_eval : 0.19998040795326233\n",
      "  median_eval : 0.20101962983608246\n",
      "    mean_eval : 0.20092423260211945\n",
      "\n",
      "         iter : 12\n",
      "pop_best_eval : 0.1994575411081314\n",
      "  median_eval : 0.2007516473531723\n",
      "    mean_eval : 0.20063968002796173\n",
      "\n",
      "         iter : 13\n",
      "pop_best_eval : 0.1989281326532364\n",
      "  median_eval : 0.20040525496006012\n",
      "    mean_eval : 0.2003130316734314\n",
      "\n",
      "         iter : 14\n",
      "pop_best_eval : 0.1989281326532364\n",
      "  median_eval : 0.20009806752204895\n",
      "    mean_eval : 0.1999942511320114\n",
      "\n",
      "         iter : 15\n",
      "pop_best_eval : 0.1976209431886673\n",
      "  median_eval : 0.19975818693637848\n",
      "    mean_eval : 0.19965800642967224\n",
      "\n",
      "         iter : 16\n",
      "pop_best_eval : 0.1976209431886673\n",
      "  median_eval : 0.1994575411081314\n",
      "    mean_eval : 0.19937051832675934\n",
      "\n",
      "         iter : 17\n",
      "pop_best_eval : 0.1976209431886673\n",
      "  median_eval : 0.19921571016311646\n",
      "    mean_eval : 0.1991075575351715\n",
      "\n",
      "         iter : 18\n",
      "pop_best_eval : 0.19698040187358856\n",
      "  median_eval : 0.1989085078239441\n",
      "    mean_eval : 0.1987791359424591\n",
      "\n",
      "         iter : 19\n",
      "pop_best_eval : 0.19660131633281708\n",
      "  median_eval : 0.1984902173280716\n",
      "    mean_eval : 0.19837430119514465\n",
      "\n",
      "         iter : 20\n",
      "pop_best_eval : 0.19660131633281708\n",
      "  median_eval : 0.19815033674240112\n",
      "    mean_eval : 0.1980431079864502\n",
      "\n",
      "         iter : 21\n",
      "pop_best_eval : 0.1962549388408661\n",
      "  median_eval : 0.19779087603092194\n",
      "    mean_eval : 0.19770437479019165\n",
      "\n",
      "         iter : 22\n",
      "pop_best_eval : 0.19566668570041656\n",
      "  median_eval : 0.19751638174057007\n",
      "    mean_eval : 0.19739967584609985\n",
      "\n",
      "         iter : 23\n",
      "pop_best_eval : 0.19554249942302704\n",
      "  median_eval : 0.1971960961818695\n",
      "    mean_eval : 0.1970653235912323\n",
      "\n",
      "         iter : 24\n",
      "pop_best_eval : 0.19507192075252533\n",
      "  median_eval : 0.1967451274394989\n",
      "    mean_eval : 0.1966795027256012\n",
      "\n",
      "         iter : 25\n",
      "pop_best_eval : 0.19503270089626312\n",
      "  median_eval : 0.19645753502845764\n",
      "    mean_eval : 0.19638171792030334\n",
      "\n",
      "         iter : 26\n",
      "pop_best_eval : 0.1948823630809784\n",
      "  median_eval : 0.19624187052249908\n",
      "    mean_eval : 0.19614094495773315\n",
      "\n",
      "         iter : 27\n",
      "pop_best_eval : 0.1944902092218399\n",
      "  median_eval : 0.19591505825519562\n",
      "    mean_eval : 0.19588187336921692\n",
      "\n",
      "         iter : 28\n",
      "pop_best_eval : 0.1944510042667389\n",
      "  median_eval : 0.19569936394691467\n",
      "    mean_eval : 0.1956002563238144\n",
      "\n",
      "         iter : 29\n",
      "pop_best_eval : 0.19426144659519196\n",
      "  median_eval : 0.1953464299440384\n",
      "    mean_eval : 0.19527427852153778\n",
      "\n",
      "         iter : 30\n",
      "pop_best_eval : 0.1935882568359375\n",
      "  median_eval : 0.19500654935836792\n",
      "    mean_eval : 0.19491606950759888\n",
      "\n",
      "         iter : 31\n",
      "pop_best_eval : 0.1935490369796753\n",
      "  median_eval : 0.19461439549922943\n",
      "    mean_eval : 0.19451957941055298\n",
      "\n",
      "         iter : 32\n",
      "pop_best_eval : 0.19298040866851807\n",
      "  median_eval : 0.19420263171195984\n",
      "    mean_eval : 0.1941247582435608\n",
      "\n",
      "         iter : 33\n",
      "pop_best_eval : 0.192718967795372\n",
      "  median_eval : 0.19388890266418457\n",
      "    mean_eval : 0.1937967836856842\n",
      "\n",
      "         iter : 34\n",
      "pop_best_eval : 0.19251635670661926\n",
      "  median_eval : 0.1935621052980423\n",
      "    mean_eval : 0.19350354373455048\n",
      "\n",
      "         iter : 35\n",
      "pop_best_eval : 0.19215689599514008\n",
      "  median_eval : 0.19328106939792633\n",
      "    mean_eval : 0.19320738315582275\n",
      "\n",
      "         iter : 36\n",
      "pop_best_eval : 0.19194118678569794\n",
      "  median_eval : 0.19300003349781036\n",
      "    mean_eval : 0.1929103136062622\n",
      "\n",
      "         iter : 37\n",
      "pop_best_eval : 0.1917385756969452\n",
      "  median_eval : 0.19275818765163422\n",
      "    mean_eval : 0.19268307089805603\n",
      "\n",
      "         iter : 38\n",
      "pop_best_eval : 0.1913791000843048\n",
      "  median_eval : 0.19250981509685516\n",
      "    mean_eval : 0.19243302941322327\n",
      "\n",
      "         iter : 39\n",
      "pop_best_eval : 0.1913791000843048\n",
      "  median_eval : 0.1922614723443985\n",
      "    mean_eval : 0.1921783834695816\n",
      "\n",
      "         iter : 40\n",
      "pop_best_eval : 0.1909346580505371\n",
      "  median_eval : 0.19194774329662323\n",
      "    mean_eval : 0.1919277012348175\n",
      "\n",
      "         iter : 41\n",
      "pop_best_eval : 0.1903790980577469\n",
      "  median_eval : 0.1917647272348404\n",
      "    mean_eval : 0.1916828751564026\n",
      "\n",
      "         iter : 42\n",
      "pop_best_eval : 0.1903790980577469\n",
      "  median_eval : 0.19156211614608765\n",
      "    mean_eval : 0.1914554387331009\n",
      "\n",
      "         iter : 43\n",
      "pop_best_eval : 0.1903790980577469\n",
      "  median_eval : 0.1913006603717804\n",
      "    mean_eval : 0.19122692942619324\n",
      "\n",
      "         iter : 44\n",
      "pop_best_eval : 0.1900000274181366\n",
      "  median_eval : 0.19102618098258972\n",
      "    mean_eval : 0.19099563360214233\n",
      "\n",
      "         iter : 45\n",
      "pop_best_eval : 0.18945100903511047\n",
      "  median_eval : 0.19078433513641357\n",
      "    mean_eval : 0.1907324492931366\n",
      "\n",
      "         iter : 46\n",
      "pop_best_eval : 0.18945100903511047\n",
      "  median_eval : 0.19058826565742493\n",
      "    mean_eval : 0.19051691889762878\n",
      "\n",
      "         iter : 47\n",
      "pop_best_eval : 0.18945100903511047\n",
      "  median_eval : 0.1903986930847168\n",
      "    mean_eval : 0.1903512328863144\n",
      "\n",
      "         iter : 48\n",
      "pop_best_eval : 0.18938565254211426\n",
      "  median_eval : 0.19020918011665344\n",
      "    mean_eval : 0.19016146659851074\n",
      "\n",
      "         iter : 49\n",
      "pop_best_eval : 0.189124196767807\n",
      "  median_eval : 0.1900196373462677\n",
      "    mean_eval : 0.18996284902095795\n",
      "\n",
      "         iter : 50\n",
      "pop_best_eval : 0.1890522986650467\n",
      "  median_eval : 0.18984968960285187\n",
      "    mean_eval : 0.1897965371608734\n",
      "\n",
      "         iter : 51\n",
      "pop_best_eval : 0.1886666864156723\n",
      "  median_eval : 0.18968629837036133\n",
      "    mean_eval : 0.1896020919084549\n",
      "\n",
      "         iter : 52\n",
      "pop_best_eval : 0.1886666864156723\n",
      "  median_eval : 0.18947716057300568\n",
      "    mean_eval : 0.18943344056606293\n",
      "\n",
      "         iter : 53\n",
      "pop_best_eval : 0.1882091611623764\n",
      "  median_eval : 0.18924838304519653\n",
      "    mean_eval : 0.18917863070964813\n",
      "\n",
      "         iter : 54\n",
      "pop_best_eval : 0.1881830394268036\n",
      "  median_eval : 0.1890392303466797\n",
      "    mean_eval : 0.18897999823093414\n",
      "\n",
      "         iter : 55\n",
      "pop_best_eval : 0.18809805810451508\n",
      "  median_eval : 0.18881700932979584\n",
      "    mean_eval : 0.18881645798683167\n",
      "\n",
      "         iter : 56\n",
      "pop_best_eval : 0.1876862794160843\n",
      "  median_eval : 0.18871243298053741\n",
      "    mean_eval : 0.18866106867790222\n",
      "\n",
      "         iter : 57\n",
      "pop_best_eval : 0.1876862794160843\n",
      "  median_eval : 0.18853597342967987\n",
      "    mean_eval : 0.18850216269493103\n",
      "\n",
      "         iter : 58\n",
      "pop_best_eval : 0.18758173286914825\n",
      "  median_eval : 0.18839871883392334\n",
      "    mean_eval : 0.188349187374115\n",
      "\n",
      "         iter : 59\n",
      "pop_best_eval : 0.18758173286914825\n",
      "  median_eval : 0.1882810741662979\n",
      "    mean_eval : 0.1882399022579193\n",
      "\n",
      "         iter : 60\n",
      "pop_best_eval : 0.18745100498199463\n",
      "  median_eval : 0.18812420964241028\n",
      "    mean_eval : 0.18809887766838074\n",
      "\n",
      "         iter : 61\n",
      "pop_best_eval : 0.18745100498199463\n",
      "  median_eval : 0.18801309168338776\n",
      "    mean_eval : 0.1879735291004181\n",
      "\n",
      "         iter : 62\n",
      "pop_best_eval : 0.18695425987243652\n",
      "  median_eval : 0.18790853023529053\n",
      "    mean_eval : 0.18785887956619263\n",
      "\n",
      "         iter : 63\n",
      "pop_best_eval : 0.18695425987243652\n",
      "  median_eval : 0.1877451092004776\n",
      "    mean_eval : 0.18769454956054688\n",
      "\n",
      "         iter : 64\n",
      "pop_best_eval : 0.18695425987243652\n",
      "  median_eval : 0.18761441111564636\n",
      "    mean_eval : 0.18755808472633362\n",
      "\n",
      "         iter : 65\n",
      "pop_best_eval : 0.18669936060905457\n",
      "  median_eval : 0.18743793666362762\n",
      "    mean_eval : 0.1874038279056549\n",
      "\n",
      "         iter : 66\n",
      "pop_best_eval : 0.18589544296264648\n",
      "  median_eval : 0.187333345413208\n",
      "    mean_eval : 0.18728306889533997\n",
      "\n",
      "         iter : 67\n",
      "pop_best_eval : 0.18589544296264648\n",
      "  median_eval : 0.18722224235534668\n",
      "    mean_eval : 0.18717119097709656\n",
      "\n",
      "         iter : 68\n",
      "pop_best_eval : 0.18589544296264648\n",
      "  median_eval : 0.18711112439632416\n",
      "    mean_eval : 0.18705593049526215\n",
      "\n",
      "         iter : 69\n",
      "pop_best_eval : 0.18589544296264648\n",
      "  median_eval : 0.18701308965682983\n",
      "    mean_eval : 0.1869765818119049\n",
      "\n",
      "         iter : 70\n",
      "pop_best_eval : 0.18589544296264648\n",
      "  median_eval : 0.1869215965270996\n",
      "    mean_eval : 0.18687745928764343\n",
      "\n",
      "         iter : 71\n",
      "pop_best_eval : 0.18589544296264648\n",
      "  median_eval : 0.1868496835231781\n",
      "    mean_eval : 0.18680250644683838\n",
      "\n",
      "         iter : 72\n",
      "pop_best_eval : 0.18589544296264648\n",
      "  median_eval : 0.18675817549228668\n",
      "    mean_eval : 0.18670804798603058\n",
      "\n",
      "         iter : 73\n",
      "pop_best_eval : 0.18580394983291626\n",
      "  median_eval : 0.18669934570789337\n",
      "    mean_eval : 0.18663640320301056\n",
      "\n",
      "         iter : 74\n",
      "pop_best_eval : 0.18580394983291626\n",
      "  median_eval : 0.18661439418792725\n",
      "    mean_eval : 0.18657676875591278\n",
      "\n",
      "         iter : 75\n",
      "pop_best_eval : 0.18580394983291626\n",
      "  median_eval : 0.18657517433166504\n",
      "    mean_eval : 0.18652008473873138\n",
      "\n",
      "         iter : 76\n",
      "pop_best_eval : 0.18580394983291626\n",
      "  median_eval : 0.18649674952030182\n",
      "    mean_eval : 0.18645748496055603\n",
      "\n",
      "         iter : 77\n",
      "pop_best_eval : 0.18580394983291626\n",
      "  median_eval : 0.1864640712738037\n",
      "    mean_eval : 0.1864166259765625\n",
      "\n",
      "         iter : 78\n",
      "pop_best_eval : 0.18580394983291626\n",
      "  median_eval : 0.1863986998796463\n",
      "    mean_eval : 0.18636572360992432\n",
      "\n",
      "         iter : 79\n",
      "pop_best_eval : 0.18580394983291626\n",
      "  median_eval : 0.18633335828781128\n",
      "    mean_eval : 0.18629243969917297\n",
      "\n",
      "         iter : 80\n",
      "pop_best_eval : 0.18580394983291626\n",
      "  median_eval : 0.18629413843154907\n",
      "    mean_eval : 0.18626277148723602\n",
      "\n",
      "         iter : 81\n",
      "pop_best_eval : 0.1855098158121109\n",
      "  median_eval : 0.18624839186668396\n",
      "    mean_eval : 0.1862066090106964\n",
      "\n",
      "         iter : 82\n",
      "pop_best_eval : 0.1855098158121109\n",
      "  median_eval : 0.18622225522994995\n",
      "    mean_eval : 0.18617823719978333\n",
      "\n",
      "         iter : 83\n",
      "pop_best_eval : 0.1855098158121109\n",
      "  median_eval : 0.18616993725299835\n",
      "    mean_eval : 0.18613973259925842\n",
      "\n",
      "         iter : 84\n",
      "pop_best_eval : 0.1855098158121109\n",
      "  median_eval : 0.18612419068813324\n",
      "    mean_eval : 0.1860913634300232\n",
      "\n",
      "         iter : 85\n",
      "pop_best_eval : 0.1855098158121109\n",
      "  median_eval : 0.18609806895256042\n",
      "    mean_eval : 0.1860620081424713\n",
      "\n",
      "         iter : 86\n",
      "pop_best_eval : 0.1855098158121109\n",
      "  median_eval : 0.18607845902442932\n",
      "    mean_eval : 0.18602855503559113\n",
      "\n",
      "         iter : 87\n",
      "pop_best_eval : 0.1855098158121109\n",
      "  median_eval : 0.1860457807779312\n",
      "    mean_eval : 0.18600115180015564\n",
      "\n",
      "         iter : 88\n",
      "pop_best_eval : 0.18529413640499115\n",
      "  median_eval : 0.1860000342130661\n",
      "    mean_eval : 0.1859554946422577\n",
      "\n",
      "         iter : 89\n",
      "pop_best_eval : 0.18529413640499115\n",
      "  median_eval : 0.1859542578458786\n",
      "    mean_eval : 0.1858963668346405\n",
      "\n",
      "         iter : 90\n",
      "pop_best_eval : 0.18529413640499115\n",
      "  median_eval : 0.18589544296264648\n",
      "    mean_eval : 0.18584796786308289\n",
      "\n",
      "         iter : 91\n",
      "pop_best_eval : 0.18529413640499115\n",
      "  median_eval : 0.18580393493175507\n",
      "    mean_eval : 0.18578192591667175\n",
      "\n",
      "         iter : 92\n",
      "pop_best_eval : 0.18529413640499115\n",
      "  median_eval : 0.18575818836688995\n",
      "    mean_eval : 0.185724139213562\n",
      "\n",
      "         iter : 93\n",
      "pop_best_eval : 0.18518301844596863\n",
      "  median_eval : 0.18571242690086365\n",
      "    mean_eval : 0.18567173182964325\n",
      "\n",
      "         iter : 94\n",
      "pop_best_eval : 0.18518301844596863\n",
      "  median_eval : 0.18567320704460144\n",
      "    mean_eval : 0.18563666939735413\n",
      "\n",
      "         iter : 95\n",
      "pop_best_eval : 0.1850849688053131\n",
      "  median_eval : 0.18562746047973633\n",
      "    mean_eval : 0.18558953702449799\n",
      "\n",
      "         iter : 96\n",
      "pop_best_eval : 0.1850457489490509\n",
      "  median_eval : 0.18558171391487122\n",
      "    mean_eval : 0.1855398416519165\n",
      "\n",
      "         iter : 97\n",
      "pop_best_eval : 0.18492811918258667\n",
      "  median_eval : 0.185490220785141\n",
      "    mean_eval : 0.1854657083749771\n",
      "\n",
      "         iter : 98\n",
      "pop_best_eval : 0.18474511802196503\n",
      "  median_eval : 0.18544445931911469\n",
      "    mean_eval : 0.18541117012500763\n",
      "\n",
      "         iter : 99\n",
      "pop_best_eval : 0.18474511802196503\n",
      "  median_eval : 0.18541832268238068\n",
      "    mean_eval : 0.18536311388015747\n",
      "\n",
      "         iter : 100\n",
      "pop_best_eval : 0.18474511802196503\n",
      "  median_eval : 0.18536603450775146\n",
      "    mean_eval : 0.18531455099582672\n",
      "\n",
      "> \u001b[0;32m/home/anonymous/anonymous-home/projects/eva-evolutionary-attacks-on-graphs/EVAttack/experiments/utils/attack.py\u001b[0m(99)\u001b[0;36mcreate_attack_instance\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32m     97 \u001b[0;31m    \u001b[0madversary\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_attack_edge\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m     98 \u001b[0;31m    \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m---> 99 \u001b[0;31m    \u001b[0mpert_adj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpert_attr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0madversary\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_pertubations\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    100 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    101 \u001b[0;31m    pert_acc = accuracy(\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "attack = create_attack_instance(\n",
    "    attack_name='EVACERT', attack_params={\"target_mask\": None, \"p_emp\": 0.7}, epsilon=0.9,\n",
    "    test_attr=filtered_attr, test_adj=filtered_adj, labels=filtered_labels, model=model,\n",
    "    dataset_info=filtered_dataset_info, model_storage_name=model_storage_name, \n",
    "    split_name=split_name, test_mask=(test_mask_filtered | training_mask_filtered | validation_mask_filtered | unlabeled_mask_filtered),\n",
    "    unlabeled_mask=unlabeled_mask_filtered,\n",
    "    default_attack_configs=default_attack_configs, reports_root=reports_root,\n",
    "    device=device, inductive=False, save=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sd;laksd;lkas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'res' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Input \u001b[0;32mIn [11]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mres\u001b[49m\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, test_attr\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], dataset_info[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_classes\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n",
      "\u001b[0;31mNameError\u001b[0m: name 'res' is not defined"
     ]
    }
   ],
   "source": [
    "res.reshape(-1, test_attr.shape[0], dataset_info[\"n_classes\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 1,  ..., 0, 6, 4], device='cuda:0')"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(test_attr, test_adj).argmax(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2810, 7])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "votes.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "citeserr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1681, 6])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "votes.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_ra=5, max_rd=15, min_p_emp=0.6970\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 237.95it/s]\n",
      "100%|██████████| 119/119 [00:01<00:00, 117.95it/s]\n",
      "100%|██████████| 119/119 [00:01<00:00, 118.91it/s]\n",
      "100%|██████████| 119/119 [00:00<00:00, 119.75it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(array([[[1.00000000e+00, 9.25585839e-01, 8.77216634e-01, ...,\n",
       "          3.14674074e-04, 3.01370131e-04, 2.26988960e-04],\n",
       "         [5.78727581e-01, 5.64900508e-01, 5.39485839e-01, ...,\n",
       "          1.94350990e-04, 1.87571855e-04, 1.44081136e-04],\n",
       "         [3.50743988e-01, 3.44727581e-01, 3.31680508e-01, ...,\n",
       "          1.20001397e-04, 1.16677412e-04, 9.12919249e-05],\n",
       "         ...,\n",
       "         [1.28831584e-01, 1.28332114e-01, 1.25261588e-01, ...,\n",
       "          4.57111298e-05, 4.50736867e-05, 3.64743927e-05],\n",
       "         [9.88633731e-02, 7.82875843e-02, 7.69457141e-02, ...,\n",
       "          2.82010496e-05, 2.79936845e-05, 2.30049082e-05],\n",
       "         [9.05499231e-02, 4.77533481e-02, 4.72535683e-02, ...,\n",
       "          1.73938925e-05, 1.73775450e-05, 1.44904406e-05]],\n",
       " \n",
       "        [[1.00000000e+00, 4.84232218e-01, 1.48983160e-01, ...,\n",
       "          1.56185153e-07, 1.37007374e-07, 4.65204185e-08],\n",
       "         [4.16614222e-01, 2.97413466e-01, 9.81322181e-02, ...,\n",
       "          9.57619890e-08, 8.48595532e-08, 3.10888583e-08],\n",
       "         [2.52493468e-01, 1.82614222e-01, 6.41934655e-02, ...,\n",
       "          5.87052469e-08, 5.25343529e-08, 2.05951879e-08],\n",
       "         ...,\n",
       "         [9.27432388e-02, 6.87863440e-02, 2.70110676e-02, ...,\n",
       "          2.20517024e-08, 2.01058255e-08, 8.85190316e-09],\n",
       "         [5.62080235e-02, 4.21992388e-02, 1.73999440e-02, ...,\n",
       "          1.35122208e-08, 1.24300894e-08, 5.75449341e-09],\n",
       "         [3.40654688e-02, 2.58816235e-02, 1.11652228e-02, ...,\n",
       "          8.27843100e-09, 7.68149009e-09, 3.72361846e-09]],\n",
       " \n",
       "        [[1.00000000e+00, 6.84381738e-01, 4.79229867e-01, ...,\n",
       "          5.32553211e-06, 2.78206788e-06, 3.38864319e-07],\n",
       "         [4.90131033e-01, 4.18716205e-01, 2.98281738e-01, ...,\n",
       "          3.30074560e-06, 1.79491493e-06, 2.08266980e-07],\n",
       "         [2.97049111e-01, 2.56131033e-01, 1.85496205e-01, ...,\n",
       "          2.04478384e-06, 1.15384505e-06, 1.27975868e-07],\n",
       "         ...,\n",
       "         [1.09108948e-01, 9.57897642e-02, 7.15667110e-02, ...,\n",
       "          7.83632233e-07, 4.72391168e-07, 4.82938425e-08],\n",
       "         [6.61266352e-02, 5.85649480e-02, 4.44033642e-02, ...,\n",
       "          4.84794980e-07, 3.01038713e-07, 2.96586991e-08],\n",
       "         [4.00767486e-02, 3.58002352e-02, 2.75309320e-02, ...,\n",
       "          2.99793972e-07, 1.91390223e-07, 1.82110158e-08]],\n",
       " \n",
       "        ...,\n",
       " \n",
       "        [[1.00000000e+00, 9.72303855e-01, 9.54301360e-01, ...,\n",
       "          6.19017254e-03, 4.75297608e-03, 1.58305780e-03],\n",
       "         [5.95887550e-01, 5.93214457e-01, 5.86203855e-01, ...,\n",
       "          3.88364523e-03, 3.04439263e-03, 1.15952749e-03],\n",
       "         [4.35501995e-01, 3.61887550e-01, 3.59994457e-01, ...,\n",
       "          2.43439546e-03, 1.94533640e-03, 8.25407879e-04],\n",
       "         ...,\n",
       "         [3.31005324e-01, 2.37278785e-01, 1.35661570e-01, ...,\n",
       "          9.54158762e-04, 7.89261001e-04, 3.94827505e-04],\n",
       "         [2.51663833e-01, 1.96221324e-01, 1.02494785e-01, ...,\n",
       "          5.96671782e-04, 5.01307995e-04, 2.67527427e-04],\n",
       "         [1.83156262e-01, 1.50575833e-01, 9.47963643e-02, ...,\n",
       "          3.72854442e-04, 3.17874343e-04, 1.79440339e-04]],\n",
       " \n",
       "        [[1.00000000e+00, 9.16426932e-01, 8.62104438e-01, ...,\n",
       "          2.93269672e-04, 2.66052869e-04, 1.68715477e-04],\n",
       "         [5.75363428e-01, 5.59349656e-01, 5.30326932e-01, ...,\n",
       "          1.81378626e-04, 1.66167454e-04, 1.08763873e-04],\n",
       "         [3.48705108e-01, 3.41363428e-01, 3.26129656e-01, ...,\n",
       "          1.12139358e-04, 1.03705048e-04, 6.98875233e-05],\n",
       "         ...,\n",
       "         [1.28082684e-01, 1.27096429e-01, 1.23222708e-01, ...,\n",
       "          4.28233284e-05, 4.03088145e-05, 2.86123536e-05],\n",
       "         [7.76258692e-02, 7.75386842e-02, 7.57100289e-02, ...,\n",
       "          2.64508669e-05, 2.51058832e-05, 1.82400360e-05],\n",
       "         [7.23947687e-02, 4.72994692e-02, 4.65046682e-02, ...,\n",
       "          1.63331757e-05, 1.56273624e-05, 1.16026393e-05]],\n",
       " \n",
       "        [[1.00000000e+00, 9.81892813e-01, 9.70123141e-01, ...,\n",
       "          7.66919517e-03, 7.19336342e-03, 5.60969691e-03],\n",
       "         [5.99409665e-01, 5.99025947e-01, 5.95792813e-01, ...,\n",
       "          4.78002258e-03, 4.52341526e-03, 3.59991483e-03],\n",
       "         [5.76386594e-01, 4.05037880e-01, 3.65805947e-01, ...,\n",
       "          2.97765446e-03, 2.84171375e-03, 2.30443051e-03],\n",
       "         ...,\n",
       "         [3.82753570e-01, 3.22663390e-01, 2.21268194e-01, ...,\n",
       "          1.15370293e-03, 1.11850888e-03, 9.38086506e-04],\n",
       "         [2.83026406e-01, 2.47969570e-01, 1.87879390e-01, ...,\n",
       "          7.17607642e-04, 7.00852164e-04, 5.96775307e-04],\n",
       "         [2.02163882e-01, 1.81938406e-01, 1.46544610e-01, ...,\n",
       "          4.46148902e-04, 4.38810203e-04, 3.78984508e-04]]]),\n",
       " array([[[1.00000000e+00, 9.22406255e-01, 8.71970321e-01, ...,\n",
       "          3.07243375e-04, 2.89109478e-04, 2.06758882e-04],\n",
       "         [5.77559690e-01, 5.62973488e-01, 5.36306255e-01, ...,\n",
       "          1.89847537e-04, 1.80141157e-04, 1.31820483e-04],\n",
       "         [3.50036176e-01, 3.43559690e-01, 3.29753488e-01, ...,\n",
       "          1.17272031e-04, 1.12173959e-04, 8.38612260e-05],\n",
       "         ...,\n",
       "         [1.28571598e-01, 1.27903137e-01, 1.24553776e-01, ...,\n",
       "          4.47086078e-05, 4.34195255e-05, 3.37450268e-05],\n",
       "         [8.84639209e-02, 7.80275980e-02, 7.65167367e-02, ...,\n",
       "          2.75934605e-05, 2.69911626e-05, 2.13507470e-05],\n",
       "         [8.42472248e-02, 4.75957806e-02, 4.69935820e-02, ...,\n",
       "          1.70256567e-05, 1.67699560e-05, 1.34879187e-05]],\n",
       " \n",
       "        [[1.00000000e+00, 4.77260132e-01, 1.37479217e-01, ...,\n",
       "          1.52444612e-07, 1.30835482e-07, 3.63367970e-08],\n",
       "         [4.14053308e-01, 2.93187959e-01, 9.11601317e-02, ...,\n",
       "          9.34949948e-08, 8.11190127e-08, 2.49169665e-08],\n",
       "         [2.50941399e-01, 1.80053308e-01, 5.99679586e-02, ...,\n",
       "          5.73313110e-08, 5.02673586e-08, 1.68546474e-08],\n",
       "         ...,\n",
       "         [9.21731493e-02, 6.78456963e-02, 2.54589989e-02, ...,\n",
       "          2.15470428e-08, 1.92731370e-08, 7.47796725e-09],\n",
       "         [5.58625147e-02, 4.16291493e-02, 1.64592963e-02, ...,\n",
       "          1.32063665e-08, 1.19254298e-08, 4.92180499e-09],\n",
       "         [3.38560695e-02, 2.55361147e-02, 1.05951333e-02, ...,\n",
       "          8.09306474e-09, 7.37563575e-09, 3.21895881e-09]],\n",
       " \n",
       "        [[1.00000000e+00, 6.78442724e-01, 4.69430494e-01, ...,\n",
       "          5.11523666e-06, 2.43508039e-06, 3.30189632e-07],\n",
       "         [4.87949577e-01, 4.15116802e-01, 2.92342724e-01, ...,\n",
       "          3.17329381e-06, 1.58461948e-06, 2.03009594e-07],\n",
       "         [2.95727016e-01, 2.53949577e-01, 1.81896802e-01, ...,\n",
       "          1.96754033e-06, 1.02639326e-06, 1.24789573e-07],\n",
       "         ...,\n",
       "         [1.08623330e-01, 9.49884948e-02, 7.02446165e-02, ...,\n",
       "          7.55259962e-07, 4.25576921e-07, 4.71234863e-08],\n",
       "         [6.58323213e-02, 5.80793302e-02, 4.36020948e-02, ...,\n",
       "          4.67599664e-07, 2.72666442e-07, 2.89493923e-08],\n",
       "         [3.98983766e-02, 3.55059213e-02, 2.70453142e-02, ...,\n",
       "          2.89372569e-07, 1.74194907e-07, 1.77811329e-08]],\n",
       " \n",
       "        ...,\n",
       " \n",
       "        [[1.00000000e+00, 9.70303186e-01, 9.51000257e-01, ...,\n",
       "          5.88158488e-03, 4.24380645e-03, 7.42927906e-04],\n",
       "         [5.95152685e-01, 5.92001931e-01, 5.84203186e-01, ...,\n",
       "          3.69662241e-03, 2.73580497e-03, 6.50357862e-04],\n",
       "         [4.06107416e-01, 3.61152685e-01, 3.58781931e-01, ...,\n",
       "          2.32104830e-03, 1.75831358e-03, 5.16820223e-04],\n",
       "         ...,\n",
       "         [3.20208417e-01, 2.19463889e-01, 1.35216197e-01, ...,\n",
       "          9.12525276e-04, 7.20565749e-04, 2.81480340e-04],\n",
       "         [2.45120253e-01, 1.85424417e-01, 8.46798886e-02, ...,\n",
       "          5.71439366e-04, 4.59674509e-04, 1.98832176e-04],\n",
       "         [1.79190456e-01, 1.44032253e-01, 8.39994573e-02, ...,\n",
       "          3.57562069e-04, 2.92641927e-04, 1.37806853e-04]],\n",
       " \n",
       "        [[1.00000000e+00, 9.13072135e-01, 8.56569023e-01, ...,\n",
       "          2.85429498e-04, 2.53116581e-04, 1.47370602e-04],\n",
       "         [5.74131179e-01, 5.57316445e-01, 5.26972135e-01, ...,\n",
       "          1.76627005e-04, 1.58327279e-04, 9.58275850e-05],\n",
       "         [3.47958290e-01, 3.40131179e-01, 3.24096445e-01, ...,\n",
       "          1.09259588e-04, 9.89534268e-05, 6.20473488e-05],\n",
       "         ...,\n",
       "         [1.27808371e-01, 1.26643812e-01, 1.22475890e-01, ...,\n",
       "          4.17655616e-05, 3.85634991e-05, 2.57325833e-05],\n",
       "         [7.74596189e-02, 7.72643711e-02, 7.52574123e-02, ...,\n",
       "          2.58097961e-05, 2.40481163e-05, 1.64947207e-05],\n",
       "         [6.57447542e-02, 4.71332189e-02, 4.62303551e-02, ...,\n",
       "          1.59446480e-05, 1.49862916e-05, 1.05448724e-05]],\n",
       " \n",
       "        [[1.00000000e+00, 9.80254538e-01, 9.67419987e-01, ...,\n",
       "          7.41650393e-03, 6.77642288e-03, 4.92174501e-03],\n",
       "         [5.98807911e-01, 5.98033053e-01, 5.94154538e-01, ...,\n",
       "          4.62687638e-03, 4.27072402e-03, 3.18297429e-03],\n",
       "         [5.52316439e-01, 3.65322124e-01, 3.64813053e-01, ...,\n",
       "          2.88483858e-03, 2.68856754e-03, 2.05173927e-03],\n",
       "         ...,\n",
       "         [3.73912374e-01, 3.08075417e-01, 1.97198039e-01, ...,\n",
       "          1.11961078e-03, 1.06225683e-03, 8.45270623e-04],\n",
       "         [2.77668106e-01, 2.39128374e-01, 1.73291417e-01, ...,\n",
       "          6.96945732e-04, 6.66760012e-04, 5.40523256e-04],\n",
       "         [1.98916428e-01, 1.76580106e-01, 1.37703414e-01, ...,\n",
       "          4.33626533e-04, 4.18148293e-04, 3.44892357e-04]]]),\n",
       " array([[[1.        , 0.05287516, 0.08724402, ..., 0.9976935 ,\n",
       "          0.99961557, 0.99963597],\n",
       "         [0.41336094, 0.42204555, 0.43897516, ..., 0.99847009,\n",
       "          0.99976209, 0.99977286],\n",
       "         [0.64446118, 0.64736094, 0.65526555, ..., 0.99899211,\n",
       "          0.99985282, 0.99985837],\n",
       "         ...,\n",
       "         [0.80487702, 0.86876193, 0.86994358, ..., 0.99956981,\n",
       "          0.99994372, 0.99994504],\n",
       "         [0.8306891 , 0.91995123, 0.92014833, ..., 0.99972089,\n",
       "          0.99996522, 0.99996579],\n",
       "         [0.86675461, 0.9317771 , 0.95098524, ..., 0.9998196 ,\n",
       "          0.99997851, 0.99997872]],\n",
       " \n",
       "        [[1.        , 0.52257341, 0.86224613, ..., 0.99999985,\n",
       "          0.99999987, 0.99999996],\n",
       "         [0.58588555, 0.70671116, 0.90867341, ..., 0.99999991,\n",
       "          0.99999992, 0.99999997],\n",
       "         [0.74902155, 0.81988555, 0.93993116, ..., 0.99999994,\n",
       "          0.99999995, 0.99999998],\n",
       "         ...,\n",
       "         [0.90781324, 0.93213185, 0.97450395, ..., 0.99999998,\n",
       "          0.99999998, 0.99999999],\n",
       "         [0.94412924, 0.95835724, 0.98351825, ..., 0.99999999,\n",
       "          0.99999999, 1.        ],\n",
       "         [0.96613893, 0.97445564, 0.98939126, ..., 0.99999999,\n",
       "          0.99999999, 1.        ]],\n",
       " \n",
       "        [[1.        , 0.32122157, 0.53001559, ..., 0.99999487,\n",
       "          0.99999755, 0.99999967],\n",
       "         [0.51192711, 0.58467974, 0.70732157, ..., 0.99999682,\n",
       "          0.9999984 , 0.9999998 ],\n",
       "         [0.70419825, 0.74592711, 0.81789974, ..., 0.99999803,\n",
       "          0.99999897, 0.99999988],\n",
       "         ...,\n",
       "         [0.89134922, 0.90496621, 0.92968065, ..., 0.99999924,\n",
       "          0.99999957, 0.99999995],\n",
       "         [0.93415104, 0.94189322, 0.95635261, ..., 0.99999953,\n",
       "          0.99999973, 0.99999997],\n",
       "         [0.96009154, 0.96447744, 0.97292724, ..., 0.99999971,\n",
       "          0.99999982, 0.99999998]],\n",
       " \n",
       "        ...,\n",
       " \n",
       "        [[1.        , 0.02029411, 0.03348528, ..., 0.99266812,\n",
       "          0.99336321, 0.99530864],\n",
       "         [0.40139361, 0.40229946, 0.40639411, ..., 0.99542441,\n",
       "          0.9958139 , 0.99695666],\n",
       "         [0.45574445, 0.63539361, 0.63551946, ..., 0.99714624,\n",
       "          0.99736272, 0.99803288],\n",
       "         ...,\n",
       "         [0.62904847, 0.69680997, 0.81086285, ..., 0.99889181,\n",
       "          0.99895658, 0.99918581],\n",
       "         [0.72412634, 0.76383247, 0.83159397, ..., 0.99930997,\n",
       "          0.99934466, 0.99947832],\n",
       "         [0.80217112, 0.82521434, 0.86525743, ..., 0.99957057,\n",
       "          0.99958877, 0.99966652]],\n",
       " \n",
       "        [[1.        , 0.07048553, 0.11630112, ..., 0.99967614,\n",
       "          0.99968348, 0.99974802],\n",
       "         [0.4198294 , 0.4327185 , 0.45658553, ..., 0.99980008,\n",
       "          0.99980325, 0.99984077],\n",
       "         [0.64838145, 0.6538294 , 0.6659385 , ..., 0.99987663,\n",
       "          0.99987776, 0.99989953],\n",
       "         ...,\n",
       "         [0.87084718, 0.87113785, 0.87386385, ..., 0.99993628,\n",
       "          0.99995288, 0.99996015],\n",
       "         [0.88828726, 0.92139118, 0.92252425, ..., 0.99994299,\n",
       "          0.99997077, 0.99997495],\n",
       "         [0.90166258, 0.95205196, 0.9524252 , ..., 0.99995421,\n",
       "          0.99998187, 0.99998427]],\n",
       " \n",
       "        [[1.        , 0.01956239, 0.03227794, ..., 0.99255526,\n",
       "          0.99317699, 0.99500138],\n",
       "         [0.40112485, 0.40185599, 0.40566239, ..., 0.99535601,\n",
       "          0.99570104, 0.99677043],\n",
       "         [0.4449938 , 0.63023977, 0.63507599, ..., 0.99710479,\n",
       "          0.99729432, 0.99792002],\n",
       "         ...,\n",
       "         [0.62509965, 0.69029443, 0.8001122 , ..., 0.99887658,\n",
       "          0.99893146, 0.99914436],\n",
       "         [0.72173312, 0.75988365, 0.82507843, ..., 0.99930075,\n",
       "          0.99932943, 0.99945319],\n",
       "         [0.80072068, 0.82282112, 0.86130861, ..., 0.99956497,\n",
       "          0.99957954, 0.9996513 ]]]))"
      ]
     },
     "execution_count": 200,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "binary_certificate(votes=votes.cpu(), pre_votes=votes.cpu(), n_samples=n_samples_eval, conf_alpha=0.1, \n",
    "    pf_plus=pf_plus_adj, pf_minus=pf_minus_adj, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'binary_certificate_grid' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Input \u001b[0;32mIn [196]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m grid_binary_class, \u001b[38;5;241m*\u001b[39m_ \u001b[38;5;241m=\u001b[39m \u001b[43mbinary_certificate_grid\u001b[49m(pf_plus\u001b[38;5;241m=\u001b[39mpf_plus_att, pf_minus\u001b[38;5;241m=\u001b[39mpf_minus_att,\n\u001b[1;32m      2\u001b[0m                                             p_emps\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.8\u001b[39m, reverse\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, progress_bar\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'binary_certificate_grid' is not defined"
     ]
    }
   ],
   "source": [
    "grid_binary_class, *_ = binary_certificate_grid(pf_plus=pf_plus_att, pf_minus=pf_minus_att,\n",
    "                                            p_emps=0.8, reverse=False, progress_bar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 1,  ..., 4, 6, 3], device='cuda:0')"
      ]
     },
     "execution_count": 183,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [00:06<00:00, 31.54it/s]\n"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2, 3],\n",
       "        [1, 2, 3],\n",
       "        [1, 2, 3],\n",
       "        [1, 2, 3],\n",
       "        [1, 2, 3],\n",
       "        [1, 2, 3],\n",
       "        [1, 2, 3],\n",
       "        [1, 2, 3],\n",
       "        [1, 2, 3],\n",
       "        [1, 2, 3]])"
      ]
     },
     "execution_count": 195,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.tensor([1, 2, 3]).repeat(10, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0000e+00, 1.4100e+02, 0.0000e+00, 1.0000e+00, 9.5760e+03, 0.0000e+00,\n",
       "        2.8200e+02], device='cuda:0')"
      ]
     },
     "execution_count": 194,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "votes[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 851862])"
      ]
     },
     "execution_count": 179,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edge_idx_batch.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[    0,     0,     0,  ..., 28090, 28091, 28098],\n",
       "        [   21,    28,    76,  ..., 28085, 28086, 28089]], device='cuda:0')"
      ]
     },
     "execution_count": 147,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edge_idx_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([28100, 2879])"
      ]
     },
     "execution_count": 145,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bulk_attr_idx.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SparseTensor(row=tensor([   0,    0,    0,  ..., 2808, 2808, 2809], device='cuda:0'),\n",
       "             col=tensor([1579, 1581, 2241,  ...,  730, 1787, 1399], device='cuda:0'),\n",
       "             val=tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0'),\n",
       "             size=(2810, 2810), nnz=15962, density=0.20%)"
      ]
     },
     "execution_count": 138,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 853516])"
      ]
     },
     "execution_count": 140,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edge_idx_batch.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "attr_idx_batch = attr_idx_batch.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[    0,     0,  2810,  2810,  5620,  5620,  8430,  8430, 11240, 11240,\n",
       "         14050, 14050, 16860, 16860, 19670, 19670, 22480, 22480, 25290, 25290]],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 136,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attr_idx_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SparseSmoothingModel(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, model, attr):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        self.attr = attr\n",
    "\n",
    "    def forward(self, attr_idx, edge_idx, n, d):\n",
    "        batch_size = n // self.attr.shape[0]\n",
    "        return self.model(self.attr, (edge_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1000 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "Dimension 'n_nodes' of inconsistent size. Got both 852036 and 2810.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Input \u001b[0;32mIn [131]\u001b[0m, in \u001b[0;36m<cell line: 17>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     14\u001b[0m attr_idx \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor([[\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m]])\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[1;32m     15\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m10\u001b[39m\n\u001b[0;32m---> 17\u001b[0m votes \u001b[38;5;241m=\u001b[39m \u001b[43mpredict_smooth_gnn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     18\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattr_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattr_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43medge_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msample_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msmoothing_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mSparseSmoothingModel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_attr\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m     19\u001b[0m \u001b[43m    \u001b[49m\u001b[43mn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_attr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43md\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_attr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataset_info\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mn_classes\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/sparse_smoothing/prediction.py:47\u001b[0m, in \u001b[0;36mpredict_smooth_gnn\u001b[0;34m(attr_idx, edge_idx, sample_config, model, n, d, nc, batch_size)\u001b[0m\n\u001b[1;32m     42\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(nbatches)):\n\u001b[1;32m     43\u001b[0m     attr_idx_batch, edge_idx_batch \u001b[38;5;241m=\u001b[39m sample_multiple_graphs(\n\u001b[1;32m     44\u001b[0m             attr_idx\u001b[38;5;241m=\u001b[39mattr_idx, edge_idx\u001b[38;5;241m=\u001b[39medge_idx,\n\u001b[1;32m     45\u001b[0m             sample_config\u001b[38;5;241m=\u001b[39msample_config, n\u001b[38;5;241m=\u001b[39mn, d\u001b[38;5;241m=\u001b[39md, nsamples\u001b[38;5;241m=\u001b[39mbatch_size)\n\u001b[0;32m---> 47\u001b[0m     predictions \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattr_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattr_idx_batch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43medge_idx_batch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     48\u001b[0m \u001b[43m                        \u001b[49m\u001b[43mn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43md\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43md\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39margmax(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     49\u001b[0m     preds_onehot \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mone_hot(predictions, \u001b[38;5;28mint\u001b[39m(nc))\u001b[38;5;241m.\u001b[39mreshape(batch_size, n, nc)\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m     50\u001b[0m     votes \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m preds_onehot\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[0;32mIn [130]\u001b[0m, in \u001b[0;36mSparseSmoothingModel.forward\u001b[0;34m(self, attr_idx, edge_idx, n, d)\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, attr_idx, edge_idx, n, d):\n\u001b[1;32m      9\u001b[0m     batch_size \u001b[38;5;241m=\u001b[39m n \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattr\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 10\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43medge_idx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/typeguard/__init__.py:911\u001b[0m, in \u001b[0;36mtypechecked.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    909\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    910\u001b[0m     memo \u001b[38;5;241m=\u001b[39m _CallMemo(python_func, _localns, args\u001b[38;5;241m=\u001b[39margs, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 911\u001b[0m     \u001b[43mcheck_argument_types\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmemo\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    912\u001b[0m     retval \u001b[38;5;241m=\u001b[39m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    913\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/torchtyping/typechecker.py:342\u001b[0m, in \u001b[0;36mpatch_typeguard.<locals>.check_argument_types\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    340\u001b[0m     _check_memo(memo)\n\u001b[1;32m    341\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:  \u001b[38;5;66;03m# suppress long traceback\u001b[39;00m\n\u001b[0;32m--> 342\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;241m*\u001b[39mexc\u001b[38;5;241m.\u001b[39margs) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28mNone\u001b[39m\n\u001b[1;32m    343\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m retval\n",
      "\u001b[0;31mTypeError\u001b[0m: Dimension 'n_nodes' of inconsistent size. Got both 852036 and 2810."
     ]
    }
   ],
   "source": [
    "\n",
    "adj = test_adj.clone()\n",
    "if isinstance(adj, SparseTensor):\n",
    "    row, col, edge_weight = adj.t().coo()\n",
    "    edge_idx = torch.stack([row, col], dim=0).to(device)\n",
    "elif isinstance(adj, tuple):\n",
    "    edge_idx = adj[0].to(device)\n",
    "attr_idx = torch.tensor([[0, 0]]).cuda()\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m\n",
      "\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdata\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch_geometric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mData\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'n_nodes'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'n_features'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0madj\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch_sparse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSparseTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msparse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'nnz'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'nnz'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'n_nodes'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'n_nodes'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mattr_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'n_nodes'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'n_features'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0medge_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'nnz'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0medge_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'nnz'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mn\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0md\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'n_nodes'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'n_classes'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDocstring:\u001b[0m\n",
      "Defines the computation performed at every call.\n",
      "\n",
      "Should be overridden by all subclasses.\n",
      "\n",
      ".. note::\n",
      "    Although the recipe for forward pass needs to be defined within\n",
      "    this function, one should call the :class:`Module` instance afterwards\n",
      "    instead of this since the former takes care of running the\n",
      "    registered hooks while the latter silently ignores them.\n",
      "\u001b[0;31mFile:\u001b[0m      ~/anonymous-home/projects/eva-evolutionary-attacks-on-graphs/adversarial_training_adapted/robust_diffusion/models/gcn.py\n",
      "\u001b[0;31mType:\u001b[0m      method\n"
     ]
    }
   ],
   "source": [
    "model.forward?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'NoneType' object has no attribute 'is_sparse'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Input \u001b[0;32mIn [89]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattr_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattr_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43medge_idx\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/typeguard/__init__.py:912\u001b[0m, in \u001b[0;36mtypechecked.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    910\u001b[0m memo \u001b[38;5;241m=\u001b[39m _CallMemo(python_func, _localns, args\u001b[38;5;241m=\u001b[39margs, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[1;32m    911\u001b[0m check_argument_types(memo)\n\u001b[0;32m--> 912\u001b[0m retval \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    913\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    914\u001b[0m     check_return_type(retval, memo)\n",
      "File \u001b[0;32m~/anonymous-home/projects/eva-evolutionary-attacks-on-graphs/adversarial_training_adapted/robust_diffusion/models/gcn.py:218\u001b[0m, in \u001b[0;36mGCN.forward\u001b[0;34m(self, data, adj, attr_idx, edge_idx, edge_weight, n, d)\u001b[0m\n\u001b[1;32m    207\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m    208\u001b[0m             data: Optional[Union[Data, TensorType[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_nodes\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_features\u001b[39m\u001b[38;5;124m\"\u001b[39m]]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m    209\u001b[0m             adj: Optional[Union[SparseTensor,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    216\u001b[0m             n: Optional[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m    217\u001b[0m             d: Optional[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m TensorType[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_nodes\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_classes\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[0;32m--> 218\u001b[0m     x, edge_idx, edge_weight \u001b[38;5;241m=\u001b[39m \u001b[43mGCN\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_forward_input\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43madj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattr_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43md\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    220\u001b[0m     device \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparameters())\u001b[38;5;241m.\u001b[39mdevice\n\u001b[1;32m    221\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m x\u001b[38;5;241m.\u001b[39mdevice \u001b[38;5;241m!=\u001b[39m device:\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/typeguard/__init__.py:912\u001b[0m, in \u001b[0;36mtypechecked.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    910\u001b[0m memo \u001b[38;5;241m=\u001b[39m _CallMemo(python_func, _localns, args\u001b[38;5;241m=\u001b[39margs, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[1;32m    911\u001b[0m check_argument_types(memo)\n\u001b[0;32m--> 912\u001b[0m retval \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    913\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    914\u001b[0m     check_return_type(retval, memo)\n",
      "File \u001b[0;32m~/anonymous-home/projects/eva-evolutionary-attacks-on-graphs/adversarial_training_adapted/robust_diffusion/models/gcn.py:272\u001b[0m, in \u001b[0;36mGCN.parse_forward_input\u001b[0;34m(data, adj, attr_idx, edge_idx, edge_weight, n, d)\u001b[0m\n\u001b[1;32m    270\u001b[0m     edge_idx \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstack([edge_idx_rows, edge_idx_cols], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m    271\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 272\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[43madj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_sparse\u001b[49m:\n\u001b[1;32m    273\u001b[0m         adj \u001b[38;5;241m=\u001b[39m adj\u001b[38;5;241m.\u001b[39mto_sparse()\n\u001b[1;32m    275\u001b[0m     x, edge_idx, edge_weight \u001b[38;5;241m=\u001b[39m data, adj\u001b[38;5;241m.\u001b[39m_indices(), adj\u001b[38;5;241m.\u001b[39m_values()\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'is_sparse'"
     ]
    }
   ],
   "source": [
    "model(attr_idx=attr_idx, edge_idx=edge_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m\n",
      "\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdata\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch_geometric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mData\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'n_nodes'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'n_features'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0madj\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch_sparse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSparseTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msparse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'nnz'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'nnz'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'n_nodes'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'n_nodes'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mattr_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'n_nodes'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'n_features'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0medge_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'nnz'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0medge_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'nnz'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mn\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0md\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mtyping_extensions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'__torchtyping__'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'details'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'n_nodes'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'n_classes'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cls_name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'TensorType'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDocstring:\u001b[0m\n",
      "Defines the computation performed at every call.\n",
      "\n",
      "Should be overridden by all subclasses.\n",
      "\n",
      ".. note::\n",
      "    Although the recipe for forward pass needs to be defined within\n",
      "    this function, one should call the :class:`Module` instance afterwards\n",
      "    instead of this since the former takes care of running the\n",
      "    registered hooks while the latter silently ignores them.\n",
      "\u001b[0;31mFile:\u001b[0m      ~/anonymous-home/projects/eva-evolutionary-attacks-on-graphs/adversarial_training_adapted/robust_diffusion/models/gcn.py\n",
      "\u001b[0;31mType:\u001b[0m      method\n"
     ]
    }
   ],
   "source": [
    "model.forward?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# edge_idx = torch.LongTensor(np.stack(graph.adj_matrix.nonzero())). cuda()\n",
    "# attr_idx = torch.LongTensor(np.stack(graph.attr_matrix.nonzero())).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SparseSmoothingModel(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, model, attr):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        self.attr = attr\n",
    "\n",
    "    def forward(self, attr_idx, edge_idx, n, d):\n",
    "        batch_size = n // self.attr.shape[0]\n",
    "        return self.model(self.attr.repeat(batch_size, 1), (edge_idx, None))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "type of argument \"adj\" must be one of (SparseTensor, FloatTensor, Tuple, typing_extensions.Annotated[torch.Tensor, {'__torchtyping__': True, 'details': ('n_nodes', 'n_nodes',), 'cls_name': 'TensorType'}], NoneType); got tuple instead",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Input \u001b[0;32mIn [57]\u001b[0m, in \u001b[0;36m<cell line: 35>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     32\u001b[0m attr_idx \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor([[\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m]])\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[1;32m     34\u001b[0m model_ \u001b[38;5;241m=\u001b[39m SparseSmoothingModel(model\u001b[38;5;241m.\u001b[39mto(device), attr_idx\u001b[38;5;241m.\u001b[39mto(device))\n\u001b[0;32m---> 35\u001b[0m pre_votes \u001b[38;5;241m=\u001b[39m \u001b[43mpredict_smooth_gnn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattr_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattr_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43medge_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     36\u001b[0m \u001b[43m                                \u001b[49m\u001b[43msample_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msample_config_pre_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     37\u001b[0m \u001b[43m                                \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_attr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43md\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_attr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m     38\u001b[0m \u001b[43m                                \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     40\u001b[0m votes \u001b[38;5;241m=\u001b[39m predict_smooth_gnn(attr_idx\u001b[38;5;241m=\u001b[39mattr_idx, edge_idx\u001b[38;5;241m=\u001b[39medge_idx,\n\u001b[1;32m     41\u001b[0m                             sample_config\u001b[38;5;241m=\u001b[39msample_config_eval,\n\u001b[1;32m     42\u001b[0m                             model\u001b[38;5;241m=\u001b[39mmodel_, n\u001b[38;5;241m=\u001b[39mtest_attr\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], d\u001b[38;5;241m=\u001b[39mtest_attr\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m], nc\u001b[38;5;241m=\u001b[39mlabels\u001b[38;5;241m.\u001b[39mmax()\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m     43\u001b[0m                             batch_size\u001b[38;5;241m=\u001b[39mbatch_size_eval)\n\u001b[1;32m     45\u001b[0m acc_majority \u001b[38;5;241m=\u001b[39m accuracy_majority(\n\u001b[1;32m     46\u001b[0m     votes\u001b[38;5;241m=\u001b[39mvotes, labels\u001b[38;5;241m=\u001b[39mlabels\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy(), idx\u001b[38;5;241m=\u001b[39midx_attack)\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/sparse_smoothing/prediction.py:47\u001b[0m, in \u001b[0;36mpredict_smooth_gnn\u001b[0;34m(attr_idx, edge_idx, sample_config, model, n, d, nc, batch_size)\u001b[0m\n\u001b[1;32m     42\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(nbatches)):\n\u001b[1;32m     43\u001b[0m     attr_idx_batch, edge_idx_batch \u001b[38;5;241m=\u001b[39m sample_multiple_graphs(\n\u001b[1;32m     44\u001b[0m             attr_idx\u001b[38;5;241m=\u001b[39mattr_idx, edge_idx\u001b[38;5;241m=\u001b[39medge_idx,\n\u001b[1;32m     45\u001b[0m             sample_config\u001b[38;5;241m=\u001b[39msample_config, n\u001b[38;5;241m=\u001b[39mn, d\u001b[38;5;241m=\u001b[39md, nsamples\u001b[38;5;241m=\u001b[39mbatch_size)\n\u001b[0;32m---> 47\u001b[0m     predictions \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattr_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattr_idx_batch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43medge_idx_batch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     48\u001b[0m \u001b[43m                        \u001b[49m\u001b[43mn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43md\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43md\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39margmax(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     49\u001b[0m     preds_onehot \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mone_hot(predictions, \u001b[38;5;28mint\u001b[39m(nc))\u001b[38;5;241m.\u001b[39mreshape(batch_size, n, nc)\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m     50\u001b[0m     votes \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m preds_onehot\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[0;32mIn [54]\u001b[0m, in \u001b[0;36mSparseSmoothingModel.forward\u001b[0;34m(self, attr_idx, edge_idx, n, d)\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, attr_idx, edge_idx, n, d):\n\u001b[1;32m      9\u001b[0m     batch_size \u001b[38;5;241m=\u001b[39m n \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattr\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 10\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrepeat\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43medge_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/typeguard/__init__.py:911\u001b[0m, in \u001b[0;36mtypechecked.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    909\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    910\u001b[0m     memo \u001b[38;5;241m=\u001b[39m _CallMemo(python_func, _localns, args\u001b[38;5;241m=\u001b[39margs, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 911\u001b[0m     \u001b[43mcheck_argument_types\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmemo\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    912\u001b[0m     retval \u001b[38;5;241m=\u001b[39m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    913\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/torchtyping/typechecker.py:338\u001b[0m, in \u001b[0;36mpatch_typeguard.<locals>.check_argument_types\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    336\u001b[0m memo\u001b[38;5;241m.\u001b[39mname_to_size \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m    337\u001b[0m memo\u001b[38;5;241m.\u001b[39mname_to_shape \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m--> 338\u001b[0m retval \u001b[38;5;241m=\u001b[39m \u001b[43m_check_argument_types\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    339\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    340\u001b[0m     _check_memo(memo)\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/typeguard/__init__.py:757\u001b[0m, in \u001b[0;36mcheck_argument_types\u001b[0;34m(memo)\u001b[0m\n\u001b[1;32m    755\u001b[0m             check_type(description, value, expected_type, memo)\n\u001b[1;32m    756\u001b[0m         \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:  \u001b[38;5;66;03m# suppress unnecessarily long tracebacks\u001b[39;00m\n\u001b[0;32m--> 757\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;241m*\u001b[39mexc\u001b[38;5;241m.\u001b[39margs) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28mNone\u001b[39m\n\u001b[1;32m    759\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
      "\u001b[0;31mTypeError\u001b[0m: type of argument \"adj\" must be one of (SparseTensor, FloatTensor, Tuple, typing_extensions.Annotated[torch.Tensor, {'__torchtyping__': True, 'details': ('n_nodes', 'n_nodes',), 'cls_name': 'TensorType'}], NoneType); got tuple instead"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "n_samples_pre_eval = 10\n",
    "n_samples_eval = 1000\n",
    "batch_size_eval = 10\n",
    "\n",
    "sample_config = {\n",
    "    'pf_plus_adj': pf_plus_adj,\n",
    "    'pf_plus_att': pf_plus_att,\n",
    "    'pf_minus_adj': pf_minus_adj,\n",
    "    'pf_minus_att': pf_minus_att,\n",
    "} # smoothing_config\n",
    "\n",
    "sample_config_eval = sample_config.copy()\n",
    "sample_config_eval['n_samples'] = n_samples_eval\n",
    "\n",
    "sample_config_pre_eval = sample_config.copy()\n",
    "sample_config_pre_eval['n_samples'] = n_samples_pre_eval\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "if isinstance(test_adj, SparseTensor):\n",
    "    row, col, edge_weight = test_adj.t().coo()\n",
    "    edge_idx = torch.stack([row, col], dim=0).to(device)\n",
    "elif isinstance(test_adj, tuple):\n",
    "    edge_idx = test_adj[0].to(device)\n",
    "attr_idx = torch.tensor([[0, 0]]).cuda()\n",
    "\n",
    "model_ = SparseSmoothingModel(model.to(device), test_attr.to(device))\n",
    "pre_votes = predict_smooth_gnn(attr_idx=attr_idx, edge_idx=edge_idx,\n",
    "                                sample_config=sample_config_pre_eval,\n",
    "                                model=model_, n=test_attr.shape[0], d=test_attr.shape[1], nc=labels.max()+1,\n",
    "                                batch_size=batch_size_eval)\n",
    "\n",
    "votes = predict_smooth_gnn(attr_idx=attr_idx, edge_idx=edge_idx,\n",
    "                            sample_config=sample_config_eval,\n",
    "                            model=model_, n=test_attr.shape[0], d=test_attr.shape[1], nc=labels.max()+1,\n",
    "                            batch_size=batch_size_eval)\n",
    "\n",
    "acc_majority = accuracy_majority(\n",
    "    votes=votes, labels=labels.cpu().numpy(), idx=idx_attack)\n",
    "\n",
    "votes_max = votes.max(1)[idx_attack]\n",
    "\n",
    "agreement = (votes.argmax(1) == pre_votes.argmax(1)).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0]], device='cuda:0')"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attr_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[   0,    0,    0,  ..., 2808, 2808, 2809],\n",
       "        [1579, 1581, 2241,  ...,  730, 1787, 1399]], device='cuda:0')"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edge_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SparseTensor(row=tensor([   0,    0,    0,  ..., 2808, 2808, 2809], device='cuda:0'),\n",
       "             col=tensor([1579, 1581, 2241,  ...,  730, 1787, 1399], device='cuda:0'),\n",
       "             val=tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0'),\n",
       "             size=(2810, 2810), nnz=15962, density=0.20%)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([   0,    0,    0,  ..., 2809, 2809, 2809], device='cuda:0'),\n",
       " tensor([  49,   66,  107,  ..., 2327, 2561, 2573], device='cuda:0'))"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_attr.long().nonzero(as_tuple=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2810, 2879])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_attr.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1, 1, 1,  ..., 1, 1, 1], device='cuda:0')"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_attr.long()[test_attr.nonzero(as_tuple=True)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(indices=tensor([[   0,    0,    0,  ..., 2809, 2809, 2809],\n",
       "                       [  49,   66,  107,  ..., 2327, 2561, 2573]]),\n",
       "       values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),\n",
       "       device='cuda:0', size=(2810, 2879), nnz=142286, layout=torch.sparse_coo)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_attr.to_sparse_coo()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "only integer tensors of a single element can be converted to an index",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Input \u001b[0;32mIn [38]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msparse_coo_tensor\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m      2\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtest_attr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnonzero\u001b[49m\u001b[43m(\u001b[49m\u001b[43mas_tuple\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_attr\u001b[49m\u001b[43m[\u001b[49m\u001b[43mtest_attr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnonzero\u001b[49m\u001b[43m(\u001b[49m\u001b[43mas_tuple\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_attr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mTypeError\u001b[0m: only integer tensors of a single element can be converted to an index"
     ]
    }
   ],
   "source": [
    "torch.sparse_coo_tensor(\n",
    "    test_attr.nonzero(as_tuple=True), test_attr[test_attr.nonzero(as_tuple=True)], test_attr.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_attr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "attack = create_attack_instance(\n",
    "    attack_name='EVAFAST', attack_params=attack_params, epsilon=epsilon,\n",
    "    test_attr=test_attr, test_adj=test_adj, labels=labels, model=model,\n",
    "    dataset_info=dataset_info, model_storage_name=model_storage_name, \n",
    "    split_name=split_name, test_mask=test_mask, unlabeled_mask=unlabeled_mask,\n",
    "    default_attack_configs=default_attack_configs, reports_root=reports_root,\n",
    "    device=device, inductive=False, save=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
