{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ff033e72",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "aa0d775c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.sparse as sp\n",
    "import torch\n",
    "from torch import optim\n",
    "from torch.nn import functional as F\n",
    "from torch.nn.parameter import Parameter\n",
    "from tqdm import tqdm\n",
    "from scipy import sparse\n",
    "from deeprobust.graph import utils\n",
    "from deeprobust.graph.global_attack import BaseAttack\n",
    "from dataset import load_graph_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "052d74f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gcn_with_edge_removal import gcn_with_edge_removal\n",
    "from sgc_with_edge_removal import sgc_with_edge_removal\n",
    "from gcn_with_edge_perturbation import gcn_with_edge_p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d8a572ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PGDAttack(BaseAttack):\n",
    "    \"\"\"PGD attack for graph data.\n",
    "    Parameters\n",
    "    ----------\n",
    "    model :\n",
    "        model to attack. Default `None`.\n",
    "    nnodes : int\n",
    "        number of nodes in the input graph\n",
    "    loss_type: str\n",
    "        attack loss type, chosen from ['CE', 'CW']\n",
    "    feature_shape : tuple\n",
    "        shape of the input node features\n",
    "    attack_structure : bool\n",
    "        whether to attack graph structure\n",
    "    attack_features : bool\n",
    "        whether to attack node features\n",
    "    device: str\n",
    "        'cpu' or 'cuda'\n",
    "    Examples\n",
    "    --------\n",
    "    >>> from deeprobust.graph.data import Dataset\n",
    "    >>> from deeprobust.graph.defense import GCN\n",
    "    >>> from deeprobust.graph.global_attack import PGDAttack\n",
    "    >>> from deeprobust.graph.utils import preprocess\n",
    "    >>> data = Dataset(root='/tmp/', name='cora')\n",
    "    >>> adj, features, labels = data.adj, data.features, data.labels\n",
    "    >>> adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False) # conver to tensor\n",
    "    >>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test\n",
    "    >>> # Setup Victim Model\n",
    "    >>> victim_model = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,\n",
    "                        nhid=16, dropout=0.5, weight_decay=5e-4, device='cpu').to('cpu')\n",
    "    >>> victim_model.fit(features, adj, labels, idx_train)\n",
    "    >>> # Setup Attack Model\n",
    "    >>> model = PGDAttack(model=victim_model, nnodes=adj.shape[0], loss_type='CE', device='cpu').to('cpu')\n",
    "    >>> model.attack(features, adj, labels, idx_train, n_perturbations=10)\n",
    "    >>> modified_adj = model.modified_adj\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, model=None, nnodes=None, loss_type='CE', feature_shape=None, attack_structure=True, attack_features=False, device='cpu', adj_orig = None):\n",
    "\n",
    "        super(PGDAttack, self).__init__(model, nnodes, attack_structure, attack_features, device)\n",
    "\n",
    "        assert attack_features or attack_structure, 'attack_features or attack_structure cannot be both False'\n",
    "\n",
    "        self.loss_type = loss_type\n",
    "        self.modified_adj = None\n",
    "        self.modified_features = None\n",
    "\n",
    "        if attack_structure:\n",
    "            assert nnodes is not None, 'Please give nnodes='\n",
    "            self.adj_changes = Parameter(torch.FloatTensor(int(nnodes*(nnodes-1)/2)))\n",
    "            self.adj_changes.data.fill_(0)\n",
    "            \n",
    "            \"\"\"add hook\"\"\"\n",
    "            self.adj_changes.register_hook(adj_orig)\n",
    "\n",
    "        if attack_features:\n",
    "            assert True, 'Topology Attack does not support attack feature'\n",
    "\n",
    "        self.complementary = None\n",
    "\n",
    "    def attack(self, ori_features, ori_adj, labels, idx_train, n_perturbations, epochs=5, **kwargs):\n",
    "        \"\"\"Generate perturbations on the input graph.\n",
    "        Parameters\n",
    "        ----------\n",
    "        ori_features :\n",
    "            Original (unperturbed) node feature matrix\n",
    "        ori_adj :\n",
    "            Original (unperturbed) adjacency matrix\n",
    "        labels :\n",
    "            node labels\n",
    "        idx_train :\n",
    "            node training indices\n",
    "        n_perturbations : int\n",
    "            Number of perturbations on the input graph. Perturbations could\n",
    "            be edge removals/additions or feature removals/additions.\n",
    "        epochs:\n",
    "            number of training epochs\n",
    "        \"\"\"\n",
    "\n",
    "        victim_model = self.surrogate\n",
    "\n",
    "        self.sparse_features = sp.issparse(ori_features)\n",
    "#         ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device)\n",
    "\n",
    "        victim_model.eval()\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "        modified_adj = self.get_modified_adj(ori_adj)\n",
    "#         adj_norm = utils.normalize_adj_tensor(modified_adj)\n",
    "#         output = victim_model(ori_features, adj_norm)\n",
    "#         # loss = F.nll_loss(output[idx_train], labels[idx_train])\n",
    "#         loss = self._loss(output[idx_train], labels[idx_train])\n",
    "#         adj_grad = torch.autograd.grad(loss, self.adj_changes)[0]\n",
    "\n",
    "#         if self.loss_type == 'CE':\n",
    "#             lr = 200 \n",
    "#             self.adj_changes.data.add_(lr * adj_grad)\n",
    "\n",
    "#         if self.loss_type == 'CW':\n",
    "#             lr = 0.1\n",
    "#             self.adj_changes.data.add_(lr * adj_grad)\n",
    "\n",
    "#         self.projection(n_perturbations)\n",
    "\n",
    "        \n",
    "    \n",
    "#         for t in tqdm(range(epochs)):\n",
    "        print('ori', ori_adj)\n",
    "        print('mod', self.modified_adj)\n",
    "#         while torch.sum(ori_adj - self.modified_adj) < n_perturbations:\n",
    "        for t in tqdm(range(epochs)):\n",
    "#             print(ori_adj)\n",
    "            modified_adj = self.get_modified_adj(ori_adj)\n",
    "            adj_norm = utils.normalize_adj_tensor(modified_adj)\n",
    "            output = victim_model(ori_features, adj_norm)\n",
    "            # loss = F.nll_loss(output[idx_train], labels[idx_train])\n",
    "            loss = self._loss(output[idx_train], labels[idx_train])\n",
    "            adj_grad = torch.autograd.grad(loss, self.adj_changes)[0]\n",
    "\n",
    "            if self.loss_type == 'CE':\n",
    "                lr = 200 / np.sqrt(t+1)\n",
    "                self.adj_changes.data.add_(lr * adj_grad)\n",
    "\n",
    "            if self.loss_type == 'CW':\n",
    "                lr = 0.1 / np.sqrt(t+1)\n",
    "                self.adj_changes.data.add_(lr * adj_grad)\n",
    "\n",
    "            self.projection(n_perturbations)\n",
    "\n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "        self.random_sample(ori_adj, ori_features, labels, idx_train, n_perturbations)\n",
    "        self.modified_adj = self.get_modified_adj(ori_adj).detach()\n",
    "        \n",
    "        \n",
    "        \n",
    "        \"\"\"deleted check adj\"\"\"\n",
    "#         self.check_adj_tensor(self.modified_adj)\n",
    "\n",
    "\n",
    "    def random_sample(self, ori_adj, ori_features, labels, idx_train, n_perturbations):\n",
    "        K = 20\n",
    "        best_loss = -1000\n",
    "        victim_model = self.surrogate\n",
    "        victim_model.eval()\n",
    "        with torch.no_grad():\n",
    "            s = self.adj_changes.cpu().detach().numpy()\n",
    "            for i in range(K):\n",
    "                sampled = np.random.binomial(1, s)\n",
    "\n",
    "                # print(sampled.sum())\n",
    "                if sampled.sum() > n_perturbations:\n",
    "                    continue\n",
    "                self.adj_changes.data.copy_(torch.tensor(sampled))\n",
    "                modified_adj = self.get_modified_adj(ori_adj)\n",
    "                adj_norm = utils.normalize_adj_tensor(modified_adj)\n",
    "                output = victim_model(ori_features, adj_norm)\n",
    "                loss = self._loss(output[idx_train], labels[idx_train])\n",
    "                # loss = F.nll_loss(output[idx_train], labels[idx_train])\n",
    "                # print(loss)\n",
    "                if best_loss < loss:\n",
    "                    best_loss = loss\n",
    "                    best_s = sampled\n",
    "            self.adj_changes.data.copy_(torch.tensor(best_s))\n",
    "\n",
    "    def _loss(self, output, labels):\n",
    "        if self.loss_type == \"CE\":\n",
    "            loss = F.nll_loss(output, labels)\n",
    "        if self.loss_type == \"CW\":\n",
    "            onehot = utils.tensor2onehot(labels)\n",
    "            best_second_class = (output - 1000*onehot).argmax(1)\n",
    "            margin = output[np.arange(len(output)), labels] - \\\n",
    "                   output[np.arange(len(output)), best_second_class]\n",
    "            k = 0\n",
    "            loss = -torch.clamp(margin, min=k).mean()\n",
    "            # loss = torch.clamp(margin.sum()+50, min=k)\n",
    "        return loss\n",
    "\n",
    "    def projection(self, n_perturbations):\n",
    "        # projected = torch.clamp(self.adj_changes, 0, 1)\n",
    "        if torch.clamp(self.adj_changes, 0, 1).sum() > n_perturbations:\n",
    "            left = (self.adj_changes - 1).min()\n",
    "            right = self.adj_changes.max()\n",
    "            miu = self.bisection(left, right, n_perturbations, epsilon=1e-5)\n",
    "            self.adj_changes.data.copy_(torch.clamp(self.adj_changes.data - miu, min=0, max=1))\n",
    "        else:\n",
    "            self.adj_changes.data.copy_(torch.clamp(self.adj_changes.data, min=0, max=1))\n",
    "\n",
    "    def get_modified_adj(self, ori_adj):\n",
    "\n",
    "        if self.complementary is None:\n",
    "#             self.complementary = torch.zeros_like(ori_adj)\n",
    "#             self.complementary = -ori_adj\n",
    "            self.complementary = (torch.ones_like(ori_adj) - torch.eye(self.nnodes).to(self.device) - ori_adj) - ori_adj\n",
    "\n",
    "        m = torch.zeros((self.nnodes, self.nnodes)).to(self.device)\n",
    "        tril_indices = torch.tril_indices(row=self.nnodes, col=self.nnodes, offset=-1)\n",
    "\n",
    "        m[tril_indices[0], tril_indices[1]] = self.adj_changes\n",
    "        m = m + m.t()\n",
    "        modified_adj = self.complementary * m + ori_adj\n",
    "        \n",
    "        modified_adj = modified_adj * ori_adj\n",
    "        \n",
    "\n",
    "        return modified_adj\n",
    "\n",
    "    def bisection(self, a, b, n_perturbations, epsilon):\n",
    "        def func(x):\n",
    "            return torch.clamp(self.adj_changes-x, 0, 1).sum() - n_perturbations\n",
    "\n",
    "        miu = a\n",
    "        while ((b-a) >= epsilon):\n",
    "            miu = (a+b)/2\n",
    "            # Check if middle point is root\n",
    "            if (func(miu) == 0.0):\n",
    "                break\n",
    "            # Decide the side to repeat the steps\n",
    "            if (func(miu)*func(a) < 0):\n",
    "                b = miu\n",
    "            else:\n",
    "                a = miu\n",
    "        # print(\"The value of root is : \",\"%.4f\" % miu)\n",
    "        return miu\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e70072f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pubmed: \n",
    "# data_set = 'pubmed'\n",
    "num_remove1 = 443\n",
    "num_remove3 = 1330\n",
    "num_remove5 = 2216\n",
    "# rate = 0.03\n",
    "\n",
    "data_set = 'pubmed'\n",
    "\n",
    "# num_remove_001 = 52\n",
    "# num_remove_003 = 52 * 3\n",
    "# num_remove_005 = 52 * 5\n",
    "rate = [0.01, 0.03, 0.05]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2928bb7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_remove5 = 2216"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "455174c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zizhang/anaconda3/lib/python3.8/site-packages/dgl/heterograph.py:3719: DGLWarning: DGLGraph.adjacency_matrix_scipy is deprecated. Please replace it with:\n",
      "\n",
      "\tDGLGraph.adjacency_matrix(transpose, scipy_fmt=\"csr\").\n",
      "\n",
      "  dgl_warning('DGLGraph.adjacency_matrix_scipy is deprecated. '\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "\n",
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(data_set)\n",
    "\n",
    "adj = graph.adjacency_matrix_scipy()\n",
    "features = sparse.csr_matrix(feat.numpy())\n",
    "labels = labels.numpy().astype(int)\n",
    "idx_train = np.where(train_mask == 1)[0]\n",
    "idx_val = np.where(val_mask == 1)[0]\n",
    "idx_test = np.where(test_mask == 1)[0]\n",
    "idx_unlabeled = np.union1d(idx_val, idx_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "00ad5e17",
   "metadata": {},
   "outputs": [],
   "source": [
    "from deeprobust.graph.utils import preprocess\n",
    "from deeprobust.graph.defense import GCN\n",
    "adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fd002faf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.ones_like(adj) - torch.eye(graph.num_nodes())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05097d06",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f7a2e632",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "ori tensor([[1., 1., 0.,  ..., 0., 0., 0.],\n",
      "        [1., 1., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 1.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 1., 1., 0.],\n",
      "        [0., 0., 0.,  ..., 1., 1., 1.],\n",
      "        [0., 0., 0.,  ..., 0., 1., 1.]], device='cuda:0')\n",
      "mod None\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|                                                     | 0/5 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 20%|█████████                                    | 1/5 [00:05<00:20,  5.17s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 40%|██████████████████                           | 2/5 [00:10<00:15,  5.25s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 60%|███████████████████████████                  | 3/5 [00:15<00:10,  5.28s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 80%|████████████████████████████████████         | 4/5 [00:21<00:05,  5.29s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 5/5 [00:26<00:00,  5.29s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n",
      "m: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "Parameter containing:\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# device = torch.device(\"cpu\")\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "adj = adj.to(device)\n",
    "features = features.to(device)\n",
    "labels = labels.to(device)\n",
    "adj_0 = adj.reshape(-1).to(device)\n",
    "\n",
    "\n",
    "victim_model = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,\n",
    "                    nhid=16, dropout=0.0, weight_decay=5e-4, device=device)\n",
    "victim_model = victim_model.to(device)\n",
    "victim_model.fit(features, adj, labels, idx_train)\n",
    "# Setup Attack Model\n",
    "model = PGDAttack(model=victim_model, nnodes=adj.shape[0], loss_type='CE', device=device, adj_orig=adj_0).to(device)\n",
    "model = model.to(device)\n",
    "model.attack(features, adj, labels, idx_train, n_perturbations=num_remove)\n",
    "modified_adj = model.modified_adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "09763d2a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1., 0.,  ..., 0., 0., 0.],\n",
       "        [1., 1., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 1.,  ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0.,  ..., 1., 1., 0.],\n",
       "        [0., 0., 0.,  ..., 1., 1., 1.],\n",
       "        [0., 0., 0.,  ..., 0., 1., 1.]], device='cuda:0')"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.modified_adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fb99fb51",
   "metadata": {},
   "outputs": [],
   "source": [
    "adj = adj.detach().cpu()\n",
    "modified_adj = modified_adj.detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "54e6449f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.sum(adj > )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f630b942",
   "metadata": {},
   "outputs": [],
   "source": [
    "f_l_remove = np.where(adj.numpy() > modified_adj.numpy())[0]\n",
    "t_l_remove = np.where(adj.numpy() > modified_adj.numpy())[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "65449912",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(f_l_remove)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6a4b2d41",
   "metadata": {},
   "outputs": [],
   "source": [
    "f_l_add = np.where(adj.numpy() < modified_adj.numpy())[0]\n",
    "t_l_add = np.where(adj.numpy() < modified_adj.numpy())[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "cc547e48",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([], dtype=int64)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f_l_add"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6e6a9e1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_candidate_remove_list(f_l_candidate,t_l_candidate, graph):\n",
    "    candidiate_remove_list = []\n",
    "    f_l = graph.edges()[0].numpy()\n",
    "    t_l = graph.edges()[1].numpy()\n",
    "\n",
    "    for i in range(len(f_l_candidate)):\n",
    "        temp_f_l_index = np.where(f_l == f_l_candidate[i])[0]\n",
    "        temp_t_l_index = np.where(t_l == t_l_candidate[i])[0]\n",
    "\n",
    "        candidiate_remove_list.extend(list(set(temp_f_l_index) & set(temp_t_l_index )))\n",
    "    return candidiate_remove_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "c1406fd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "candidate_remove_node_index = generate_candidate_remove_list(f_l_remove, t_l_remove, graph)\n",
    "temp_add_f = f_l_add \n",
    "temp_add_t = t_l_add"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "3de98f93",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "112441\n",
      "\n",
      "Test accuracy 20.40%\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.204"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp_gcn = gcn_with_edge_p(dataset=data_set, remove_edge_index=candidate_remove_node_index, \n",
    "                           add_from_index = temp_add_f , add_to_index = temp_add_t)\n",
    "temp_gcn.train_evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dce2c7d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b3abf4bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ori_adj = adj\n",
    "# (torch.ones_like(ori_adj) - torch.eye(2708) - ori_adj) - ori_adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "69612acf",
   "metadata": {},
   "outputs": [],
   "source": [
    "df =  pd.DataFrame([f_l.astype(int), t_l.astype(int)]).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2e1aa5cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('PGD/new_' + data_set + '_remove_rate_' +str(rate_005) + '.csv', index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "0aebe40d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pd.read_csv('PGD/' + data_set + '_remove_rate_' +str(rate) + '.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "4fd84550",
   "metadata": {},
   "outputs": [],
   "source": [
    "def PGD_attack(data_set, rate, graph, df):\n",
    "#     df_dice = pd.read_csv('PGD/new_' + data_set + '_remove_rate_' +str(rate) + '.csv')\n",
    "    df_dice = df\n",
    "    remove_list = graph.edge_ids(df_dice[0].values, df_dice[1].values)\n",
    "    remove_list = list(remove_list)[0:]\n",
    "    train_gcn = gcn_with_edge_removal(dataset = data_set, remove_edge_index=remove_list)\n",
    "    train_gcn_acc = train_gcn.train_evaluate()\n",
    "    return train_gcn_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "a1f7db1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pd.read_csv('PGD/new_' + 'citeseer' + '_remove_rate_' +str(0.05) + '.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "07929939",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = df[0].values\n",
    "b = df[1].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "fae229b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4416, 4416)"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(a), len(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "8d69e532",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([    6,    19,    22, ..., 19681, 19681, 19699])"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "a01074cd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([13042, 11111, 11915, ..., 19635, 19699, 19681])"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "e33795df",
   "metadata": {},
   "outputs": [
    {
     "ename": "DGLError",
     "evalue": "Error: (6, 13042) does not form a valid edge.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mDGLError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_18663/2656219010.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgraph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0medge_ids\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m13042\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/dgl/heterograph.py\u001b[0m in \u001b[0;36medge_ids\u001b[0;34m(self, u, v, force_multi, return_uv, etype)\u001b[0m\n\u001b[1;32m   3098\u001b[0m                 \u001b[0;31m# Raise error since some (u, v) pair is not a valid edge.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3099\u001b[0m                 \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnonzero_1d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_neg_one\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3100\u001b[0;31m                 raise DGLError(\"Error: (%d, %d) does not form a valid edge.\" % (\n\u001b[0m\u001b[1;32m   3101\u001b[0m                     \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgather_row\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3102\u001b[0m                     F.as_scalar(F.gather_row(v, idx))))\n",
      "\u001b[0;31mDGLError\u001b[0m: Error: (6, 13042) does not form a valid edge."
     ]
    }
   ],
   "source": [
    "graph.edge_ids(6, 13042)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "865b7bf6",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "only one element tensors can be converted to Python scalars",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_18663/124428566.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# PGD_attack_white_box(data_set, rate, graph)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mPGD_attack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_set\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.05\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/tmp/ipykernel_18663/255080488.py\u001b[0m in \u001b[0;36mPGD_attack\u001b[0;34m(data_set, rate, graph, df)\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;31m#     df_dice = pd.read_csv('PGD/new_' + data_set + '_remove_rate_' +str(rate) + '.csv')\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mdf_dice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m     \u001b[0mremove_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0medge_ids\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdf_dice\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdf_dice\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m     \u001b[0mremove_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mremove_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0mtrain_gcn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgcn_with_edge_removal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata_set\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mremove_edge_index\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mremove_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/dgl/heterograph.py\u001b[0m in \u001b[0;36medge_ids\u001b[0;34m(self, u, v, force_multi, return_uv, etype)\u001b[0m\n\u001b[1;32m   3099\u001b[0m                 \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnonzero_1d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_neg_one\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3100\u001b[0m                 raise DGLError(\"Error: (%d, %d) does not form a valid edge.\" % (\n\u001b[0;32m-> 3101\u001b[0;31m                     \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgather_row\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3102\u001b[0m                     F.as_scalar(F.gather_row(v, idx))))\n\u001b[1;32m   3103\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meid\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_int\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0meid\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/dgl/backend/pytorch/tensor.py\u001b[0m in \u001b[0;36mas_scalar\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m     45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mas_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     49\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_preferred_sparse_format\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: only one element tensors can be converted to Python scalars"
     ]
    }
   ],
   "source": [
    "# PGD_attack_white_box(data_set, rate, graph)\n",
    "PGD_attack(data_set, 0.05, graph, df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "b13a74df",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_dice = pd.read_csv('PGD/' + 'pubmed' + '_remove_rate_' +str(0.05) + '.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "41b1bf52",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_dice = pd.read_csv('PGD/' + 'pubmed' + '_remove_rate_' +str(0.03) + '.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d4cacaff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>342</td>\n",
       "      <td>993</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>449</td>\n",
       "      <td>1277</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>636</td>\n",
       "      <td>1727</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>653</td>\n",
       "      <td>1819</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>993</td>\n",
       "      <td>342</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>469</th>\n",
       "      <td>19635</td>\n",
       "      <td>19681</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>470</th>\n",
       "      <td>19681</td>\n",
       "      <td>19517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>471</th>\n",
       "      <td>19681</td>\n",
       "      <td>19635</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>472</th>\n",
       "      <td>19681</td>\n",
       "      <td>19699</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>473</th>\n",
       "      <td>19699</td>\n",
       "      <td>19681</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>474 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         0      1\n",
       "0      342    993\n",
       "1      449   1277\n",
       "2      636   1727\n",
       "3      653   1819\n",
       "4      993    342\n",
       "..     ...    ...\n",
       "469  19635  19681\n",
       "470  19681  19517\n",
       "471  19681  19635\n",
       "472  19681  19699\n",
       "473  19699  19681\n",
       "\n",
       "[474 rows x 2 columns]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_dice"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
