{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5db49b70",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.io import mmread\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "import scipy.io\n",
    "import scipy.sparse as sp\n",
    "import pandas as pd\n",
    "import os\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13749dab",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'region_job'\n",
    "sens_attr = \"region\"\n",
    "        predict_attr = \"I_am_working_in_field\"\n",
    "        path = \"datasets/pokec_dataset/\"\n",
    "        edges,features,labels,sens=load_pokec(dataset, sens_attr, predict_attr, path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "738c0860",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'region_job'\n",
    "sens_attr = \"region\"\n",
    "predict_attr = \"I_am_working_in_field\"\n",
    "path = \"datasets/pokec_dataset/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e3cb69f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_features_labels = pd.read_csv(os.path.join(path, \"{}.csv\".format(dataset)))\n",
    "\n",
    "header = list(idx_features_labels.columns)\n",
    "header.remove(\"user_id\")\n",
    "header.remove(sens_attr)\n",
    "header.remove(predict_attr)\n",
    "\n",
    "features = sp.csr_matrix(idx_features_labels[header], dtype=np.float32)\n",
    "labels = idx_features_labels[predict_attr].values\n",
    "sens = idx_features_labels[sens_attr].values\n",
    "#Only nodes for which label and sensitive attributes are available are utilized \n",
    "sens_idx = set(np.where(sens >= 0)[0])\n",
    "label_idx = np.where(labels >= 0)[0]\n",
    "idx_used = np.asarray(list(sens_idx & set(label_idx)))\n",
    "idx_nonused = np.asarray(list(set(np.arange(len(labels))).difference(set(idx_used))))\n",
    "\n",
    "features = features[idx_used, :]\n",
    "labels = labels[idx_used]\n",
    "sens = sens[idx_used]\n",
    "\n",
    "idx = np.array(idx_features_labels[\"user_id\"], dtype=int)\n",
    "edges_unordered = np.genfromtxt(os.path.join(path, \"{}_relationship.txt\".format(dataset)), dtype=int)\n",
    "\n",
    "idx_n = idx[idx_nonused]\n",
    "idx = idx[idx_used]\n",
    "used_ind1 = [i for i, elem in enumerate(edges_unordered[:, 0]) if elem not in idx_n]\n",
    "used_ind2 = [i for i, elem in enumerate(edges_unordered[:, 1]) if elem not in idx_n]\n",
    "intersect_ind = list(set(used_ind1) & set(used_ind2))\n",
    "edges_unordered = edges_unordered[intersect_ind, :]\n",
    "# build graph\n",
    "\n",
    "idx_map = {j: i for i, j in enumerate(idx)}\n",
    "edges_un = np.array(list(map(idx_map.get, edges_unordered.flatten())),\n",
    "                    dtype=int).reshape(edges_unordered.shape)\n",
    "\n",
    "\n",
    "adj = sp.coo_matrix((np.ones(edges_un.shape[0]), (edges_un[:, 0], edges_un[:, 1])),\n",
    "                    shape=(labels.shape[0], labels.shape[0]),\n",
    "                    dtype=np.float32)\n",
    "G = nx.from_scipy_sparse_matrix(adj)\n",
    "g_nx_ccs = (G.subgraph(c).copy() for c in nx.connected_components(G))\n",
    "g_nx = max(g_nx_ccs, key=len)\n",
    "\n",
    "import random\n",
    "seed=19\n",
    "random.seed(seed)\n",
    "node_ids = list(g_nx.nodes())\n",
    "idx_s=node_ids\n",
    "random.shuffle(idx_s)\n",
    "\n",
    "features=features[idx_s,:]\n",
    "features=features[:,np.where(np.std(np.array(features.todense()),axis=0)!=0)[0]] \n",
    "\n",
    "features=torch.FloatTensor(np.array(features.todense()))\n",
    "labels=torch.LongTensor(labels[idx_s])\n",
    "\n",
    "sens=torch.LongTensor(sens[idx_s])\n",
    "labels[labels > 1] = 1\n",
    "sens[sens > 0] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "efc9f94b",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('labels_pokec1.npy',np.array(labels))\n",
    "np.save('sens_pokec1.npy', np.array(sens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb52628f",
   "metadata": {},
   "outputs": [],
   "source": [
    "nx.write_edgelist(g_nx, \"/pub/okose/ICLR_submission/deepwalk-master/example_graphs/unc.edgelist\", data=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "294ed816",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels=np.load('labels_german.npy')\n",
    "sens=np.load('sens_german.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "27dbe5dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "300\n",
      "690\n"
     ]
    }
   ],
   "source": [
    "print(len(np.where(labels==0)[0]))\n",
    "print(len(np.where(sens==0)[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "286f2598",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "700\n",
      "310\n"
     ]
    }
   ],
   "source": [
    "print(len(np.where(labels==1)[0]))\n",
    "print(len(np.where(sens==1)[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7147779c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f348477",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6511ce09",
   "metadata": {},
   "outputs": [],
   "source": [
    "mat = scipy.io.loadmat('datasets/socfb-Oklahoma97/Oklahoma97.mat')\n",
    "Adj=mat['A']\n",
    "feats=mat['local_info']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4cd8d1f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_used=[]\n",
    "labels_tobeused=[236, 238, 258, 265, 292, 304, 320, 326, 340 , 361, 364, 367, 382, 386, 388]\n",
    "for i in range(np.shape(feats)[0]):\n",
    "    if(0 not in feats[i,[0,1,2,4,5]] and feats[i,2] in labels_tobeused):\n",
    "        idx_used.append(i)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c12db024",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_nonused = np.asarray(list(set(np.arange(np.shape(feats)[0])).difference(set(idx_used))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0ea6b00b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Label is the major\n",
    "#Sensitive attr is gender\n",
    "labels=[]\n",
    "for i in range(len(idx_used)):\n",
    "    labels.append(np.where(feats[idx_used[i],2]==labels_tobeused)[0])\n",
    "labels=np.array(labels)    \n",
    "labels=np.reshape(labels,np.shape(labels)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ce3ceead",
   "metadata": {},
   "outputs": [],
   "source": [
    "sens=np.array(feats[idx_used,1]-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2e24c297",
   "metadata": {},
   "outputs": [],
   "source": [
    "feats=feats[idx_used,:]\n",
    "feats=feats[:,[0,3,4,5,6]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "adfc3271",
   "metadata": {},
   "outputs": [],
   "source": [
    "edges=np.concatenate((np.reshape(scipy.sparse.find(Adj)[0],(len(scipy.sparse.find(Adj)[0]),1)),np.reshape(scipy.sparse.find(Adj)[1],(len(scipy.sparse.find(Adj)[1]),1))),axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b4fa5220",
   "metadata": {},
   "outputs": [],
   "source": [
    "used_ind1 = [i for i, elem in enumerate(edges[:, 0]) if elem not in idx_nonused]\n",
    "used_ind2 = [i for i, elem in enumerate(edges[:, 1]) if elem not in idx_nonused]\n",
    "intersect_ind = list(set(used_ind1) & set(used_ind2))\n",
    "edges = edges[intersect_ind, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d6828f4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_map = {j: i for i, j in enumerate(idx_used)}\n",
    "edges = np.array(list(map(idx_map.get, edges.flatten())),\n",
    "                        dtype=int).reshape(edges.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3fa261da",
   "metadata": {},
   "outputs": [],
   "source": [
    "adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),\n",
    "                        shape=(labels.shape[0], labels.shape[0]),\n",
    "                        dtype=np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5e408b29",
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx.from_scipy_sparse_matrix(adj)\n",
    "g_nx_ccs = (G.subgraph(c).copy() for c in nx.connected_components(G))\n",
    "g_nx = max(g_nx_ccs, key=len)\n",
    "\n",
    "import random\n",
    "seed=19\n",
    "random.seed(seed)\n",
    "node_ids = list(g_nx.nodes())\n",
    "idx_s=node_ids\n",
    "random.shuffle(idx_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1c0ccb48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2581\n"
     ]
    }
   ],
   "source": [
    "print(len(idx_s))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "205b7c91",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "feats=feats[idx_s,:]\n",
    "feats=feats[:,np.where(np.std(np.array(feats),axis=0)!=0)[0]] \n",
    "feats=torch.FloatTensor(np.array(feats,dtype=float))\n",
    "\n",
    "labels=torch.LongTensor(labels[idx_s])\n",
    "    \n",
    "sens=torch.LongTensor(np.array(sens[idx_s],dtype=int))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5d81acd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_map_n = {j: int(i) for i, j in enumerate(idx_s)}\n",
    "\n",
    "idx_nonused2 = np.asarray(list(set(np.arange(len(list(G.nodes())))).difference(set(idx_s))))\n",
    "used_ind1 = [i for i, elem in enumerate(edges[:, 0]) if elem not in idx_nonused2]\n",
    "used_ind2 = [i for i, elem in enumerate(edges[:, 1]) if elem not in idx_nonused2]\n",
    "intersect_ind = list(set(used_ind1) & set(used_ind2))\n",
    "edges = edges[intersect_ind, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4883ac52",
   "metadata": {},
   "outputs": [],
   "source": [
    "edges = np.array(list(map(idx_map_n.get, edges.flatten())),\n",
    "                dtype=int).reshape(edges.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "c4bc2c7c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(67314, 2)\n"
     ]
    }
   ],
   "source": [
    "print(np.shape(edges))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "cc911f0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),\n",
    "                    shape=(labels.shape[0], labels.shape[0]),\n",
    "                    dtype=np.float32)\n",
    "adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "9f73d170",
   "metadata": {},
   "outputs": [],
   "source": [
    "edges=np.concatenate((np.reshape(scipy.sparse.find(adj)[0],(len(scipy.sparse.find(adj)[0]),1)),np.reshape(scipy.sparse.find(adj)[1],(len(scipy.sparse.find(adj)[1]),1))),axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f3d7530b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(67314, 2)\n"
     ]
    }
   ],
   "source": [
    "print(np.shape(edges))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4e61c1a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "g = nx.from_scipy_sparse_matrix(Adj)\n",
    "edgelist=nx.generate_edgelist(g, data=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "0872241d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(array([], dtype=int64),)\n"
     ]
    }
   ],
   "source": [
    "print(np.where(scipy.sparse.find(Adj)[2]!=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "1bcb667e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2826])\n"
     ]
    }
   ],
   "source": [
    "print(np.shape(sens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "930a4df6",
   "metadata": {},
   "outputs": [],
   "source": [
    "adj=Adj[:,used_idx]\n",
    "adj=adj[used_idx,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a9a3bd63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(array([   0,    0,    0, ..., 2853, 2853, 2853]), array([ 163,  701, 1312, ..., 2487, 2600, 2725]))\n"
     ]
    }
   ],
   "source": [
    "print(np.where(adj.todense()==1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0a0b998b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1e2caf68",
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx.from_scipy_sparse_matrix(adj)\n",
    "g_nx_ccs = (G.subgraph(c).copy() for c in nx.connected_components(G))\n",
    "g_nx = max(g_nx_ccs, key=len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "162171b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "seed=19\n",
    "random.seed(seed)\n",
    "node_ids = list(g_nx.nodes())\n",
    "idx_s=node_ids\n",
    "random.shuffle(idx_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b48b5207",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1555, 1253, 2687, 788, 2667, 988, 2074, 259, 1858, 1774, 2413, 1053, 691, 2428, 1713, 19, 2425, 721, 185, 935, 2349, 871, 374, 121, 2835, 253, 1567, 2009, 910, 747, 2068, 2177, 240, 816, 1721, 2685, 2627, 2626, 2833, 409, 782, 493, 44, 424, 2089, 2238, 5, 394, 2536, 2378, 1104, 728, 2361, 45, 370, 2843, 856, 175, 67, 2545, 1912, 1877, 2461, 412, 466, 1291, 2597, 1076, 668, 1123, 2781, 2341, 630, 614, 2057, 2722, 1966, 2335, 989, 160, 1583, 1268, 1734, 1597, 1005, 752, 560, 1173, 1503, 378, 562, 2693, 2466, 302, 2223, 1511, 1606, 2296, 1270, 2455, 1232, 2251, 1404, 1073, 2460, 2742, 2248, 13, 2572, 2785, 700, 1629, 1023, 1710, 2504, 312, 323, 2239, 1889, 1571, 537, 516, 1946, 2397, 1881, 17, 665, 2809, 396, 70, 2602, 686, 1271, 1434, 724, 2800, 557, 1634, 1151, 524, 219, 1780, 20, 1537, 1970, 1564, 234, 2565, 611, 51, 2253, 1240, 1921, 1442, 254, 489, 612, 2508, 1695, 1475, 1836, 1494, 902, 1570, 1917, 29, 649, 2000, 1636, 622, 2773, 1222, 2776, 2574, 1796, 343, 1155, 664, 1368, 401, 1876, 713, 2720, 585, 2410, 200, 770, 2458, 1018, 2221, 2649, 855, 1358, 1399, 2752, 1552, 52, 1789, 2449, 1461, 1072, 745, 2531, 2046, 408, 2535, 2237, 1229, 2534, 2619, 1019, 510, 1489, 1523, 2736, 2318, 1382, 696, 542, 1909, 603, 2406, 2730, 72, 2792, 135, 2340, 242, 1686, 324, 2526, 151, 1620, 2111, 1333, 1261, 1744, 527, 14, 879, 1799, 1689, 1807, 2812, 61, 1233, 2474, 540, 2418, 532, 112, 896, 1654, 129, 390, 468, 404, 1159, 2439, 906, 205, 495, 2170, 91, 2443, 1937, 75, 2083, 1645, 23, 552, 1314, 2760, 1468, 1443, 2244, 2143, 1879, 446, 2471, 967, 2261, 1743, 27, 83, 210, 813, 653, 2532, 1973, 923, 2308, 806, 2655, 2386, 1298, 1947, 359, 2811, 2090, 2379, 2258, 1684, 1456, 590, 609, 125, 2546, 2036, 1940, 1426, 2331, 1400, 1594, 2681, 2381, 163, 1139, 2108, 1687, 2396, 660, 25, 2834, 501, 652, 402, 759, 735, 1147, 1691, 1961, 572, 492, 7, 1741, 1679, 71, 1046, 2478, 1592, 863, 1944, 349, 2324, 667, 2841, 465, 1795, 1550, 1817, 1706, 1746, 555, 1350, 2290, 1027, 1994, 1754, 1644, 985, 2027, 2448, 2732, 869, 2713, 1648, 1134, 293, 2107, 1578, 9, 2502, 1964, 1643, 1897, 2757, 715, 1071, 1865, 1396, 2347, 919, 1752, 2003, 1039, 88, 211, 237, 1647, 1290, 835, 2263, 1559, 2393, 2451, 2542, 1707, 2540, 126, 47, 2599, 2694, 1580, 216, 1323, 78, 1565, 2596, 2519, 2114, 292, 2029, 809, 1906, 1920, 1758, 657, 2115, 1458, 10, 1534, 1596, 1166, 1418, 2285, 2529, 885, 827, 2481, 1101, 1887, 742, 1742, 757, 1048, 1496, 2636, 2518, 2168, 289, 1487, 318, 1070, 556, 907, 1478, 922, 846, 608, 526, 1226, 1363, 2847, 2810, 1526, 2204, 730, 1269, 1264, 299, 1886, 2557, 472, 1388, 1980, 1828, 2488, 1738, 107, 756, 2702, 1255, 1544, 2559, 2514, 155, 2579, 749, 2823, 1631, 2367, 2758, 2427, 2102, 1025, 1512, 1561, 436, 496, 2113, 486, 589, 2005, 341, 968, 1302, 831, 642, 799, 2095, 2184, 103, 624, 226, 1835, 754, 1715, 1839, 2499, 1321, 1035, 1577, 1406, 189, 1421, 2146, 2004, 2243, 1833, 2010, 2766, 1178, 2570, 2409, 214, 1989, 2828, 771, 2593, 2266, 2674, 702, 295, 2271, 2014, 2609, 903, 718, 1385, 2691, 1407, 1781, 2199, 1372, 209, 890, 1532, 917, 2513, 1307, 115, 1111, 2695, 2749, 1309, 1890, 22, 1344, 1187, 2422, 1452, 2511, 2252, 440, 2663, 1747, 1289, 881, 990, 1803, 579, 2420, 1459, 488, 122, 2604, 2364, 1351, 514, 1824, 2441, 529, 2761, 157, 363, 1756, 2431, 386, 978, 2189, 1903, 791, 2498, 265, 2739, 558, 1962, 1430, 848, 1119, 694, 431, 2040, 1987, 362, 1662, 1997, 2332, 1932, 1490, 1206, 1977, 119, 2622, 1563, 877, 1572, 2594, 142, 952, 2398, 49, 2721, 1562, 2077, 2548, 676, 680, 275, 2660, 2642, 1543, 1957, 127, 1360, 1664, 1037, 1422, 1099, 1520, 388, 203, 796, 1425, 714, 2297, 63, 954, 2507, 1220, 918, 1625, 2664, 2054, 547, 534, 1717, 215, 2279, 722, 1831, 332, 1812, 2744, 1718, 40, 272, 1753, 208, 2316, 1635, 1105, 801, 1411, 326, 1870, 476, 1234, 2512, 344, 1495, 1776, 2211, 586, 1129, 605, 1211, 1837, 1121, 1729, 1024, 1130, 826, 1395, 2144, 1905, 1976, 2, 109, 55, 1319, 834, 1398, 110, 970, 1370, 2635, 1389, 471, 519, 883, 2853, 1454, 114, 2127, 1471, 2060, 666, 1791, 2229, 2500, 1065, 1081, 792, 574, 2134, 2496, 2608, 1919, 1176, 798, 2735, 358, 315, 2323, 1283, 454, 201, 673, 678, 342, 1218, 1108, 147, 1660, 1850, 1856, 384, 1637, 2714, 2468, 335, 199, 1181, 1939, 123, 522, 246, 1965, 2329, 2013, 2209, 1189, 738, 2817, 2001, 2648, 2633, 1017, 2666, 1751, 1126, 1346, 1082, 1493, 2852, 1067, 2634, 1417, 1783, 1958, 1026, 528, 383, 1254, 740, 2388, 1444, 674, 2586, 1542, 149, 2789, 494, 2782, 2175, 1665, 993, 1670, 1728, 829, 2479, 41, 515, 677, 1106, 1245, 256, 1677, 2063, 317, 1618, 1248, 2561, 1423, 2093, 417, 2483, 1242, 2085, 480, 1663, 2202, 637, 2525, 212, 2149, 2552, 975, 689, 737, 1194, 152, 57, 2322, 1560, 2837, 165, 2053, 1038, 2305, 619, 101, 2336, 456, 258, 1765, 2601, 177, 810, 1761, 805, 1513, 971, 85, 291, 761, 843, 2033, 1745, 2824, 2631, 2818, 1424, 271, 2616, 1740, 1536, 956, 104, 1371, 2138, 191, 2850, 1609, 875, 1044, 852, 1021, 1655, 2688, 2740, 926, 755, 1772, 282, 2707, 1090, 2052, 2580, 2048, 1007, 1049, 111, 1376, 1914, 1161, 154, 2725, 1802, 842, 924, 164, 857, 169, 655, 1172, 433, 577, 2309, 269, 2039, 1466, 1115, 2181, 1522, 633, 1805, 1608, 2312, 1826, 960, 976, 2109, 1146, 1582, 1149, 1950, 416, 983, 204, 1157, 2819, 1593, 1278, 776, 1988, 639, 1066, 432, 1884, 2578, 768, 170, 1757, 2784, 156, 2392, 1764, 1948, 46, 477, 651, 2620, 1292, 1408, 2260, 2246, 316, 992, 1362, 1061, 373, 2711, 1547, 90, 2084, 133, 716, 2849, 130, 1143, 2164, 188, 693, 2423, 563, 232, 1622, 2798, 1846, 969, 2348, 1925, 2802, 2560, 2283, 2330, 2307, 1004, 2690, 2838, 143, 207, 1300, 984, 2183, 726, 2547, 1675, 1688, 1397, 1913, 962, 908, 2179, 42, 87, 2606, 82, 1801, 1191, 1419, 2522, 235, 1667, 719, 2313, 575, 2228, 2286, 1699, 2603, 411, 2708, 1179, 1107, 1735, 2438, 1491, 2662, 250, 950, 1103, 2489, 1574, 1595, 1823, 821, 980, 2032, 1770, 2086, 2665, 1299, 1969, 955, 2079, 2629, 2264, 1301, 442, 236, 1486, 2370, 833, 1339, 2680, 1658, 1749, 2401, 2218, 2094, 646, 230, 996, 2612, 2230, 356, 2517, 1827, 351, 1003, 1366, 2269, 2291, 1020, 450, 311, 1953, 481, 398, 134, 2338, 1672, 1008, 1384, 1033, 322, 2805, 181, 2315, 601, 2023, 2588, 2821, 887, 951, 184, 2746, 1693, 944, 685, 1533, 1682, 2366, 59, 895, 1241, 1888, 140, 1652, 2304, 925, 213, 1916, 1378, 538, 1158, 297, 656, 2225, 183, 1386, 2282, 406, 1610, 1844, 1501, 2368, 815, 789, 774, 597, 1431, 647, 2590, 1568, 2472, 1256, 1963, 1448, 451, 2345, 76, 329, 1821, 69, 1144, 837, 159, 543, 2846, 1722, 598, 966, 1244, 1589, 849, 1650, 911, 2135, 1666, 2415, 1992, 1633, 914, 2524, 434, 847, 2563, 1782, 1898, 2779, 1986, 780, 1492, 74, 1141, 1998, 333, 2625, 1798, 2247, 2553, 1342, 1403, 2159, 1185, 548, 1703, 54, 937, 1546, 1628, 1153, 2245, 670, 1249, 1521, 2417, 2493, 1477, 38, 1315, 1274, 192, 1263, 973, 2408, 2480, 1656, 2539, 1267, 765, 37, 1698, 459, 2745, 545, 2167, 2763, 2677, 158, 1190, 1700, 224, 2049, 2787, 2803, 1322, 1463, 245, 1626, 645, 2216, 310, 120, 2235, 2042, 503, 1184, 1784, 634, 2288, 1840, 39, 2587, 2154, 1697, 2328, 669, 2623, 1204, 1401, 2206, 2554, 98, 2503, 1317, 1473, 1109, 2140, 2311, 751, 888, 986, 1586, 1353, 2019, 1324, 2453, 2607, 2363, 2234, 2816, 1630, 1210, 2754, 578, 643, 92, 425, 544, 1281, 2400, 2426, 1508, 1221, 507, 2295, 783, 2658, 1553, 2012, 1336, 2358, 1297, 884, 1864, 168, 31, 904, 1694, 2822, 1785, 233, 361, 1180, 506, 1427, 1646, 1472, 2767, 1196, 2430, 64, 1832, 876, 886, 1685, 2212, 2222, 1318, 1192, 635, 1054, 500, 464, 422, 1972, 606, 2528, 1251, 629, 1800, 418, 2226, 594, 1135, 2270, 56, 997, 1201, 1316, 1325, 2123, 198, 2825, 305, 1720, 531, 2628, 961, 1375, 2831, 2101, 338, 2346, 2436, 1581, 2067, 2692, 1731, 2434, 1748, 1348, 1022, 2416, 819, 2124, 1207, 2139, 1416, 244, 977, 1902, 1029, 1504, 2344, 2753, 1575, 2045, 303, 1168, 2698, 865, 546, 1923, 1818, 1843, 704, 2591, 2814, 2797, 965, 1036, 228, 229, 2683, 2016, 744, 2583, 1328, 1611, 285, 1875, 1225, 1891, 1441, 2573, 1787, 2790, 1908, 248, 1726, 166, 2808, 1960, 2273, 441, 1982, 1131, 255, 1329, 128, 372, 804, 1554, 1260, 2652, 1790, 2614, 2155, 1539, 1170, 2024, 945, 1294, 2755, 2158, 2672, 2249, 587, 1624, 182, 613, 405, 1326, 1334, 1616, 97, 600, 1208, 1692, 539, 1188, 764, 2065, 2214, 568, 1447, 438, 2339, 116, 1087, 2497, 862, 1612, 958, 1524, 1041, 1160, 1548, 2306, 217, 1480, 675, 1545, 2643, 1127, 290, 1330, 2160, 308, 241, 153, 1086, 2389, 1683, 892, 1374, 2374, 874, 1186, 832, 772, 1128, 1379, 912, 458, 1150, 2220, 2737, 1941, 1968, 2799, 1089, 2007, 1485, 800, 414, 1064, 296, 1971, 1464, 661, 880, 2333, 345, 347, 1934, 644, 2829, 1849, 350, 314, 2661, 940, 570, 994, 974, 1959, 2412, 491, 1100, 1306, 845, 340, 511, 957, 2342, 220, 820, 1613, 261, 444, 1247, 2192, 382, 2056, 1767, 2099, 387, 1340, 2195, 2741, 415, 1733, 251, 861, 1282, 483, 1369, 2281, 797, 81, 2475, 118, 1383, 2585, 698, 105, 1556, 190, 1236, 2059, 963, 15, 1030, 2551, 4, 2343, 1223, 1152, 79, 262, 2595, 654, 551, 499, 1162, 1231, 2826, 273, 336, 2337, 2501, 2765, 2097, 3, 593, 2201, 2610, 878, 1949, 2038, 1059, 2314, 33, 2233, 1445, 615, 1042, 699, 252, 982, 1440, 264, 1632, 2465, 650, 2727, 1432, 1868, 2641, 1669, 1414, 946, 1702, 620, 2265, 767, 972, 959, 2362, 2300, 36, 2640, 599, 32, 1133, 2293, 2047, 2162, 2061, 2274, 2118, 2669, 1052, 2659, 2278, 1279, 12, 1011, 2103, 2719, 948, 822, 915, 1933, 1219, 1288, 709, 1171, 2728, 868, 2541, 1479, 662, 1854, 1680, 781, 2487, 1273, 932, 339, 429, 1177, 893, 632, 2689, 604, 1579, 930, 1642, 2777, 1320, 2066, 319, 758, 1470, 2359, 1352, 1481, 520, 43, 392, 1810, 2492, 1588, 1200, 328, 695, 2521, 1538, 2165, 2356, 2172, 2197, 701, 146, 933, 132, 1979, 1469, 330, 1202, 2203, 2326, 144, 139, 785, 2598, 947, 1338, 936, 2791, 225, 2104, 2142, 2738, 1094, 469, 1737, 1102, 48, 2670, 334, 2495, 354, 949, 194, 929, 850, 913, 2082, 2319, 1806, 2289, 320, 1510, 592, 991, 30, 247, 1118, 1993, 1093, 2836, 682, 894, 858, 1551, 2028, 2073, 889, 1402, 2303, 276, 86, 623, 825, 2470, 2087, 368, 1006, 2772, 24, 2778, 1356, 2848, 2120, 2182, 2723, 2275, 509, 2205, 2820, 1439, 423, 900, 1755, 173, 439, 1587, 2232, 1057, 1549, 1497, 2334, 2646, 2156, 1935, 2376, 1096, 2718, 2715, 2145, 743, 304, 565, 2703, 1604, 1412, 1573, 610, 1928, 369, 375, 2369, 2018, 376, 823, 1084, 2213, 1415, 1097, 1045, 2131, 1674, 2675, 1723, 1227, 1142, 812, 2350, 479, 1079, 1651, 1016, 723, 1062, 1724, 897, 1474, 2411, 307, 2072, 178, 2630, 979, 284, 1435, 1074, 2081, 18, 2682, 638, 281, 2171, 1343, 2382, 1873, 1929, 457, 1901, 905, 1659, 2794, 1113, 736, 2105, 2178, 2454, 2701, 1332, 325, 1058, 920, 2129, 928, 367, 1853, 1012, 455, 1392, 1286, 999, 1857, 1716, 1001, 2284, 136, 2486, 517, 148, 1060, 708, 221, 1387, 1295, 1262, 854, 2697, 2051, 2729, 725, 2494, 2147, 2845, 1310, 839, 703, 2387, 2530, 1120, 2549, 124, 407, 176, 1855, 2404, 239, 1167, 1098, 2731, 1499, 2185, 2257, 1732, 1768, 1117, 569, 6, 561, 864, 808, 1736, 1825, 280, 1238, 2148, 720, 1866, 2433, 793, 1518, 2435, 1367, 1203, 2515, 2656, 1584, 684, 1619, 1450, 2126, 504, 943, 1926, 1438, 0, 729, 2750, 1530, 2098, 1952, 1762, 1032, 453, 591, 2031, 2668, 1055, 1690, 2353, 1872, 2533, 2421, 672, 1349, 1516, 2125, 1165, 65, 2236, 2161, 1228, 1601, 2605, 2130, 899, 1429, 1482, 901, 1116, 518, 2141, 2096, 1252, 1871, 195, 2678, 1696, 352, 697, 2747, 663, 1602, 2506, 279, 2621, 2100, 2696, 227, 1769, 1910, 1215, 1786, 400, 844, 1212, 2075, 766, 2562, 1327, 2568, 671, 2137, 2473, 2804, 1335, 485, 2710, 824, 1615, 1305, 2445, 1924, 462, 1451, 1063, 1390, 487, 2699, 2196, 2611, 1068, 1955, 1380, 2037, 2644, 811, 2373, 521, 541, 1883, 2110, 2071, 625, 1355, 490, 2299, 2759, 2112, 1590, 1705, 1, 1515, 2771, 1945, 2796, 807, 1708, 870, 2550, 588, 1777, 1918, 1566, 270, 2491, 1112, 1460, 1217, 2509, 1841, 1174, 1276, 595, 853, 62, 2375, 222, 403, 2006, 1701, 1257, 830, 461, 2391, 58, 84, 2025, 16, 1999, 371, 1214, 2062, 2298, 2733, 2527, 817, 1050, 364, 584, 1250, 2302, 437, 447, 1540, 1122, 1834, 96, 385, 2241, 898, 1230, 427, 2462, 1507, 1649, 2477, 909, 1500, 223, 360, 763, 2194, 2227, 1727, 1640, 1617, 355, 802, 688, 94, 2651, 1990, 921, 1983, 508, 583, 1296, 1483, 2581, 1763, 2176, 2058, 1847, 2276, 512, 2191, 784, 1892, 1293, 2405, 778, 1239, 2280, 21, 1484, 2267, 1525, 627, 2174, 2571, 2021, 2576, 2407, 1852, 2645, 2444, 1954, 1845, 2756, 2813, 1975, 741, 1811, 2617, 739, 2450, 995, 1794, 732, 1428, 1364, 2207, 2320, 851, 331, 2117, 381, 2193, 2650, 2163, 1712, 1711, 536, 463, 2788, 60, 2050, 1951, 1899, 2717, 1000, 1405, 2639, 2783, 2198, 1603, 596, 640, 2457, 1420, 2704, 475, 1409, 460, 470, 683, 1797, 2575, 1779, 2447, 419, 621, 2600, 1882, 1013, 287, 1862, 2351, 2505, 2365, 987, 2133, 238, 2844, 2064, 1410, 939, 1900, 1956, 750, 1819, 1195, 1243, 1730, 34, 1517, 2748, 1509, 2613, 581, 1075, 2647, 1235, 1051, 1467, 1140, 2726, 1829, 478, 731, 567, 2786, 1312, 535, 1793, 1535, 26, 2510, 2558, 1308, 2482, 2751, 712, 113, 2020, 1010, 2215, 2287, 733, 1303, 1091, 1820, 1673, 2272, 554, 1514, 882, 553, 2452, 2372, 2055, 1331, 327, 1830, 346, 2119, 1145, 1911, 1759, 138, 2076, 2357, 1357, 753, 1394, 2043, 1557, 2136, 1132, 1488, 628, 1193, 814, 2564, 1505, 1462, 278, 286, 1991, 1878, 2806, 2764, 2200, 1531, 2399, 1413, 2169, 1453, 309, 1209, 2490, 549, 53, 787, 690, 1506, 838, 2277, 941, 80, 1600, 1657, 1205, 2485, 2464, 2150, 1265, 1361, 1169, 530, 1137, 626, 1814, 2015, 1381, 659, 1995, 95, 2268, 602, 174, 2424, 2132, 452, 1528, 2589, 1981, 607, 1446, 337, 397, 981, 891, 1623, 197, 618, 1639, 2442, 2259, 550, 1996, 391, 2484, 2371, 1605, 1863, 1154, 2327, 2851, 1056, 2403, 1848, 2352, 99, 2070, 1676, 274, 2217, 775, 1936, 2577, 1272, 2543, 172, 50, 803, 631, 2157, 2709, 786, 1436, 1809, 2180, 2520, 2219, 859, 1719, 2091, 2190, 2017, 681, 426, 1040, 779, 288, 2210, 705, 1585, 692, 769, 734, 942, 2008, 102, 298, 2030, 2224, 872, 73, 1304, 1598, 2377, 1125, 711, 2467, 141, 866, 2380, 2231, 2638, 1930, 1043, 1457, 2582, 1391, 484, 746, 1502, 180, 231, 497, 617, 1164, 1393, 1541, 1284, 2034, 1576, 2657, 145, 263, 2679, 2566, 365, 202, 301, 2618, 2775, 1259, 2250, 2827, 748, 1277, 1031, 1377, 2476, 2026, 2523, 1114, 77, 1766, 717, 395, 580, 89, 2815, 1922, 2769, 1641, 2584, 679, 2011, 1028, 1124, 348, 1002, 8, 2390, 268, 2676, 1816, 1661, 1156, 357, 1974, 2516, 773, 1893, 648, 953, 2208, 566, 1725, 1653, 428, 1009, 2301, 1280, 934, 2383, 108, 1760, 2173, 206, 1874, 1449, 2022, 2706, 2153, 1599, 523, 2795, 266, 2544, 790, 513, 1813, 2414, 249, 1258, 1851, 2294, 916, 2310, 1904, 1433, 571, 1069, 998, 430, 1465, 931, 2394, 1678, 2793, 1285, 1860, 300, 2768, 1838, 2440, 505, 1197, 1034, 2842, 794, 1148, 137, 2255, 2780, 1015, 2671, 2538, 582, 193, 1804, 2432, 150, 377, 2128, 2469, 445, 2292, 1943, 687, 707, 131, 1621, 938, 366, 1136, 1085, 1224, 186, 2654, 2456, 1894, 1519, 2743, 1311, 389, 1498, 167, 2555, 1709, 1861, 2187, 2632, 393, 1373, 1842, 777, 1529, 1365, 1341, 1569, 2734, 1078, 420, 2446, 2122, 482, 1896, 161, 641, 927, 260, 1978, 1775, 2830, 2088, 2673, 1182, 2186, 2653, 1266, 2325, 2840, 2121, 1337, 564, 2080, 2774, 2166, 1216, 818, 267, 2712, 2705, 1163, 100, 257, 1313, 2106, 1095, 873, 1047, 1476, 2437, 2256, 2355, 162, 2807, 1199, 2770, 2569, 1591, 35, 636, 2624, 2684, 1773, 283, 1942, 2567, 841, 467, 1931, 1183, 1778, 2762, 525, 353, 313, 1138, 2686, 1808, 727, 66, 2716, 1867, 2092, 1938, 277, 1915, 421, 1739, 2832, 867, 321, 187, 1275, 1246, 1815, 1907, 710, 1213, 1668, 1859, 1175, 1967, 533, 1088, 2700, 473, 1869, 1558, 1927, 106, 379, 1014, 11, 762, 2240, 1771, 1527, 2044, 2069, 1985, 2151, 498, 2459, 1984, 1880, 28, 1895, 2262, 435, 2041, 2592, 559, 576, 795, 1614, 860, 2839, 2385, 2254, 964, 1681, 2637, 658, 2078, 2360, 1237, 2724, 1822, 1638, 1750, 2395, 413, 68, 1714, 410, 306, 1607, 1885, 2242, 474, 840, 294, 836, 2556, 2354, 93, 1287, 1347, 448, 1110, 1354, 1704, 1077, 449, 1080, 2463, 616, 2419, 1198, 2188, 1437, 1627, 828, 2116, 502, 2152, 179, 2801]\n"
     ]
    }
   ],
   "source": [
    "print(node_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ff613def",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "ad9f2a29",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_206905/2585262802.py:16: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  idx_features_labels['Gender'][idx_features_labels['Gender'] == 'Female'] = 1\n",
      "/tmp/ipykernel_206905/2585262802.py:17: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  idx_features_labels['Gender'][idx_features_labels['Gender'] == 'Male'] = 0\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "sens_attr=\"Gender\"\n",
    "predict_attr=\"GoodCustomer\"\n",
    "path = \"datasets/german/\"\n",
    "dataset = 'german'\n",
    "idx_features_labels = pd.read_csv(os.path.join(path,\"{}.csv\".format(dataset)))\n",
    "header = list(idx_features_labels.columns)\n",
    "header.remove(predict_attr)\n",
    "    \n",
    "header.remove('OtherLoansAtStore')\n",
    "header.remove('PurposeOfLoan')\n",
    "header.remove(sens_attr)\n",
    "                         \n",
    "    # Sensitive Attribute\n",
    "idx_features_labels['Gender'][idx_features_labels['Gender'] == 'Female'] = 1\n",
    "idx_features_labels['Gender'][idx_features_labels['Gender'] == 'Male'] = 0\n",
    "\n",
    "edges = np.genfromtxt(f'{path}/{dataset}_edges.txt').astype('int')\n",
    "    \n",
    "\n",
    "features = sp.csr_matrix(idx_features_labels[header], dtype=np.float32)\n",
    "labels = idx_features_labels[predict_attr].values\n",
    "labels[labels == -1] = 0\n",
    "sens = idx_features_labels[sens_attr].values.astype(int)\n",
    "\n",
    "adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),\n",
    "                    shape=(labels.shape[0], labels.shape[0]),\n",
    "                    dtype=np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "c4523d8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)\n",
    "\n",
    "G = nx.from_scipy_sparse_matrix(adj)\n",
    "g_nx_ccs = (G.subgraph(c).copy() for c in nx.connected_components(G))\n",
    "g_nx = max(g_nx_ccs, key=len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "f9ec61c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000\n",
      "1000\n"
     ]
    }
   ],
   "source": [
    "print(len(G.nodes()))\n",
    "print(len(g_nx.nodes()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "3a9ff77f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "seed=19\n",
    "random.seed(seed)\n",
    "node_ids = list(g_nx.nodes())\n",
    "idx_s=node_ids\n",
    "random.shuffle(idx_s)                     \n",
    "                         \n",
    "features=features[idx_s,:]\n",
    "features=features[:,np.where(np.std(np.array(features.todense()),axis=0)!=0)[0]] \n",
    "    \n",
    "features=torch.FloatTensor(np.array(features.todense()))\n",
    "labels=torch.LongTensor(labels[idx_s])\n",
    "    \n",
    "sens=torch.LongTensor(sens[idx_s])  \n",
    "    \n",
    "idx_map_n = {j: int(i) for i, j in enumerate(idx_s)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "4b57c238",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1000])\n"
     ]
    }
   ],
   "source": [
    "print(sens.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1ecb2d59",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([18876, 17])\n"
     ]
    }
   ],
   "source": [
    "print(np.shape(features))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c99b2538",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import torch\n",
    "seed=19\n",
    "random.seed(seed)\n",
    "node_ids = list(g_nx.nodes())\n",
    "idx_s=node_ids\n",
    "random.shuffle(idx_s)                    \n",
    "    \n",
    "features=features[idx_s,:]\n",
    "features=features[:,np.where(np.std(np.array(features.todense()),axis=0)!=0)[0]] \n",
    "    \n",
    "features=torch.FloatTensor(np.array(features.todense()))\n",
    "labels=torch.LongTensor(labels[idx_s])\n",
    "    \n",
    "sens=torch.LongTensor(sens[idx_s])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d677d21a",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_map_n = {j: int(i) for i, j in enumerate(idx_s)}\n",
    "\n",
    "    #idx_nonused2 = np.asarray(list(set(np.arange(len(list(G.nodes())))).difference(set(idx_s))))\n",
    "    #used_ind1 = [i for i, elem in enumerate(edges[:, 0]) if elem not in idx_nonused2]\n",
    "    #used_ind2 = [i for i, elem in enumerate(edges[:, 1]) if elem not in idx_nonused2]\n",
    "    #intersect_ind = list(set(used_ind1) & set(used_ind2))\n",
    "    #edges_un = edges_un[intersect_ind, :]\n",
    "edges = np.array(list(map(idx_map_n.get, edges.flatten())),\n",
    "                 dtype=int).reshape(edges.shape)\n",
    "adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),\n",
    "                    shape=(labels.shape[0], labels.shape[0]),\n",
    "                    dtype=np.float32)\n",
    "adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)\n",
    "edges= np.concatenate((np.reshape(scipy.sparse.find(adj)[0],(len(scipy.sparse.find(adj)[0]),1)),np.reshape(scipy.sparse.find(adj)[1],(len(scipy.sparse.find(adj)[1]),1))),axis=1)\n",
    "g_nx = nx.from_scipy_sparse_matrix(adj)\n",
    "edges = torch.LongTensor(edges.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "511350ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "features = sp.csr_matrix(idx_features_labels[header], dtype=np.float32)\n",
    "labels = idx_features_labels[predict_attr].values\n",
    "sens = idx_features_labels[sens_attr].values.astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7042873a",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = np.arange(features.shape[0])\n",
    "idx_map = {j: i for i, j in enumerate(idx)}\n",
    "edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),\n",
    "                 dtype=int).reshape(edges_unordered.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "f0ba547d",
   "metadata": {},
   "outputs": [],
   "source": [
    "adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),\n",
    "                    shape=(labels.shape[0], labels.shape[0]),\n",
    "                    dtype=np.float32)\n",
    "\n",
    "    # build symmetric adjacency matrix\n",
    "adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)\n",
    "    \n",
    "G = nx.from_scipy_sparse_matrix(adj)\n",
    "g_nx_ccs = (G.subgraph(c).copy() for c in nx.connected_components(G))\n",
    "g_nx = max(g_nx_ccs, key=len) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "41900cf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "seed=19\n",
    "random.seed(seed)\n",
    "node_ids = list(g_nx.nodes())\n",
    "idx_s=node_ids\n",
    "random.shuffle(idx_s)\n",
    "    \n",
    "features=features[idx_s,:]\n",
    "features=features[:,np.where(np.std(np.array(features.todense()),axis=0)!=0)[0]] \n",
    "    \n",
    "features=torch.FloatTensor(np.array(features.todense()))\n",
    "labels=torch.LongTensor(labels[idx_s])\n",
    "    \n",
    "sens=torch.LongTensor(sens[idx_s])\n",
    "idx_map_n = {j: int(i) for i, j in enumerate(idx_s)}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "ec40490d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[]\n"
     ]
    }
   ],
   "source": [
    "idx_nonused2 = np.asarray(list(set(np.arange(len(list(G.nodes())))).difference(set(idx_s))))\n",
    "print(idx_nonused2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7442051d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
