{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d016facb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from scipy import sparse\n",
    "from deeprobust.graph.data import Dataset\n",
    "from deeprobust.graph.defense import GCN\n",
    "# from deeprobust.graph.global_attack import Metattack\n",
    "\n",
    "import math\n",
    "\n",
    "import scipy.sparse as sp\n",
    "from torch import optim\n",
    "from torch.nn import functional as F\n",
    "from torch.nn.parameter import Parameter\n",
    "from tqdm import tqdm\n",
    "from deeprobust.graph import utils\n",
    "from deeprobust.graph.global_attack import BaseAttack\n",
    "\n",
    "from dataset import load_graph_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fe433d6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BaseMeta(BaseAttack):\n",
    "    \"\"\"Abstract base class for meta attack. Adversarial Attacks on Graph Neural\n",
    "    Networks via Meta Learning, ICLR 2019,\n",
    "    https://openreview.net/pdf?id=Bylnx209YX\n",
    "    Parameters\n",
    "    ----------\n",
    "    model :\n",
    "        model to attack. Default `None`.\n",
    "    nnodes : int\n",
    "        number of nodes in the input graph\n",
    "    lambda_ : float\n",
    "        lambda_ is used to weight the two objectives in Eq. (10) in the paper.\n",
    "    feature_shape : tuple\n",
    "        shape of the input node features\n",
    "    attack_structure : bool\n",
    "        whether to attack graph structure\n",
    "    attack_features : bool\n",
    "        whether to attack node features\n",
    "    undirected : bool\n",
    "        whether the graph is undirected\n",
    "    device: str\n",
    "        'cpu' or 'cuda'\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, model=None, nnodes=None, feature_shape=None, lambda_=0.5, attack_structure=True, attack_features=False, undirected=True, device='cpu', adj_mask = None):\n",
    "\n",
    "        super(BaseMeta, self).__init__(model, nnodes, attack_structure, attack_features, device)\n",
    "        self.lambda_ = lambda_\n",
    "\n",
    "        assert attack_features or attack_structure, 'attack_features or attack_structure cannot be both False'\n",
    "\n",
    "        self.modified_adj = None\n",
    "        self.modified_features = None\n",
    "\n",
    "        if attack_structure:\n",
    "            self.undirected = undirected\n",
    "            assert nnodes is not None, 'Please give nnodes='\n",
    "            self.adj_changes = Parameter(torch.FloatTensor(nnodes, nnodes))\n",
    "            self.adj_changes.data.fill_(0)\n",
    "            self.adj_changes.register_hook(adj_0)\n",
    "\n",
    "        if attack_features:\n",
    "            assert feature_shape is not None, 'Please give feature_shape='\n",
    "            self.feature_changes = Parameter(torch.FloatTensor(feature_shape))\n",
    "            self.feature_changes.data.fill_(0)\n",
    "            self.feature_changes.register_hook(adj_mask)\n",
    "\n",
    "        self.with_relu = model.with_relu\n",
    "\n",
    "    def attack(self, adj, labels, n_perturbations):\n",
    "        pass\n",
    "\n",
    "    def get_modified_adj(self, ori_adj):\n",
    "        adj_changes_square = self.adj_changes - torch.diag(torch.diag(self.adj_changes, 0))\n",
    "        # ind = np.diag_indices(self.adj_changes.shape[0]) # this line seems useless\n",
    "        if self.undirected:\n",
    "            adj_changes_square = adj_changes_square + torch.transpose(adj_changes_square, 1, 0)\n",
    "        adj_changes_square = torch.clamp(adj_changes_square, -1, 1)\n",
    "        modified_adj = adj_changes_square + ori_adj\n",
    "        return modified_adj\n",
    "\n",
    "    def get_modified_features(self, ori_features):\n",
    "        return ori_features + self.feature_changes\n",
    "\n",
    "    def filter_potential_singletons(self, modified_adj):\n",
    "        \"\"\"\n",
    "        Computes a mask for entries potentially leading to singleton nodes, i.e. one of the two nodes corresponding to\n",
    "        the entry have degree 1 and there is an edge between the two nodes.\n",
    "        \"\"\"\n",
    "\n",
    "        degrees = modified_adj.sum(0)\n",
    "        degree_one = (degrees == 1)\n",
    "        resh = degree_one.repeat(modified_adj.shape[0], 1).float()\n",
    "        l_and = resh * modified_adj\n",
    "        if self.undirected:\n",
    "            l_and = l_and + l_and.t()\n",
    "        flat_mask = 1 - l_and\n",
    "        return flat_mask\n",
    "\n",
    "    def self_training_label(self, labels, idx_train):\n",
    "        # Predict the labels of the unlabeled nodes to use them for self-training.\n",
    "        output = self.surrogate.output\n",
    "        labels_self_training = output.argmax(1)\n",
    "        labels_self_training[idx_train] = labels[idx_train]\n",
    "        return labels_self_training\n",
    "\n",
    "\n",
    "    def log_likelihood_constraint(self, modified_adj, ori_adj, ll_cutoff):\n",
    "        \"\"\"\n",
    "        Computes a mask for entries that, if the edge corresponding to the entry is added/removed, would lead to the\n",
    "        log likelihood constraint to be violated.\n",
    "        Note that different data type (float, double) can effect the final results.\n",
    "        \"\"\"\n",
    "        t_d_min = torch.tensor(2.0).to(self.device)\n",
    "        if self.undirected:\n",
    "            t_possible_edges = np.array(np.triu(np.ones((self.nnodes, self.nnodes)), k=1).nonzero()).T\n",
    "        else:\n",
    "            t_possible_edges = np.array((np.ones((self.nnodes, self.nnodes)) - np.eye(self.nnodes)).nonzero()).T\n",
    "        allowed_mask, current_ratio = utils.likelihood_ratio_filter(t_possible_edges,\n",
    "                                                                    modified_adj,\n",
    "                                                                    ori_adj, t_d_min,\n",
    "                                                                    ll_cutoff, undirected=self.undirected)\n",
    "        return allowed_mask, current_ratio\n",
    "\n",
    "    def get_adj_score(self, adj_grad, modified_adj, ori_adj, ll_constraint, ll_cutoff):\n",
    "#         adj_meta_grad = adj_grad * (-2 * modified_adj + 1) \n",
    "        adj_meta_grad = adj_grad * (-modified_adj)\n",
    "        # Make sure that the minimum entry is 0.\n",
    "        adj_meta_grad -= adj_meta_grad.min()\n",
    "        # Filter self-loops\n",
    "        adj_meta_grad -= torch.diag(torch.diag(adj_meta_grad, 0))\n",
    "        # # Set entries to 0 that could lead to singleton nodes.\n",
    "        singleton_mask = self.filter_potential_singletons(modified_adj)\n",
    "        adj_meta_grad = adj_meta_grad *  singleton_mask\n",
    "\n",
    "        if ll_constraint:\n",
    "            allowed_mask, self.ll_ratio = self.log_likelihood_constraint(modified_adj, ori_adj, ll_cutoff)\n",
    "            allowed_mask = allowed_mask.to(self.device)\n",
    "            adj_meta_grad = adj_meta_grad * allowed_mask\n",
    "        return adj_meta_grad\n",
    "\n",
    "    def get_feature_score(self, feature_grad, modified_features):\n",
    "        feature_meta_grad = feature_grad * (-2 * modified_features + 1)\n",
    "        feature_meta_grad -= feature_meta_grad.min()\n",
    "        return feature_meta_grad\n",
    "\n",
    "\n",
    "class Metattack(BaseMeta):\n",
    "\n",
    "    def __init__(self, model, nnodes, feature_shape=None, attack_structure=True, attack_features=False, undirected=True, device='cpu', with_bias=False, lambda_=0.5, train_iters=100, lr=0.1, momentum=0.9):\n",
    "\n",
    "        super(Metattack, self).__init__(model, nnodes, feature_shape, lambda_, attack_structure, attack_features, undirected, device)\n",
    "        self.momentum = momentum\n",
    "        self.lr = lr\n",
    "        self.train_iters = train_iters\n",
    "        self.with_bias = with_bias\n",
    "\n",
    "        self.weights = []\n",
    "        self.biases = []\n",
    "        self.w_velocities = []\n",
    "        self.b_velocities = []\n",
    "\n",
    "        self.hidden_sizes = self.surrogate.hidden_sizes\n",
    "        self.nfeat = self.surrogate.nfeat\n",
    "        self.nclass = self.surrogate.nclass\n",
    "\n",
    "        previous_size = self.nfeat\n",
    "        for ix, nhid in enumerate(self.hidden_sizes):\n",
    "            weight = Parameter(torch.FloatTensor(previous_size, nhid).to(device))\n",
    "            w_velocity = torch.zeros(weight.shape).to(device)\n",
    "            self.weights.append(weight)\n",
    "            self.w_velocities.append(w_velocity)\n",
    "\n",
    "            if self.with_bias:\n",
    "                bias = Parameter(torch.FloatTensor(nhid).to(device))\n",
    "                b_velocity = torch.zeros(bias.shape).to(device)\n",
    "                self.biases.append(bias)\n",
    "                self.b_velocities.append(b_velocity)\n",
    "\n",
    "            previous_size = nhid\n",
    "\n",
    "        output_weight = Parameter(torch.FloatTensor(previous_size, self.nclass).to(device))\n",
    "        output_w_velocity = torch.zeros(output_weight.shape).to(device)\n",
    "        self.weights.append(output_weight)\n",
    "        self.w_velocities.append(output_w_velocity)\n",
    "\n",
    "        if self.with_bias:\n",
    "            output_bias = Parameter(torch.FloatTensor(self.nclass).to(device))\n",
    "            output_b_velocity = torch.zeros(output_bias.shape).to(device)\n",
    "            self.biases.append(output_bias)\n",
    "            self.b_velocities.append(output_b_velocity)\n",
    "\n",
    "        self._initialize()\n",
    "\n",
    "    def _initialize(self):\n",
    "        for w, v in zip(self.weights, self.w_velocities):\n",
    "            stdv = 1. / math.sqrt(w.size(1))\n",
    "            w.data.uniform_(-stdv, stdv)\n",
    "            v.data.fill_(0)\n",
    "\n",
    "        if self.with_bias:\n",
    "            for b, v in zip(self.biases, self.b_velocities):\n",
    "                stdv = 1. / math.sqrt(w.size(1))\n",
    "                b.data.uniform_(-stdv, stdv)\n",
    "                v.data.fill_(0)\n",
    "\n",
    "    def inner_train(self, features, adj_norm, idx_train, idx_unlabeled, labels):\n",
    "        self._initialize()\n",
    "        pass\n",
    "\n",
    "        for ix in range(len(self.hidden_sizes) + 1):\n",
    "            self.weights[ix] = self.weights[ix].detach()\n",
    "            self.weights[ix].requires_grad = True\n",
    "            self.w_velocities[ix] = self.w_velocities[ix].detach()\n",
    "            self.w_velocities[ix].requires_grad = True\n",
    "\n",
    "            if self.with_bias:\n",
    "                self.biases[ix] = self.biases[ix].detach()\n",
    "                self.biases[ix].requires_grad = True\n",
    "                self.b_velocities[ix] = self.b_velocities[ix].detach()\n",
    "                self.b_velocities[ix].requires_grad = True\n",
    "\n",
    "        for j in range(self.train_iters):\n",
    "            hidden = features\n",
    "            for ix, w in enumerate(self.weights):\n",
    "                b = self.biases[ix] if self.with_bias else 0\n",
    "                if self.sparse_features:\n",
    "                    hidden = adj_norm @ torch.spmm(hidden, w) + b\n",
    "                else:\n",
    "                    hidden = adj_norm @ hidden @ w + b\n",
    "\n",
    "                if self.with_relu and ix != len(self.weights) - 1:\n",
    "                    hidden = F.relu(hidden)\n",
    "\n",
    "            output = F.log_softmax(hidden, dim=1)\n",
    "            loss_labeled = F.nll_loss(output[idx_train], labels[idx_train])\n",
    "\n",
    "            weight_grads = torch.autograd.grad(loss_labeled, self.weights, create_graph=True)\n",
    "            self.w_velocities = [self.momentum * v + g for v, g in zip(self.w_velocities, weight_grads)]\n",
    "            if self.with_bias:\n",
    "                bias_grads = torch.autograd.grad(loss_labeled, self.biases, create_graph=True)\n",
    "                self.b_velocities = [self.momentum * v + g for v, g in zip(self.b_velocities, bias_grads)]\n",
    "\n",
    "            self.weights = [w - self.lr * v for w, v in zip(self.weights, self.w_velocities)]\n",
    "            if self.with_bias:\n",
    "                self.biases = [b - self.lr * v for b, v in zip(self.biases, self.b_velocities)]\n",
    "\n",
    "    def get_meta_grad(self, features, adj_norm, idx_train, idx_unlabeled, labels, labels_self_training):\n",
    "\n",
    "        hidden = features\n",
    "        for ix, w in enumerate(self.weights):\n",
    "            b = self.biases[ix] if self.with_bias else 0\n",
    "            if self.sparse_features:\n",
    "                hidden = adj_norm @ torch.spmm(hidden, w) + b\n",
    "            else:\n",
    "                hidden = adj_norm @ hidden @ w + b\n",
    "            if self.with_relu and ix != len(self.weights) - 1:\n",
    "                hidden = F.relu(hidden)\n",
    "\n",
    "        output = F.log_softmax(hidden, dim=1)\n",
    "\n",
    "        loss_labeled = F.nll_loss(output[idx_train], labels[idx_train])\n",
    "        loss_unlabeled = F.nll_loss(output[idx_unlabeled], labels_self_training[idx_unlabeled])\n",
    "        loss_test_val = F.nll_loss(output[idx_unlabeled], labels[idx_unlabeled])\n",
    "\n",
    "        if self.lambda_ == 1:\n",
    "            attack_loss = loss_labeled\n",
    "        elif self.lambda_ == 0:\n",
    "            attack_loss = loss_unlabeled\n",
    "        else:\n",
    "            attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled\n",
    "\n",
    "        print('GCN loss on unlabled data: {}'.format(loss_test_val.item()))\n",
    "        print('GCN acc on unlabled data: {}'.format(utils.accuracy(output[idx_unlabeled], labels[idx_unlabeled]).item()))\n",
    "        print('attack loss: {}'.format(attack_loss.item()))\n",
    "\n",
    "        adj_grad, feature_grad = None, None\n",
    "        if self.attack_structure:\n",
    "            adj_grad = torch.autograd.grad(attack_loss, self.adj_changes, retain_graph=True)[0]\n",
    "        if self.attack_features:\n",
    "            feature_grad = torch.autograd.grad(attack_loss, self.feature_changes, retain_graph=True)[0]\n",
    "        return adj_grad, feature_grad\n",
    "\n",
    "    def attack(self, ori_features, ori_adj, labels, idx_train, idx_unlabeled, n_perturbations, ll_constraint=True, ll_cutoff=0.004):\n",
    "        \"\"\"Generate n_perturbations on the input graph.\n",
    "        Parameters\n",
    "        ----------\n",
    "        ori_features :\n",
    "            Original (unperturbed) node feature matrix\n",
    "        ori_adj :\n",
    "            Original (unperturbed) adjacency matrix\n",
    "        labels :\n",
    "            node labels\n",
    "        idx_train :\n",
    "            node training indices\n",
    "        idx_unlabeled:\n",
    "            unlabeled nodes indices\n",
    "        n_perturbations : int\n",
    "            Number of perturbations on the input graph. Perturbations could\n",
    "            be edge removals/additions or feature removals/additions.\n",
    "        ll_constraint: bool\n",
    "            whether to exert the likelihood ratio test constraint\n",
    "        ll_cutoff : float\n",
    "            The critical value for the likelihood ratio test of the power law distributions.\n",
    "            See the Chi square distribution with one degree of freedom. Default value 0.004\n",
    "            corresponds to a p-value of roughly 0.95. It would be ignored if `ll_constraint`\n",
    "            is False.\n",
    "        \"\"\"\n",
    "\n",
    "        self.sparse_features = sp.issparse(ori_features)\n",
    "        ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device)\n",
    "        labels_self_training = self.self_training_label(labels, idx_train)\n",
    "        modified_adj = ori_adj\n",
    "        modified_features = ori_features\n",
    "\n",
    "        for i in tqdm(range(n_perturbations), desc=\"Perturbing graph\"):\n",
    "            if self.attack_structure:\n",
    "                modified_adj = self.get_modified_adj(ori_adj)\n",
    "\n",
    "            if self.attack_features:\n",
    "                modified_features = ori_features + self.feature_changes\n",
    "\n",
    "            adj_norm = utils.normalize_adj_tensor(modified_adj)\n",
    "            self.inner_train(modified_features, adj_norm, idx_train, idx_unlabeled, labels)\n",
    "\n",
    "            adj_grad, feature_grad = self.get_meta_grad(modified_features, adj_norm, idx_train, idx_unlabeled, labels, labels_self_training)\n",
    "\n",
    "            adj_meta_score = torch.tensor(0.0).to(self.device)\n",
    "            feature_meta_score = torch.tensor(0.0).to(self.device)\n",
    "            if self.attack_structure:\n",
    "                adj_meta_score = self.get_adj_score(adj_grad, modified_adj, ori_adj, ll_constraint, ll_cutoff)\n",
    "            if self.attack_features:\n",
    "                feature_meta_score = self.get_feature_score(feature_grad, modified_features)\n",
    "\n",
    "            if adj_meta_score.max() >= feature_meta_score.max():\n",
    "                adj_meta_argmax = torch.argmax(adj_meta_score)\n",
    "                row_idx, col_idx = utils.unravel_index(adj_meta_argmax, ori_adj.shape)\n",
    "                \n",
    "                self.adj_changes.data[row_idx][col_idx] += -modified_adj[row_idx][col_idx]\n",
    "                \n",
    "#                 self.adj_changes.data[row_idx][col_idx] += (-2 * modified_adj[row_idx][col_idx] + 1)\n",
    "                if self.undirected:\n",
    "                    self.adj_changes.data[col_idx][row_idx] += -modified_adj[row_idx][col_idx]\n",
    "#                     self.adj_changes.data[col_idx][row_idx] += (-2 * modified_adj[row_idx][col_idx] + 1)\n",
    "            else:\n",
    "                feature_meta_argmax = torch.argmax(feature_meta_score)\n",
    "                row_idx, col_idx = utils.unravel_index(feature_meta_argmax, ori_features.shape)\n",
    "                self.feature_changes.data[row_idx][col_idx] += (-2 * modified_features[row_idx][col_idx] + 1)\n",
    "                \n",
    "            del adj_norm\n",
    "            \n",
    "        if self.attack_structure:\n",
    "            self.modified_adj = self.get_modified_adj(ori_adj).detach()\n",
    "        if self.attack_features:\n",
    "            self.modified_features = self.get_modified_features(ori_features).detach()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0fb48638",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zizhang/anaconda3/lib/python3.8/site-packages/dgl/heterograph.py:3719: DGLWarning: DGLGraph.adjacency_matrix_scipy is deprecated. Please replace it with:\n",
      "\n",
      "\tDGLGraph.adjacency_matrix(transpose, scipy_fmt=\"csr\").\n",
      "\n",
      "  dgl_warning('DGLGraph.adjacency_matrix_scipy is deprecated. '\n"
     ]
    }
   ],
   "source": [
    "# data_set = 'citeseer'\n",
    "# num_remove = 227\n",
    "# rate = 0.05\n",
    "data_set = 'cora'\n",
    "num_remove = 51\n",
    "rate = 0.01\n",
    "\n",
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(data_set)\n",
    "\n",
    "adj = graph.adjacency_matrix_scipy()\n",
    "features = sparse.csr_matrix(feat.numpy())\n",
    "labels = labels.numpy().astype(int)\n",
    "idx_train = np.where(train_mask == 1)[0]\n",
    "idx_val = np.where(val_mask == 1)[0]\n",
    "idx_test = np.where(test_mask == 1)[0]\n",
    "idx_unlabeled = np.union1d(idx_val, idx_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "feccba0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_0 = torch.tensor(adj.toarray())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e081b8fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data = Dataset(root='/tmp/', name='cora')\n",
    "# adj, features, labels = data.adj, data.features, data.labels\n",
    "# idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test\n",
    "# idx_unlabeled = np.union1d(idx_val, idx_test)\n",
    "# idx_unlabeled = np.union1d(idx_val, idx_test)\n",
    "# Setup Surrogate model\n",
    "\n",
    "# device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "device = torch.device(\"cpu\")\n",
    "surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,\n",
    "                    nhid=16, dropout=0, with_relu=False, with_bias=False, device=device)\n",
    "surrogate = surrogate.to(device)\n",
    "surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)\n",
    "# Setup Attack Model\n",
    "model = Metattack(surrogate, nnodes=adj.shape[0], feature_shape=features.shape,\n",
    "            attack_structure=True, attack_features=False, device=device, lambda_=0)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8530e903",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:   0%|                                  | 0/51 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.929415225982666\n",
      "GCN acc on unlabled data: 0.4\n",
      "attack loss: 1.9275988340377808\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zizhang/anaconda3/lib/python3.8/site-packages/deeprobust/graph/utils.py:542: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  rows = index // array_shape[1]\n",
      "\r",
      "Perturbing graph:   2%|▌                         | 1/51 [00:05<04:19,  5.19s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.925334095954895\n",
      "GCN acc on unlabled data: 0.416\n",
      "attack loss: 1.9249495267868042\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:   4%|█                         | 2/51 [00:10<04:07,  5.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.922761082649231\n",
      "GCN acc on unlabled data: 0.444\n",
      "attack loss: 1.920738935470581\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:   6%|█▌                        | 3/51 [00:14<03:57,  4.94s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.924565315246582\n",
      "GCN acc on unlabled data: 0.37\n",
      "attack loss: 1.921918272972107\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:   8%|██                        | 4/51 [00:20<03:54,  4.99s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9311691522598267\n",
      "GCN acc on unlabled data: 0.38533333333333336\n",
      "attack loss: 1.9288384914398193\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  10%|██▌                       | 5/51 [00:25<03:50,  5.02s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9303354024887085\n",
      "GCN acc on unlabled data: 0.4033333333333333\n",
      "attack loss: 1.9288017749786377\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  12%|███                       | 6/51 [00:30<03:49,  5.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9324870109558105\n",
      "GCN acc on unlabled data: 0.322\n",
      "attack loss: 1.930759072303772\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  14%|███▌                      | 7/51 [00:35<03:43,  5.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9123566150665283\n",
      "GCN acc on unlabled data: 0.5613333333333334\n",
      "attack loss: 1.9112718105316162\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  16%|████                      | 8/51 [00:40<03:35,  5.01s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.925897240638733\n",
      "GCN acc on unlabled data: 0.45\n",
      "attack loss: 1.9241998195648193\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  18%|████▌                     | 9/51 [00:45<03:27,  4.94s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9307804107666016\n",
      "GCN acc on unlabled data: 0.26666666666666666\n",
      "attack loss: 1.929033637046814\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  20%|████▉                    | 10/51 [00:49<03:21,  4.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.915075659751892\n",
      "GCN acc on unlabled data: 0.42333333333333334\n",
      "attack loss: 1.9141992330551147\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  22%|█████▍                   | 11/51 [00:54<03:15,  4.88s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.934322476387024\n",
      "GCN acc on unlabled data: 0.31866666666666665\n",
      "attack loss: 1.9327548742294312\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  24%|█████▉                   | 12/51 [00:59<03:09,  4.86s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9285029172897339\n",
      "GCN acc on unlabled data: 0.398\n",
      "attack loss: 1.9258671998977661\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  25%|██████▎                  | 13/51 [01:04<03:03,  4.84s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9248052835464478\n",
      "GCN acc on unlabled data: 0.44\n",
      "attack loss: 1.92327082157135\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  27%|██████▊                  | 14/51 [01:09<02:58,  4.83s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.923409342765808\n",
      "GCN acc on unlabled data: 0.4046666666666667\n",
      "attack loss: 1.9197838306427002\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  29%|███████▎                 | 15/51 [01:13<02:53,  4.82s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9232165813446045\n",
      "GCN acc on unlabled data: 0.464\n",
      "attack loss: 1.9227714538574219\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  31%|███████▊                 | 16/51 [01:19<02:52,  4.92s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.927187204360962\n",
      "GCN acc on unlabled data: 0.41133333333333333\n",
      "attack loss: 1.9255775213241577\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  33%|████████▎                | 17/51 [01:24<02:50,  5.01s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.92559015750885\n",
      "GCN acc on unlabled data: 0.414\n",
      "attack loss: 1.9234133958816528\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  35%|████████▊                | 18/51 [01:29<02:47,  5.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9239634275436401\n",
      "GCN acc on unlabled data: 0.36133333333333334\n",
      "attack loss: 1.921826958656311\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  37%|█████████▎               | 19/51 [01:34<02:43,  5.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9262465238571167\n",
      "GCN acc on unlabled data: 0.3893333333333333\n",
      "attack loss: 1.9245606660842896\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  39%|█████████▊               | 20/51 [01:39<02:35,  5.03s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9259014129638672\n",
      "GCN acc on unlabled data: 0.4633333333333333\n",
      "attack loss: 1.9239341020584106\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  41%|██████████▎              | 21/51 [01:44<02:29,  4.97s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9303821325302124\n",
      "GCN acc on unlabled data: 0.37066666666666664\n",
      "attack loss: 1.9281840324401855\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  43%|██████████▊              | 22/51 [01:49<02:23,  4.94s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9241212606430054\n",
      "GCN acc on unlabled data: 0.4\n",
      "attack loss: 1.9208382368087769\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  45%|███████████▎             | 23/51 [01:54<02:19,  4.97s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9275639057159424\n",
      "GCN acc on unlabled data: 0.324\n",
      "attack loss: 1.9261317253112793\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  47%|███████████▊             | 24/51 [01:59<02:14,  4.97s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9270457029342651\n",
      "GCN acc on unlabled data: 0.3373333333333333\n",
      "attack loss: 1.9253835678100586\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  49%|████████████▎            | 25/51 [02:04<02:09,  4.98s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9324660301208496\n",
      "GCN acc on unlabled data: 0.25466666666666665\n",
      "attack loss: 1.9299601316452026\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  51%|████████████▋            | 26/51 [02:09<02:05,  5.03s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9212052822113037\n",
      "GCN acc on unlabled data: 0.4686666666666667\n",
      "attack loss: 1.9207406044006348\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  53%|█████████████▏           | 27/51 [02:14<01:59,  5.00s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9257272481918335\n",
      "GCN acc on unlabled data: 0.38133333333333336\n",
      "attack loss: 1.9228146076202393\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  55%|█████████████▋           | 28/51 [02:19<01:54,  4.98s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9312303066253662\n",
      "GCN acc on unlabled data: 0.3233333333333333\n",
      "attack loss: 1.9289153814315796\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  57%|██████████████▏          | 29/51 [02:24<01:49,  4.96s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9254690408706665\n",
      "GCN acc on unlabled data: 0.396\n",
      "attack loss: 1.925419569015503\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  59%|██████████████▋          | 30/51 [02:29<01:44,  4.96s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.927166223526001\n",
      "GCN acc on unlabled data: 0.444\n",
      "attack loss: 1.9260003566741943\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  61%|███████████████▏         | 31/51 [02:33<01:38,  4.93s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9239091873168945\n",
      "GCN acc on unlabled data: 0.422\n",
      "attack loss: 1.921765923500061\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  63%|███████████████▋         | 32/51 [02:38<01:33,  4.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9304267168045044\n",
      "GCN acc on unlabled data: 0.398\n",
      "attack loss: 1.9280604124069214\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  65%|████████████████▏        | 33/51 [02:43<01:28,  4.92s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9285930395126343\n",
      "GCN acc on unlabled data: 0.364\n",
      "attack loss: 1.9277514219284058\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  67%|████████████████▋        | 34/51 [02:48<01:23,  4.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9252110719680786\n",
      "GCN acc on unlabled data: 0.366\n",
      "attack loss: 1.9234908819198608\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  69%|█████████████████▏       | 35/51 [02:53<01:18,  4.90s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9295194149017334\n",
      "GCN acc on unlabled data: 0.432\n",
      "attack loss: 1.927905559539795\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  71%|█████████████████▋       | 36/51 [02:58<01:13,  4.90s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9308432340621948\n",
      "GCN acc on unlabled data: 0.36866666666666664\n",
      "attack loss: 1.9299235343933105\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  73%|██████████████████▏      | 37/51 [03:03<01:08,  4.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9274826049804688\n",
      "GCN acc on unlabled data: 0.376\n",
      "attack loss: 1.9267438650131226\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  75%|██████████████████▋      | 38/51 [03:08<01:03,  4.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9252082109451294\n",
      "GCN acc on unlabled data: 0.42\n",
      "attack loss: 1.9219974279403687\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  76%|███████████████████      | 39/51 [03:13<00:58,  4.90s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9293347597122192\n",
      "GCN acc on unlabled data: 0.3466666666666667\n",
      "attack loss: 1.9258058071136475\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  78%|███████████████████▌     | 40/51 [03:18<00:54,  4.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9292950630187988\n",
      "GCN acc on unlabled data: 0.3466666666666667\n",
      "attack loss: 1.9282150268554688\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  80%|████████████████████     | 41/51 [03:23<00:49,  4.92s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9226198196411133\n",
      "GCN acc on unlabled data: 0.3893333333333333\n",
      "attack loss: 1.92319917678833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  82%|████████████████████▌    | 42/51 [03:27<00:44,  4.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9295105934143066\n",
      "GCN acc on unlabled data: 0.3526666666666667\n",
      "attack loss: 1.927152156829834\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  84%|█████████████████████    | 43/51 [03:32<00:39,  4.90s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9284571409225464\n",
      "GCN acc on unlabled data: 0.3606666666666667\n",
      "attack loss: 1.9265474081039429\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  86%|█████████████████████▌   | 44/51 [03:37<00:34,  4.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9309136867523193\n",
      "GCN acc on unlabled data: 0.364\n",
      "attack loss: 1.92872953414917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  88%|██████████████████████   | 45/51 [03:42<00:29,  4.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9226818084716797\n",
      "GCN acc on unlabled data: 0.44\n",
      "attack loss: 1.9215643405914307\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  90%|██████████████████████▌  | 46/51 [03:47<00:24,  4.95s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9278286695480347\n",
      "GCN acc on unlabled data: 0.37666666666666665\n",
      "attack loss: 1.9264802932739258\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  92%|███████████████████████  | 47/51 [03:52<00:19,  4.96s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9277111291885376\n",
      "GCN acc on unlabled data: 0.326\n",
      "attack loss: 1.9250712394714355\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  94%|███████████████████████▌ | 48/51 [03:57<00:15,  5.02s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9231277704238892\n",
      "GCN acc on unlabled data: 0.44066666666666665\n",
      "attack loss: 1.9211496114730835\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  96%|████████████████████████ | 49/51 [04:03<00:10,  5.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.9303103685379028\n",
      "GCN acc on unlabled data: 0.36933333333333335\n",
      "attack loss: 1.9279593229293823\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Perturbing graph:  98%|████████████████████████▌| 50/51 [04:08<00:05,  5.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GCN loss on unlabled data: 1.930135726928711\n",
      "GCN acc on unlabled data: 0.4226666666666667\n",
      "attack loss: 1.929218053817749\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Perturbing graph: 100%|█████████████████████████| 51/51 [04:13<00:00,  4.97s/it]\n"
     ]
    }
   ],
   "source": [
    "# Attack\n",
    "model.attack(features, adj, labels, idx_train, idx_unlabeled, n_perturbations=num_remove, ll_constraint=False)\n",
    "modified_adj = model.modified_adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "03476781",
   "metadata": {},
   "outputs": [],
   "source": [
    "# adj = adj.toarray()\n",
    "modified_adj = modified_adj.detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a0270af",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "754c3bb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "f_l = np.where(adj != modified_adj.numpy())[0]\n",
    "t_l = np.where(adj != modified_adj.numpy())[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "159b7e6b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(102,)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f_l.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2fd669fd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "102.0"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(adj - modified_adj.numpy())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d0e9b49d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "102.0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(adj - modified_adj.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8300a0c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "857041ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "df =  pd.DataFrame([f_l.astype(int), t_l.astype(int)]).T\n",
    "df.to_csv('metattack/' + data_set + '_remove_rate_' +str(rate) + '.csv', index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a488729",
   "metadata": {},
   "outputs": [],
   "source": [
    "def META_attack(data_set, rate, graph):\n",
    "    df_dice = pd.read_csv('metattack/' + data_set + '_remove_rate_' +str(rate) + '.csv')\n",
    "    remove_list = graph.edge_ids(df_dice['0'].values, df_dice['1'].values)\n",
    "    remove_list = list(remove_list)\n",
    "    train_gcn = gcn_with_edge_removal(dataset = data_set, remove_edge_index=remove_list)\n",
    "    train_gcn_acc = train_gcn.train_evaluate()\n",
    "    return train_gcn_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84d7fd4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gcn_with_edge_removal import gcn_with_edge_removal\n",
    "from sgc_with_edge_removal import sgc_with_edge_removal\n",
    "META_attack(data_set, rate, graph)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
