{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eec9383d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "import torch\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import gb\n",
    "from gb.exutil import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42772178",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gb.model import GraphSequential, PreprocessA, PreprocessX, PreprocessAUsingXMetric, GCN, RGCN, ProGNN, GNNGuard, \\\n",
    "    GRAND, MLP, SoftMedianPropagation\n",
    "from gb.pert import sp_edge_diff_matrix, sp_feat_diff_matrix\n",
    "from gb.torchext import mul\n",
    "from gb import metric, preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcf0149c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"cora\"\n",
    "A, X, y = gb.data.get_dataset(dataset)\n",
    "N, D = X.shape\n",
    "C = y.max().item() + 1\n",
    "train_nodes, val_nodes, test_nodes = gb.data.get_splits(y)[0]  # [0] = select first split\n",
    "\n",
    "A = A.cuda()\n",
    "X = X.cuda()\n",
    "y = y.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ed3e31d",
   "metadata": {},
   "outputs": [],
   "source": [
    "ptb_rate = 0.05"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92497aa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "budget = int(ptb_rate * (A.cpu().numpy().sum() // 2))\n",
    "budget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d64cc77f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ptb_value = str(int(ptb_rate*100))\n",
    "ptb_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aac0a205-7a97-4cb5-a4d1-12b47af88a97",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir feature_vals\n",
    "!mkdir accuracy_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51fccd43-3e76-4949-b71d-cc8dd435ff2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_dict = {}\n",
    "accuracy_dict['GCN']={}\n",
    "accuracy_dict['GRAND']={}\n",
    "accuracy_dict['GNNGuard']={}\n",
    "accuracy_dict['GCNSVD']={}\n",
    "accuracy_dict['ProGNN']={}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "953ffc20",
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0b9975e",
   "metadata": {},
   "source": [
    "## GCN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fc3d1af",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(42)\n",
    "\n",
    "fit_kwargs = dict(lr=1e-2, weight_decay=5e-4)\n",
    "\n",
    "def make_model():\n",
    "    return gb.model.GCN(n_feat=D, n_class=C, hidden_dims=[64], dropout=0.5).cuda()\n",
    "\n",
    "aux_model = make_model()\n",
    "aux_model.fit((A, X), y, train_nodes, val_nodes, progress=False, **fit_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "364534d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_accuracy = gb.metric.accuracy(aux_model(A, X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['GCN']['clean']=clean_accuracy\n",
    "\n",
    "print(\"Clean test acc:   \", clean_accuracy)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "143d4649",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_vals = aux_model.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "381db2e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in clean_vals.items():\n",
    "    print(v.shape)\n",
    "    clean_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd716315",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez(f'feature_vals/gcn_clean_{ptb_value}.npz', **clean_vals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b68deb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "aux_model.feature_vals"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eed80322",
   "metadata": {},
   "source": [
    "### Poisoning global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96813024",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(A_flip):\n",
    "    A_pert = A + A_flip * (1 - 2 * A)\n",
    "\n",
    "    ########### Meta-Attack w/ Adam ##########\n",
    "    model = make_model()\n",
    "    model.fit((A_pert, X), y, train_nodes, val_nodes, progress=False,\\\n",
    "              **fit_kwargs, differentiable=A_pert.requires_grad)\n",
    "    ##########################################\n",
    "\n",
    "    scores = model(A_pert, X)\n",
    "    return gb.metric.margin(scores[test_nodes, :], y[test_nodes]).tanh().mean()\n",
    "\n",
    "def grad_fn(A_flip):\n",
    "    return torch.autograd.grad(loss_fn(A_flip), A_flip)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46ca24ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "########### PGD for Meta-Attack ##########\n",
    "pert, _ = gb.attack.proj_grad_descent(A.shape, True, A.device, budget, grad_fn, loss_fn, \\\n",
    "                                      base_lr=0.01, grad_clip=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5627bc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "A_pert = A + gb.pert.edge_diff_matrix(pert, A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab67f602",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(\"Adversarial edges:\", pert.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91a2cdd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "pois_model = make_model()\n",
    "pois_model.fit((A_pert, X), y, train_nodes, val_nodes, progress=False, **fit_kwargs)\n",
    "pois_accuracy = gb.metric.accuracy(pois_model(A_pert, X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['GCN']['pois']=pois_accuracy\n",
    "\n",
    "print(\"Poisoned test acc:\", pois_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc94b28e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pois_vals=pois_model.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "748d3c31",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in pois_vals.items():\n",
    "    print(v.shape)\n",
    "    pois_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aa20daa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez('feature_vals/gcn_gp_'+ptb_value+'.npz', **pois_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eecb144b",
   "metadata": {},
   "source": [
    "### Evasion global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c0e4d7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(A_flip):\n",
    "    A_pert = A + A_flip * (1 - 2 * A)\n",
    "\n",
    "    ############### Aux-Attack ###############\n",
    "    model = aux_model\n",
    "\n",
    "    scores = model(A_pert, X)\n",
    "    return gb.metric.margin(scores[test_nodes, :], y[test_nodes]).tanh().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e475409e",
   "metadata": {},
   "outputs": [],
   "source": [
    "########### PGD for Aux-Attack ###########\n",
    "pert, _ = gb.attack.proj_grad_descent(A.shape, True, A.device, budget, grad_fn, loss_fn,\\\n",
    "                                      base_lr=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4541f49",
   "metadata": {},
   "outputs": [],
   "source": [
    "A_pert = A + gb.pert.edge_diff_matrix(pert, A)\n",
    "print(\"Adversarial edges:\", pert.shape[0])\n",
    "evas_accuracy = gb.metric.accuracy(aux_model(A_pert, X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['GCN']['evas'] = evas_accuracy\n",
    "\n",
    "print(\"Evasion test acc: \", evas_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91fcf281",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "aux_model(A_pert,X)\n",
    "evasion_vals=aux_model.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d435668",
   "metadata": {},
   "outputs": [],
   "source": [
    "evasion_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c5c7996",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in evasion_vals.items():\n",
    "    print(v.shape)\n",
    "    evasion_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4b6c589",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez('feature_vals/gcn_ge_'+ptb_value+'.npz', **evasion_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a584396",
   "metadata": {},
   "source": [
    "## GCN-SVD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d95bafa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "rank = 50\n",
    "fit_kwargs = dict(lr=1e-2, weight_decay=5e-4)\n",
    "\n",
    "def make_model():\n",
    "    return gb.model.GraphSequential(OrderedDict(\n",
    "        low_rank=gb.model.PreprocessA(lambda A: gb.preprocess.low_rank(A, rank)),\n",
    "        gcn=gb.model.GCN(n_feat=D, n_class=C, hidden_dims=[64], dropout=0.5)\n",
    "    )).cuda()\n",
    "\n",
    "aux_model = make_model()\n",
    "aux_model.fit((A, X), y, train_nodes, val_nodes, progress=False, **fit_kwargs)\n",
    "\n",
    "A_low_rank = aux_model.low_rank(A)\n",
    "A_weights = gb.metric.eigenspace_alignment(A, rank)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a60661d4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "clean_accuracy = gb.metric.accuracy(aux_model(A, X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['GCNSVD']['clean']=clean_accuracy\n",
    "\n",
    "print(\"Clean test acc:   \", clean_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22ebec93",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_vals = aux_model.gcn.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05b26304",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in clean_vals.items():\n",
    "    print(v.shape)\n",
    "    clean_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a0984fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez(f'feature_vals/gcnsvd_clean_{ptb_value}.npz', **clean_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b6dcef4",
   "metadata": {},
   "source": [
    "### Poisoning global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "633aab11",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(A_flip):\n",
    "    A_diff = A_flip * (1 - 2 * A)\n",
    "\n",
    "    ############# w/ weights #############\n",
    "    #A_diff = A_diff * A_weights\n",
    "    ######################################\n",
    "\n",
    "    A_pert = A_low_rank + A_diff\n",
    "\n",
    "\n",
    "    ############# Meta-Attack ############\n",
    "    model = make_model().sub(exclude=[\"low_rank\"])\n",
    "    model.fit((A_pert, X), y, train_nodes, val_nodes, progress=False, **fit_kwargs, differentiable=A_pert.requires_grad)\n",
    "    ######################################\n",
    "\n",
    "    scores = model(A_pert, X)\n",
    "    return gb.metric.margin(scores[test_nodes, :], y[test_nodes]).tanh().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c02f50f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def grad_fn(A_flip):\n",
    "    return torch.autograd.grad(loss_fn(A_flip), A_flip)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3949c22",
   "metadata": {},
   "outputs": [],
   "source": [
    "########## PGD for Meta-Attack ##########\n",
    "pert, _ = gb.attack.proj_grad_descent(A.shape, True, A.device, budget, grad_fn, loss_fn,\\\n",
    "                                      base_lr=0.1, grad_clip=0.1)\n",
    "A_pert = A + gb.pert.edge_diff_matrix(pert, A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6bd0742",
   "metadata": {},
   "outputs": [],
   "source": [
    "pois_model = make_model()\n",
    "pois_model.fit((A_pert, X), y, train_nodes, val_nodes, progress=False, **fit_kwargs)\n",
    "pois_accuracy = gb.metric.accuracy(pois_model(A_pert, X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['GCNSVD']['pois']=pois_accuracy\n",
    "\n",
    "print(\"Poisoned test acc:\", pois_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fb8eb93",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "pois_vals=pois_model.gcn.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d045cb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in pois_vals.items():\n",
    "    print(v.shape)\n",
    "    pois_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7ab887f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez('feature_vals/gcnsvd_gp_'+ptb_value+'.npz', **pois_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47de3097",
   "metadata": {},
   "source": [
    "### Evasion global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f6eeea5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(A_flip):\n",
    "    A_diff = A_flip * (1 - 2 * A)\n",
    "\n",
    "    A_pert = A_low_rank + A_diff\n",
    "\n",
    "    ############# Aux-Attack #############\n",
    "    model = aux_model.sub(exclude=[\"low_rank\"])\n",
    "\n",
    "    scores = model(A_pert, X)\n",
    "    return gb.metric.margin(scores[test_nodes, :], y[test_nodes]).tanh().mean()\n",
    "\n",
    "def grad_fn(A_flip):\n",
    "    return torch.autograd.grad(loss_fn(A_flip), A_flip)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3481b9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "########### PGD for Aux-Attack ###########\n",
    "pert, _ = gb.attack.proj_grad_descent(A.shape, True, A.device, budget, grad_fn,\\\n",
    "                                      loss_fn, base_lr=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86facf71",
   "metadata": {},
   "outputs": [],
   "source": [
    "A_pert = A + gb.pert.edge_diff_matrix(pert, A)\n",
    "print(\"Adversarial edges:\", pert.shape[0])\n",
    "evas_accuracy = gb.metric.accuracy(aux_model(A_pert, X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['GCNSVD']['evas']=evas_accuracy\n",
    "\n",
    "print(\"Evasion test acc: \", evas_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55470f5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "aux_model(A_pert,X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15f1e491",
   "metadata": {},
   "outputs": [],
   "source": [
    "aux_model(A_pert,X)\n",
    "evasion_vals=aux_model.gcn.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fa00810",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in evasion_vals.items():\n",
    "    print(v.shape)\n",
    "    evasion_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2d70987",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez('feature_vals/gcnsvd_ge_'+ptb_value+'.npz', **evasion_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c09b4ea",
   "metadata": {},
   "source": [
    "## GNNGuard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46862f2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "rank = 50\n",
    "fit_kwargs = dict(lr=1e-2, weight_decay=5e-4)\n",
    "\n",
    "def make_model(div_limit=1e-6):\n",
    "    return gb.model.GNNGuard(n_feat=D, n_class=C, hidden_dims=[64], dropout=0.5, div_limit=div_limit).cuda()\n",
    "\n",
    "aux_model = make_model()\n",
    "aux_model.fit((A, X), y, train_nodes, val_nodes, progress=False, **fit_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1e1ace6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "clean_accuracy = gb.metric.accuracy(aux_model(A, X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['GNNGuard']['clean']=clean_accuracy\n",
    "\n",
    "print(\"Clean test acc:   \", clean_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f31d91",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in clean_vals.items():\n",
    "    print(v.shape)\n",
    "    clean_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc9d4712",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez(f'feature_vals/gnnguard_clean_{ptb_value}.npz', **clean_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "598ebbd7",
   "metadata": {},
   "source": [
    "### Poisoning global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c7fd582",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(A_flip):\n",
    "    A_diff = A_flip * (1 - 2 * A)\n",
    "    A_pert = A + A_diff\n",
    "\n",
    "    ########## w/ real div_limit #########\n",
    "    alteration = dict()\n",
    "    ######################################\n",
    "\n",
    "    ############# Meta-Attack ############\n",
    "    model = make_model(**alteration)\n",
    "    model.fit((A_pert, X), y, train_nodes, val_nodes, progress=False, **fit_kwargs, max_epochs=50, differentiable=A_pert.requires_grad)\n",
    "    scores = model(A_pert, X)\n",
    "    ######################################\n",
    "\n",
    "    return gb.metric.margin(scores[test_nodes, :], y[test_nodes]).tanh().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28020302",
   "metadata": {},
   "outputs": [],
   "source": [
    "def grad_fn(A_flip):\n",
    "    return torch.autograd.grad(loss_fn(A_flip), A_flip)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "322b0bfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "pert, _ = gb.attack.proj_grad_descent(A.shape, True, A.device, budget, grad_fn, loss_fn,\\\n",
    "                                      base_lr=0.1, grad_clip=0.1)\n",
    "A_pert = A + gb.pert.edge_diff_matrix(pert, A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc6b31e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "pois_model = make_model()\n",
    "pois_model.fit((A_pert, X), y, train_nodes, val_nodes, progress=False, **fit_kwargs)\n",
    "pois_accuracy = gb.metric.accuracy(pois_model(A_pert, X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['GNNGuard']['pois']=pois_accuracy\n",
    "\n",
    "print(\"Poisoned test acc:\", pois_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a557c4e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "pois_vals=pois_model.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65c632d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in pois_vals.items():\n",
    "    print(v.shape)\n",
    "    pois_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00c2e32b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez('feature_vals/gnnguard_gp_'+ptb_value+'.npz', **pois_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f0d9268",
   "metadata": {},
   "source": [
    "### Evasion global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a14cdf28",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(A_flip):\n",
    "    A_diff = A_flip * (1 - 2 * A)\n",
    "    A_pert = A + A_diff\n",
    "\n",
    "    ########## w/ real div_limit #########\n",
    "    alteration = dict()\n",
    "\n",
    "    ############# Aux-Attack #############\n",
    "    with gb.model.changed_fields(aux_model, **alteration):\n",
    "        scores = aux_model(A_pert, X)\n",
    "\n",
    "    return gb.metric.margin(scores[test_nodes, :], y[test_nodes]).tanh().mean()\n",
    "\n",
    "def grad_fn(A_flip):\n",
    "    return torch.autograd.grad(loss_fn(A_flip), A_flip)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b6a4555",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "pert, _ = gb.attack.proj_grad_descent(A.shape, True, A.device, budget, grad_fn, loss_fn, base_lr=0.1)\n",
    "\n",
    "A_pert = A + gb.pert.edge_diff_matrix(pert, A)\n",
    "print(\"Adversarial edges:\", pert.shape[0])\n",
    "evas_accuracy = gb.metric.accuracy(aux_model(A_pert, X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['GNNGuard']['evas']=evas_accuracy\n",
    "\n",
    "print(\"Evasion test acc: \", evas_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99b69731",
   "metadata": {},
   "outputs": [],
   "source": [
    "aux_model(A_pert,X)\n",
    "evasion_vals=aux_model.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f031818d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in evasion_vals.items():\n",
    "    print(v.shape)\n",
    "    evasion_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e33a0b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez('feature_vals/gnnguard_ge_'+ptb_value+'.npz', **evasion_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62d55c5b",
   "metadata": {},
   "source": [
    "## ProGNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb730a30",
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_kwargs = dict(gnn_lr=0.01,gnn_weight_decay=0.0005,adj_lr=0.01,adj_momentum=0.9,reg_adj_deviate=1.0)\n",
    "\n",
    "def make_model(A):\n",
    "    return gb.model.ProGNN(A, GCN(n_feat=D, n_class=C, bias=True, activation=\"relu\", hidden_dims=[64],dropout=0.5)).cuda()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce222e50",
   "metadata": {},
   "outputs": [],
   "source": [
    "aux_model = make_model(A)\n",
    "model_args = filter_model_args(aux_model, A, X)\n",
    "aux_model.fit(model_args, y, train_nodes, val_nodes, progress=True, **fit_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0eb00025",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_vals = aux_model.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59df2748",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "clean_accuracy = gb.metric.accuracy(aux_model(X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['ProGNN']['clean']=clean_accuracy\n",
    "\n",
    "print(\"Clean test acc:   \", clean_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73da04d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in clean_vals.items():\n",
    "    print(v.shape)\n",
    "    clean_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "102792f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez(f'feature_vals/prognn_clean_{ptb_value}.npz', **clean_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19a9fcb3",
   "metadata": {},
   "source": [
    "### Poisoning global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "189a9e2d-45be-484b-ad01-cb640b2de535",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65928a37",
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_kwargs2 = dict(gnn_lr=0.01,gnn_weight_decay=0.0005,adj_lr=0.01,adj_momentum=0.9,reg_adj_deviate=1.0,\\\n",
    "                   adj_optim_interval = 2, reg_adj_l1 = 5e-4, reg_adj_nuclear = 0, reg_feat_smooth = 1e-3)\n",
    "\n",
    "def loss_fn(A_flip):\n",
    "    A_diff = A_flip * (1 - 2 * A)\n",
    "    A_pert = A + A_diff\n",
    "\n",
    "    ############# Meta-Attack ############\n",
    "    model = make_model(A_pert)\n",
    "    model_args = filter_model_args(model, A_pert, X)\n",
    "    model.fit(model_args, y, train_nodes, val_nodes, progress=False, **fit_kwargs2, differentiable=A_pert.requires_grad)\n",
    "    #model.fit(X, y, train_nodes, val_nodes, progress=True, **fit_kwargs)\n",
    "    scores = model(X)\n",
    "    ######################################\n",
    "\n",
    "    return gb.metric.margin(scores[test_nodes, :], y[test_nodes]).tanh().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4cb5c2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def grad_fn(A_flip):\n",
    "    return torch.autograd.grad(loss_fn(A_flip), A_flip)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26f6d3c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "pert, _ = gb.attack.proj_grad_descent(A.shape, True, A.device, budget, grad_fn, loss_fn,\\\n",
    "                                      base_lr=0.1, grad_clip=0.1)\n",
    "A_pert = A + gb.pert.edge_diff_matrix(pert, A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90019bfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "pois_model = make_model(A_pert)\n",
    "model_args = filter_model_args(pois_model, A_pert, X)\n",
    "pois_model.fit(model_args, y, train_nodes, val_nodes, progress=True, **fit_kwargs)\n",
    "pois_accuracy = gb.metric.accuracy(pois_model(X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['ProGNN']['pois']=pois_accuracy\n",
    "\n",
    "print(\"Poisoned test acc:\", pois_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "790c68cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "pois_vals=pois_model.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41e0aa7e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for k,v in pois_vals.items():\n",
    "    print(v.shape)\n",
    "    pois_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ba961f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez('feature_vals/prognn_gp_'+ptb_value+'.npz', **pois_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5824c3e",
   "metadata": {},
   "source": [
    "### Evasion global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4030508b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(A_flip):\n",
    "    A_diff = A_flip * (1 - 2 * A)\n",
    "    A_pert = A + A_diff\n",
    "\n",
    "\n",
    "    ############# Aux-Attack #############\n",
    "    \n",
    "    model = aux_model\n",
    "    model.S = A_pert\n",
    "    scores = model(X)\n",
    "    return gb.metric.margin(scores[test_nodes, :], y[test_nodes]).tanh().mean()\n",
    "\n",
    "def grad_fn(A_flip):\n",
    "    return torch.autograd.grad(loss_fn(A_flip), A_flip)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "904ed619",
   "metadata": {},
   "outputs": [],
   "source": [
    "pert, _ = gb.attack.proj_grad_descent(A.shape, True, A.device, budget, grad_fn, loss_fn, base_lr=0.1)\n",
    "\n",
    "A_pert = A + gb.pert.edge_diff_matrix(pert, A)\n",
    "print(\"Adversarial edges:\", pert.shape[0])\n",
    "evas_accuracy = gb.metric.accuracy(aux_model(X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['ProGNN']['evas']=evas_accuracy\n",
    "\n",
    "print(\"Evasion test acc: \", evas_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2668690",
   "metadata": {},
   "outputs": [],
   "source": [
    "aux_model(X)\n",
    "evasion_vals=aux_model.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd919426",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in evasion_vals.items():\n",
    "    evasion_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e97efe1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez('feature_vals/prognn_ge_'+ptb_value+'.npz', **evasion_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5683c0f",
   "metadata": {},
   "source": [
    "## GRAND"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4303c154",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_kwargs1 = dict(hidden_dims=[64],dropout=0.5)\n",
    "model_kwargs2 = dict(dropnode=0.5,mlp_input_dropout=0.5,order=2)\n",
    "\n",
    "def make_model():\n",
    "    return GRAND(MLP(n_feat=D, n_class=C, bias=True, **model_kwargs1),**model_kwargs2).cuda()\n",
    "\n",
    "#aux_model = make_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8ac8479",
   "metadata": {},
   "outputs": [],
   "source": [
    "aux_model = make_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d1dc699",
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_kwargs = dict(lr=0.1, weight_decay=1e-4)\n",
    "aux_model.fit((A,X), y, train_nodes, val_nodes, progress=True, **fit_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45f3fb17",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "clean_accuracy = gb.metric.accuracy(aux_model(A,X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['GRAND']['clean']=clean_accuracy\n",
    "\n",
    "print(\"Clean test acc:   \", clean_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6e45d14",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_vals = aux_model.mlp.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02acb518",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in clean_vals.items():\n",
    "    clean_vals[k] = v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b5a0d60",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez(f'feature_vals/grand_clean_{ptb_value}.npz', **clean_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7681edea",
   "metadata": {},
   "source": [
    "### Poisoning global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8919174",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(A_flip):\n",
    "    A_diff = A_flip * (1 - 2 * A)\n",
    "    A_pert = A + A_diff\n",
    "\n",
    "    ############# Meta-Attack ############\n",
    "    model = make_model()\n",
    "    model.fit((A_pert, X), y, train_nodes, val_nodes, progress=False, **fit_kwargs,max_epochs=100, differentiable=A_pert.requires_grad)\n",
    "    scores = model(A_pert, X)\n",
    "    ######################################\n",
    "\n",
    "    return gb.metric.margin(scores[test_nodes, :], y[test_nodes]).tanh().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e297fff7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def grad_fn(A_flip):\n",
    "    return torch.autograd.grad(loss_fn(A_flip), A_flip)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a5ae992",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "pert, _ = gb.attack.proj_grad_descent(A.shape, True, A.device, budget, grad_fn, loss_fn,\\\n",
    "                                      base_lr=0.1, grad_clip=0.1)\n",
    "A_pert = A + gb.pert.edge_diff_matrix(pert, A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2923919a",
   "metadata": {},
   "outputs": [],
   "source": [
    "pois_model = make_model()\n",
    "pois_model.fit((A_pert, X), y, train_nodes, val_nodes, progress=False, **fit_kwargs)\n",
    "pois_accuracy = gb.metric.accuracy(pois_model(A_pert, X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['GRAND']['pois']=pois_accuracy\n",
    "\n",
    "print(\"Poisoned test acc:\", pois_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bd7f13c",
   "metadata": {},
   "outputs": [],
   "source": [
    "pois_vals = pois_model.mlp.feature_vals\n",
    "pois_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "497d7978",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in pois_vals.items():\n",
    "    pois_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4eba5385",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez('feature_vals/grand_gp_'+ptb_value+'.npz', **pois_vals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7988d3f",
   "metadata": {},
   "source": [
    "### Evasion global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1767619",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(A_flip):\n",
    "    A_diff = A_flip * (1 - 2 * A)\n",
    "    A_pert = A + A_diff\n",
    "\n",
    "\n",
    "    ############# Aux-Attack #############\n",
    "    model = aux_model\n",
    "    scores = model(A_pert, X)\n",
    "\n",
    "    return gb.metric.margin(scores[test_nodes, :], y[test_nodes]).tanh().mean()\n",
    "\n",
    "def grad_fn(A_flip):\n",
    "    return torch.autograd.grad(loss_fn(A_flip), A_flip)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88173c54",
   "metadata": {},
   "outputs": [],
   "source": [
    "pert, _ = gb.attack.proj_grad_descent(A.shape, True, A.device, budget, grad_fn, loss_fn, base_lr=0.1)\n",
    "\n",
    "A_pert = A + gb.pert.edge_diff_matrix(pert, A)\n",
    "print(\"Adversarial edges:\", pert.shape[0])\n",
    "evas_accuracy = gb.metric.accuracy(aux_model(A_pert, X)[test_nodes], y[test_nodes]).item()\n",
    "accuracy_dict['GRAND']['evas']=evas_accuracy\n",
    "\n",
    "print(\"Evasion test acc: \", )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e1d17d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "aux_model(A_pert,X)\n",
    "evasion_vals = aux_model.mlp.feature_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "682deff7",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in evasion_vals.items():\n",
    "    evasion_vals[k]=v.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b99e1fe6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.savez('feature_vals/grand_ge_'+ptb_value+'.npz', **evasion_vals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f7b94ac-9e22-4e39-956c-1743ac514671",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "53086bad-e1e5-45e5-b791-adab9af2713b",
   "metadata": {},
   "source": [
    "## Save the accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dd0794a-7912-44ae-afd1-fb340713a949",
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbf793b9-8e91-40b9-959f-00d701b4abe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "!mkdir accuracy_vals\n",
    "import pickle\n",
    "save_path = 'accuracy_vals/'+ptb_value+'.pkl'\n",
    "with open(save_path, 'wb') as file:\n",
    "    pickle.dump(accuracy_dict, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ec863ce-0fc9-4bee-be04-d6e0143bc817",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
