{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Cuda kernels could not loaded -> no CUDA support!\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import yaml\n",
    "from ml_collections import ConfigDict\n",
    "from tqdm import tqdm\n",
    "from copy import deepcopy\n",
    "import numpy as np\n",
    "import logging\n",
    "\n",
    "import torch\n",
    "import torch_geometric\n",
    "\n",
    "from utils.data import load_dataset, make_dataset_splits, load_dataset_splits, check_dataset_valid\n",
    "from utils.split import SplitManager, node_induced_subgraph\n",
    "from utils.storage import TensorHash\n",
    "from utils.model import load_model_class, accuracy, load_model_instance, create_model_instance, from_sparse_GCN, from_sparse_GPRGNN\n",
    "\n",
    "from robust_diffusion.data import count_edges_for_idx\n",
    "\n",
    "from robust_diffusion.attacks import create_attack\n",
    "from robust_diffusion.helper.utils import accuracy, calculate_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Experiment configs\n",
    "dataset_name = \"cora_ml\"\n",
    "\n",
    "# model_name in [\"GCN\", \"DenseGCN\", \"GAT\", \"GPRGNN\", \"DenseGPRGNN\", \"APPNP\", \"ChebNetII\", \"SoftMedian_GDC\"]\n",
    "model_name = \"GCN\"\n",
    "n_splits = 5\n",
    "\n",
    "training_split = None\n",
    "validation_split = None\n",
    "training_split_type = None\n",
    "validation_split_type = None\n",
    "test_split = None\n",
    "test_split_type = None\n",
    "\n",
    "model_params = None\n",
    "epsilon = 0.1\n",
    "\n",
    "# attack_name in [\"PRBCD\", \"LRBCD\", \"EvaAttack\", \"Evafast\", \"PGD\"] \n",
    "attack_name = \"PRBCD\"\n",
    "attack_params = None\n",
    "\n",
    "inductive = False\n",
    "self_training = False\n",
    "robust_training = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Experiment Started\n",
      "Loading dataset = cora_ml\n",
      "Found 5 splits!\n",
      "Loading pretrained GCN model on cora_ml dataset for 5 splits\n"
     ]
    }
   ],
   "source": [
    "## Loading general configs (like dataset_root, etc.) and initial parameters\n",
    "general_config = yaml.safe_load(open(\"conf/general-config.yaml\"))\n",
    "default_dataset_configs = yaml.safe_load(open(\"conf/data-configs.yaml\")).get(\"configs\").get(\"default\")\n",
    "default_model_configs = yaml.safe_load(open(\"conf/model-configs.yaml\")).get(\"configs\")\n",
    "default_attack_configs = yaml.safe_load(open(\"conf/attack-configs.yaml\")).get(\"configs\")\n",
    "\n",
    "\n",
    "# extracting configs \n",
    "dataset_root = general_config.get(\"dataset_root\", \"data/\")\n",
    "splits_root = general_config.get(\"splits_root\", \"splits/\")\n",
    "models_root = general_config.get(\"models_root\", \"models/\")\n",
    "results_root = general_config.get(\"results_root\", \"results/\")\n",
    "reports_root = general_config.get(\"reports_root\", \"reports/\")\n",
    "    \n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "print(\"Experiment Started\")\n",
    "# Trains the specified model on the given graph and saves the model artifacts, and the splits.\n",
    "\n",
    "print(\"Loading dataset =\", dataset_name)\n",
    "\n",
    "dataset_splits = [\n",
    "    split_record for split_record in os.listdir(splits_root) \n",
    "    if split_record.split(\"-\")[0] == dataset_name \n",
    "    and check_dataset_valid(split_record=split_record, training_split=training_split,\n",
    "                            validation_split=validation_split, training_split_type=training_split_type, \n",
    "                            validation_split_type=validation_split_type, test_split=test_split, \n",
    "                            test_split_type=test_split_type, splits_root=splits_root)]\n",
    "creating_splits = max(n_splits - len(dataset_splits), 0)\n",
    "\n",
    "if creating_splits > 0:\n",
    "    raise ValueError(\"Not enough splits for the dataset. Create the splits by running training scripts.\")\n",
    "\n",
    "# creating remaining needed dataset splits\n",
    "print(f\"Found {len(dataset_splits)} splits!\")\n",
    "\n",
    "print(f\"Loading pretrained {model_name} model on {dataset_name} dataset for {n_splits} splits\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/5 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/torch_geometric/data/in_memory_dataset.py:293: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.\n",
      "  warnings.warn(msg)\n",
      "100%|██████████| 5/5 [00:03<00:00,  1.56it/s]\n"
     ]
    }
   ],
   "source": [
    "clean_accs = []\n",
    "pert_accs = []\n",
    "for split_file in tqdm(dataset_splits[:n_splits]):\n",
    "    split_code = split_file.split(\"-\")[1].replace(\".pt\", \"\")\n",
    "\n",
    "    data = load_dataset_splits(\n",
    "        dataset_name, split_code, inductive=inductive, \n",
    "        dataset_root=dataset_root, splits_root=splits_root, device=device)\n",
    "\n",
    "    training_attr = data[\"training_attr\"]\n",
    "    training_adj = data[\"training_adj\"]\n",
    "    validation_attr = data[\"validation_attr\"]\n",
    "    validation_adj = data[\"validation_adj\"]\n",
    "    labels = data[\"labels\"]\n",
    "    training_idx = data[\"training_idx\"]\n",
    "    validation_idx = data[\"validation_idx\"]\n",
    "    test_attr = data[\"test_attr\"]\n",
    "    test_adj = data[\"test_adj\"]\n",
    "    unlabeled_mask = data[\"unlabeled_mask\"]\n",
    "    test_mask = data[\"test_mask\"]\n",
    "    dataset_info = data[\"dataset_info\"]\n",
    "    split_name = data[\"split_name\"]\n",
    "    data_config = data[\"config\"]\n",
    "\n",
    "    try:\n",
    "        model_instance = load_model_instance(\n",
    "            model_name=model_name, model_params=model_params, \n",
    "            test_attr=test_attr, test_adj=test_adj, labels=labels, \n",
    "            test_mask=test_mask, unlabeled_mask=unlabeled_mask,\n",
    "            split_name=split_name, dataset_info=dataset_info, \n",
    "            inductive=inductive,\n",
    "            models_root=models_root,\n",
    "            default_model_configs=default_model_configs, device=device)\n",
    "    except FileNotFoundError as e:\n",
    "        print(e)\n",
    "        raise ValueError(\"Model not found. Run training scripts to train the model.\")\n",
    "\n",
    "    model = model_instance[\"model\"]\n",
    "    acc = model_instance[\"accuracy\"]\n",
    "    model_params = model_instance[\"model_params\"]\n",
    "    model_storage_name = model_instance[\"model_storage_name\"]\n",
    "    clean_accs.append(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(SparseTensor(row=tensor([   0,    0,    0,  ..., 2993, 2993, 2994], device='cuda:0'),\n",
       "              col=tensor([1636, 1638, 2357,  ...,  745, 1865, 1452], device='cuda:0'),\n",
       "              val=tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0'),\n",
       "              size=(2995, 2995), nnz=16316, density=0.18%),\n",
       " torch.Size([2995]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_adj, labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "self_training = True\n",
    "if self_training:\n",
    "    logits = model(training_attr, training_adj)\n",
    "    pseudolabels = torch.argmax(logits, dim=1)\n",
    "    pseudolabels[training_idx] = labels[training_idx]\n",
    "    train_labels = pseudolabels\n",
    "else:\n",
    "    train_labels = labels\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "robust_training = True\n",
    "robust_epsilon = 0.1\n",
    "if robust_training:\n",
    "    n_train_edges = count_edges_for_idx(training_adj, training_idx) # num edges connected to train nodes\n",
    "    m_train = int(n_train_edges) / 2\n",
    "    n_perturbations_train = int(round(robust_epsilon * m_train))\n",
    "\n",
    "    n_val_edges = count_edges_for_idx(validation_adj, validation_idx) # num edges connected to val nodes\n",
    "    m_val = int(n_val_edges) / 2\n",
    "    n_perturbations_val = int(round(robust_epsilon * m_val))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# init attack adjs \n",
    "adj_attacked_val = validation_adj.detach()\n",
    "adj_attacked_train = training_adj.detach()\n",
    "# init trace variables\n",
    "acc_trace_train = []\n",
    "acc_trace_val = []\n",
    "acc_trace_train_pert = []\n",
    "acc_trace_val_pert = []\n",
    "loss_trace = []\n",
    "loss_trace_val = []\n",
    "gamma_trace = []\n",
    "best_loss=np.inf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:07<00:00,  1.48s/it], ?it/s]\n",
      "100%|██████████| 500/500 [01:59<00:00,  4.19it/s]\n",
      "100%|██████████| 5/5 [00:00<00:00, 53.94it/s]09:16:41, 131.18s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 48.63it/s]5:03:27, 54.11s/it]  \n",
      "100%|██████████| 5/5 [00:00<00:00, 58.29it/s]4:32:24, 29.48s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 57.52it/s]4:53:44, 17.90s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 61.75it/s]:34:00, 11.50s/it] \n",
      "100%|██████████| 5/5 [00:00<00:00, 56.32it/s]:21:14,  7.64s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 58.80it/s]:18:42,  5.19s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 61.32it/s]:58:44,  3.58s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 57.46it/s]:05:07,  2.51s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 57.79it/s]1:28:31,  1.78s/it]\n",
      "100%|██████████| 500/500 [09:22<00:00,  1.13s/it]\n",
      "100%|██████████| 5/5 [00:02<00:00,  1.80it/s]145:19:13, 175.03s/it]\n",
      "100%|██████████| 5/5 [00:01<00:00,  3.75it/s]102:01:15, 122.92s/it]\n",
      "100%|██████████| 5/5 [00:01<00:00,  2.53it/s]71:37:43, 86.33s/it]  \n",
      "100%|██████████| 5/5 [00:10<00:00,  2.02s/it]51:01:10, 61.51s/it]\n",
      "100%|██████████| 5/5 [00:10<00:00,  2.01s/it]39:41:56, 47.88s/it]\n",
      "100%|██████████| 5/5 [00:10<00:00,  2.03s/it]31:35:02, 38.10s/it]\n",
      "100%|██████████| 5/5 [00:09<00:00,  1.92s/it]26:12:49, 31.64s/it]\n",
      "100%|██████████| 5/5 [00:02<00:00,  1.98it/s]21:20:35, 25.77s/it]\n",
      "100%|██████████| 5/5 [00:01<00:00,  3.70it/s]15:47:45, 19.08s/it]\n",
      "100%|██████████| 5/5 [00:01<00:00,  2.92it/s]11:36:17, 14.02s/it]\n",
      "100%|██████████| 500/500 [02:58<00:00,  2.80it/s]\n",
      "100%|██████████| 5/5 [00:00<00:00, 54.99it/s]53:12:26, 64.30s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 56.24it/s]37:15:53, 45.05s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 58.61it/s]26:06:35, 31.57s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 59.57it/s]18:18:11, 22.14s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 59.99it/s]12:50:42, 15.54s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 55.41it/s]9:01:29, 10.92s/it] \n",
      "100%|██████████| 5/5 [00:00<00:00, 59.75it/s]6:21:05,  7.69s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 58.31it/s]4:28:46,  5.43s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 27.00it/s]3:10:15,  3.84s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 57.71it/s]2:16:49,  2.76s/it]\n",
      "100%|██████████| 500/500 [02:03<00:00,  4.04it/s]\n",
      "100%|██████████| 5/5 [00:00<00:00, 26.00it/s]32:17:35, 39.16s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 57.48it/s]22:39:42, 27.49s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 57.47it/s]15:53:47, 19.29s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 60.02it/s]11:09:52, 13.55s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 25.25it/s]7:51:03,  9.53s/it] \n",
      "100%|██████████| 5/5 [00:00<00:00,  8.32it/s]5:33:39,  6.75s/it]\n",
      "100%|██████████| 5/5 [00:05<00:00,  1.03s/it]4:03:18,  4.93s/it]\n",
      "100%|██████████| 5/5 [00:08<00:00,  1.79s/it]5:37:26,  6.84s/it]\n",
      "100%|██████████| 5/5 [00:08<00:00,  1.79s/it]7:43:14,  9.39s/it]\n",
      "100%|██████████| 5/5 [00:08<00:00,  1.75s/it]8:35:36, 10.45s/it]\n",
      "100%|██████████| 500/500 [01:49<00:00,  4.59it/s]\n",
      "100%|██████████| 5/5 [00:01<00:00,  4.90it/s]37:11:36, 45.25s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 56.65it/s]26:20:37, 32.06s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 56.95it/s]18:28:26, 22.49s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 34.52it/s]12:57:58, 15.79s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 25.23it/s]9:07:38, 11.12s/it] \n",
      "100%|██████████| 5/5 [00:00<00:00, 25.46it/s]6:27:22,  7.87s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 59.55it/s]4:34:53,  5.59s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 62.64it/s]3:14:36,  3.96s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 59.26it/s]2:18:15,  2.81s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 61.88it/s]1:38:55,  2.01s/it]\n",
      "100%|██████████| 500/500 [01:46<00:00,  4.71it/s]\n",
      "100%|██████████| 5/5 [00:00<00:00, 53.90it/s]27:19:01, 33.35s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 57.12it/s]19:09:07, 23.39s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 59.90it/s]13:26:23, 16.42s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 56.83it/s]9:26:30, 11.54s/it] \n",
      "100%|██████████| 5/5 [00:00<00:00, 59.26it/s]6:38:50,  8.13s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 55.57it/s]4:41:15,  5.73s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 60.48it/s]3:19:08,  4.06s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 61.47it/s]2:21:30,  2.89s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 56.05it/s]1:41:16,  2.07s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 60.54it/s]1:13:14,  1.49s/it]\n",
      "100%|██████████| 500/500 [01:50<00:00,  4.53it/s]\n",
      "100%|██████████| 5/5 [00:00<00:00, 48.63it/s]27:57:41, 34.25s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 59.69it/s]19:36:15, 24.02s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 58.65it/s]13:45:13, 16.86s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 24.81it/s]9:39:46, 11.85s/it] \n",
      "100%|██████████| 5/5 [00:00<00:00, 58.85it/s]6:49:39,  8.37s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 54.74it/s]4:48:46,  5.91s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 62.89it/s]3:24:26,  4.18s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 60.51it/s]2:25:02,  2.97s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00, 15.76it/s]1:43:51,  2.13s/it]\n",
      "100%|██████████| 5/5 [00:00<00:00,  8.36it/s]1:18:22,  1.60s/it]\n",
      " 54%|█████▍    | 270/500 [01:35<01:21,  2.82it/s]\n",
      "Training...:   2%|▏         | 70/3000 [26:10<18:15:44, 22.44s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[18], line 84\u001b[0m\n\u001b[1;32m     80\u001b[0m \u001b[38;5;66;03m# adversary_val = PGD(attr=attr, adj=adj, labels=labels, model=model, idx_attack=idx_val,\u001b[39;00m\n\u001b[1;32m     81\u001b[0m \u001b[38;5;66;03m#                 device=device, data_device=data_device, binary_attr=binary_attr,\u001b[39;00m\n\u001b[1;32m     82\u001b[0m \u001b[38;5;66;03m#                 make_undirected=make_undirected, **val_attack_params)\u001b[39;00m\n\u001b[1;32m     83\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[0;32m---> 84\u001b[0m \u001b[43madversary_val\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn_perturbations_val\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     85\u001b[0m adj_pert \u001b[38;5;241m=\u001b[39m adversary_val\u001b[38;5;241m.\u001b[39mget_modified_adj()\n\u001b[1;32m     86\u001b[0m \u001b[38;5;66;03m#adj_attacked_val = (adj_pert[0].detach(), adj_pert[1].detach())\u001b[39;00m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/typeguard/__init__.py:912\u001b[0m, in \u001b[0;36mtypechecked.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    910\u001b[0m memo \u001b[38;5;241m=\u001b[39m _CallMemo(python_func, _localns, args\u001b[38;5;241m=\u001b[39margs, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[1;32m    911\u001b[0m check_argument_types(memo)\n\u001b[0;32m--> 912\u001b[0m retval \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    913\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    914\u001b[0m     check_return_type(retval, memo)\n",
      "File \u001b[0;32m/workspace/Project_EvoWire/EVAttack/experiments/robust_diffusion/attacks/base_attack.py:129\u001b[0m, in \u001b[0;36mAttack.attack\u001b[0;34m(self, n_perturbations, **kwargs)\u001b[0m\n\u001b[1;32m    119\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    120\u001b[0m \u001b[38;5;124;03mExecutes the attack on the model updating the attributes\u001b[39;00m\n\u001b[1;32m    121\u001b[0m \u001b[38;5;124;03mself.adj_adversary and self.attr_adversary accordingly.\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    126\u001b[0m \u001b[38;5;124;03m    number of perturbations (attack budget in terms of node additions/deletions) that constrain the atack\u001b[39;00m\n\u001b[1;32m    127\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m n_perturbations \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 129\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_attack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn_perturbations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    130\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    131\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattr_adversary \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattr\n",
      "File \u001b[0;32m/workspace/Project_EvoWire/EVAttack/experiments/robust_diffusion/attacks/prbcd.py:123\u001b[0m, in \u001b[0;36mPRBCD._attack\u001b[0;34m(self, n_perturbations, **kwargs)\u001b[0m\n\u001b[1;32m    121\u001b[0m probability_mass_update \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mperturbed_edge_weight\u001b[38;5;241m.\u001b[39msum()\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m    122\u001b[0m \u001b[38;5;66;03m# Projection to stay within relaxed `L_0` budget (Algorithm 1, line 8)\u001b[39;00m\n\u001b[0;32m--> 123\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mperturbed_edge_weight \u001b[38;5;241m=\u001b[39m \u001b[43mAttack\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mproject\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    124\u001b[0m \u001b[43m    \u001b[49m\u001b[43mn_perturbations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mperturbed_edge_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meps\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    125\u001b[0m \u001b[38;5;66;03m# For monitoring\u001b[39;00m\n\u001b[1;32m    126\u001b[0m probability_mass_projected \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mperturbed_edge_weight\u001b[38;5;241m.\u001b[39msum()\u001b[38;5;241m.\u001b[39mitem()\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/typeguard/__init__.py:912\u001b[0m, in \u001b[0;36mtypechecked.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    910\u001b[0m memo \u001b[38;5;241m=\u001b[39m _CallMemo(python_func, _localns, args\u001b[38;5;241m=\u001b[39margs, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[1;32m    911\u001b[0m check_argument_types(memo)\n\u001b[0;32m--> 912\u001b[0m retval \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    913\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    914\u001b[0m     check_return_type(retval, memo)\n",
      "File \u001b[0;32m/workspace/Project_EvoWire/EVAttack/experiments/robust_diffusion/attacks/base_attack.py:264\u001b[0m, in \u001b[0;36mAttack.project\u001b[0;34m(n_perturbations, values, eps, inplace)\u001b[0m\n\u001b[1;32m    262\u001b[0m     left \u001b[38;5;241m=\u001b[39m (values \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mmin()\n\u001b[1;32m    263\u001b[0m     right \u001b[38;5;241m=\u001b[39m values\u001b[38;5;241m.\u001b[39mmax()\n\u001b[0;32m--> 264\u001b[0m     miu \u001b[38;5;241m=\u001b[39m \u001b[43mAttack\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbisection\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mleft\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mright\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_perturbations\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    265\u001b[0m     values\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mcopy_(torch\u001b[38;5;241m.\u001b[39mclamp(\n\u001b[1;32m    266\u001b[0m         values \u001b[38;5;241m-\u001b[39m miu, \u001b[38;5;28mmin\u001b[39m\u001b[38;5;241m=\u001b[39meps, \u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m eps\n\u001b[1;32m    267\u001b[0m     ))\n\u001b[1;32m    268\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[0;32m/workspace/Project_EvoWire/EVAttack/experiments/robust_diffusion/attacks/base_attack.py:290\u001b[0m, in \u001b[0;36mAttack.bisection\u001b[0;34m(edge_weights, a, b, n_perturbations, epsilon, iter_max)\u001b[0m\n\u001b[1;32m    288\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    289\u001b[0m         a \u001b[38;5;241m=\u001b[39m miu\n\u001b[0;32m--> 290\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m ((b \u001b[38;5;241m-\u001b[39m a) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m epsilon):\n\u001b[1;32m    291\u001b[0m         \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[1;32m    292\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m miu\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "attack = 'PRBCD'\n",
    "attack_params = dict(\n",
    "    epochs=500,\n",
    "    fine_tune_epochs=100,\n",
    "    keep_heuristic=\"WeightOnly\",\n",
    "    search_space_size=100_000,\n",
    "    do_synchronize=True,\n",
    "    loss_type=\"tanhMargin\",\n",
    ")\n",
    "\n",
    "train_params = dict(\n",
    "    lr=1e-2,\n",
    "    weight_decay=1e-3,\n",
    "    patience=300,\n",
    "    max_epochs=3000\n",
    ")\n",
    "\n",
    "make_undirected = True\n",
    "binary_attr = False\n",
    "data_device = 0\n",
    "balance_test = True\n",
    "\n",
    "device = 0\n",
    "seed = 0\n",
    "\n",
    "loss_type = 'tanhMargin'\n",
    "validate_every = 10\n",
    "\n",
    "train_attack_params = {'epochs': 5, \n",
    "                'fine_tune_epochs': 0,\n",
    "                'keep_heuristic': 'WeightOnly',\n",
    "                'search_space_size': 100_000,\n",
    "                'do_synchronize': True,\n",
    "                'attack_loss_type': 'tanhMargin'}\n",
    "\n",
    "############################### OPTIMIZER ###################################\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=train_params['lr'], weight_decay=train_params['weight_decay'])\n",
    "\n",
    "for it in tqdm(range(train_params['max_epochs']), desc='Training...'):\n",
    "    # Generate adversarial adjacency\n",
    "    if robust_epsilon > 0:\n",
    "        model.eval()\n",
    "        adversary = create_attack(attack, attr=training_attr, adj=training_adj, labels=train_labels, model=model, idx_attack=np.array(training_idx),\n",
    "                                device=device, data_device=data_device, binary_attr=binary_attr,\n",
    "                                make_undirected=make_undirected, **train_attack_params)\n",
    "        \n",
    "        adversary.attack(n_perturbations_train)\n",
    "\n",
    "        adj_pert = adversary.get_modified_adj()\n",
    "        del adversary\n",
    "        #adj_attacked_train = (adj_pert[0].detach(), adj_pert[1].detach())\n",
    "        adj_attacked_train = adj_pert\n",
    "\n",
    "    # train step\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    logits = model(training_attr, adj_attacked_train)\n",
    "    loss = calculate_loss(logits[training_idx], train_labels[training_idx], loss_type)\n",
    "    loss.backward()\n",
    "    optimizer.step()        \n",
    "    train_accuracy = accuracy(logits.cpu(), train_labels.cpu(), training_idx)\n",
    "    acc_trace_train_pert.append(train_accuracy)\n",
    "\n",
    "    # log clean accuracy\n",
    "    logits = model(training_attr, training_adj)\n",
    "    train_accuracy_clean = accuracy(logits.cpu(), train_labels.cpu(), training_idx)\n",
    "    val_accuracy_clean = accuracy(logits.cpu(), labels.cpu(), validation_idx)\n",
    "    acc_trace_train.append(train_accuracy_clean)\n",
    "    acc_trace_val.append(val_accuracy_clean)\n",
    "\n",
    "\n",
    "    # val step \n",
    "    if it % validate_every == 0:\n",
    "        if robust_epsilon > 0: \n",
    "            #if not self_training:\n",
    "                # ALWAYS RUN FULL DISCRETE ATTACK FOR VAL\n",
    "                adversary_val = create_attack(attack, attr=validation_attr, adj=validation_adj, labels=labels, model=model, idx_attack=np.array(validation_idx),\n",
    "                        device=device, data_device=data_device, binary_attr=binary_attr,\n",
    "                        make_undirected=make_undirected, **attack_params)\n",
    "                # adversary_val = PGD(attr=attr, adj=adj, labels=labels, model=model, idx_attack=idx_val,\n",
    "                #                 device=device, data_device=data_device, binary_attr=binary_attr,\n",
    "                #                 make_undirected=make_undirected, **val_attack_params)\n",
    "                model.eval()\n",
    "                adversary_val.attack(n_perturbations_val)\n",
    "                adj_pert = adversary_val.get_modified_adj()\n",
    "                #adj_attacked_val = (adj_pert[0].detach(), adj_pert[1].detach())\n",
    "                adj_attacked_val = adj_pert\n",
    "                del adversary_val\n",
    "            #else: # when self training use training attacked adj -> TODO change\n",
    "            #    adj_attacked_val = adj_attacked_train\n",
    "\n",
    "        with torch.no_grad():\n",
    "            model.eval()\n",
    "            logits_val = model(validation_attr, adj_attacked_val)\n",
    "            loss_val = calculate_loss(logits_val[training_idx], labels[validation_idx], loss_type)\n",
    "            # save val statistic\n",
    "            loss_trace_val.append(loss_val.item())\n",
    "            val_accuracy = accuracy(logits_val.cpu(), labels.cpu(), validation_idx)\n",
    "            acc_trace_val_pert.append(val_accuracy)\n",
    "\n",
    "            logging.info(f'train acc (pert/clean): {train_accuracy} / {train_accuracy_clean}')\n",
    "            logging.info(f'val acc (pert/clean): {val_accuracy} / {val_accuracy_clean}')\n",
    "    \n",
    "    # save train statistics\n",
    "    loss_trace.append(loss.item())\n",
    "\n",
    "    # save new best model and break if patience is reached\n",
    "    if loss_val < best_loss:\n",
    "        best_loss = loss_val\n",
    "        best_epoch = it\n",
    "        best_state = {key: value.cpu() for key, value in model.state_dict().items()}\n",
    "    else:\n",
    "        if it >= best_epoch + train_params['patience']:\n",
    "            break\n",
    "\n",
    "# restore the best validation state\n",
    "model.load_state_dict(best_state)\n",
    "model.eval()\n",
    "\n",
    "print({\n",
    "    'loss_trace': loss_trace,\n",
    "    'loss_trace_val': loss_trace_val,\n",
    "    'gamma_trace': gamma_trace,\n",
    "    # 'model_path': model_path,\n",
    "    'acc_trace_train': acc_trace_train,\n",
    "    'acc_trace_val': acc_trace_val,\n",
    "    'acc_trace_train_pert': acc_trace_train_pert,\n",
    "    'acc_trace_val_pert': acc_trace_val_pert\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
