{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "dfa156e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import csv\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "import torch\n",
    "import dgl\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "datapath = f'../data/Amazon/'\n",
    "with open(datapath+'statistics', 'rb') as file:\n",
    "    num_task, num_class = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "1596baa6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import deque \n",
    "\n",
    "def dfs(g, u, visited):\n",
    "    q = deque([u])\n",
    "    subnodes = []\n",
    "    count = 0\n",
    "    while q:\n",
    "        v = q.popleft()\n",
    "        if v in visited:\n",
    "            continue\n",
    "#         if len(visited)%1000 == 0:\n",
    "#             print (len(visited))\n",
    "        visited.add(v)\n",
    "        subnodes.append(v)\n",
    "        count += 1\n",
    "        node_list = g.edges()[1][(g.edges()[0] == v).nonzero().view(-1)].tolist()\n",
    "        node_list += g.edges()[0][(g.edges()[1] == v).nonzero().view(-1)].tolist()\n",
    "        q.extend((x for x in node_list if x not in visited))\n",
    "#         for n in node_list:\n",
    "#             if n not in visited:\n",
    "#                 q.append(n)\n",
    "    return count, subnodes\n",
    "\n",
    "def get_largest_cluster(g):\n",
    "    count_list = []\n",
    "    visited = set()\n",
    "    for u in g.nodes().tolist():\n",
    "        if u not in visited:\n",
    "#             print (len(visited))\n",
    "            count_list.append(dfs(g,u, visited))\n",
    "    count_list.sort(reverse=True)\n",
    "    return count_list[0][1]\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "ecc1bc67",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph(num_nodes=11757, num_edges=98516,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n",
      "Graph(num_nodes=5506, num_edges=37434,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n",
      "Graph(num_nodes=8088, num_edges=65436,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n",
      "Graph(num_nodes=13391, num_edges=146950,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n",
      "Graph(num_nodes=19099, num_edges=297950,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n",
      "Graph(num_nodes=21851, num_edges=355690,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n",
      "Graph(num_nodes=24353, num_edges=495036,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n",
      "Graph(num_nodes=26564, num_edges=607048,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n"
     ]
    }
   ],
   "source": [
    "sub_g_list = []\n",
    "for time_slot in range(num_task):\n",
    "    with open(datapath+f'graph_{time_slot}_by_edges', 'rb') as file:\n",
    "        g = pickle.load(file)\n",
    "        subnodes = get_largest_cluster(g)\n",
    "        sub_g = g.subgraph(subnodes)\n",
    "        print (sub_g)\n",
    "        sub_g_list.append(sub_g)\n",
    "#         break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "id": "013db930",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "51763"
      ]
     },
     "execution_count": 119,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(appeared_nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "id": "4b28265f",
   "metadata": {},
   "outputs": [],
   "source": [
    "appeared_nodes = set()\n",
    "for i in range(num_task):\n",
    "    appeared_nodes.update(sub_g_list[i].ndata['node_idxs'].tolist())\n",
    "\n",
    "with open(datapath+f'graph_whole', 'rb') as file:\n",
    "    g_whole = pickle.load(file)\n",
    "    g_whole.ndata['node_idxs'] = g_whole.nodes()\n",
    "    subnodes = list(appeared_nodes)\n",
    "    g_whole_sub = g_whole.subgraph(subnodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "id": "c75386a9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([    0,     1,     2,  ..., 61182, 61183, 61184])"
      ]
     },
     "execution_count": 124,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g_whole.ndata['node_idxs']\n",
    "# g_whole_sub.ndata['node_idxs']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "6abe6863",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(datapath+f'sub_graph_whole', 'wb') as file:\n",
    "    pickle.dump(g_whole_sub, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "09504815",
   "metadata": {},
   "outputs": [],
   "source": [
    "g_list = []\n",
    "for time_slot in range(num_task):\n",
    "    with open(datapath+f'graph_{time_slot}_by_edges', 'rb') as file:\n",
    "        g = pickle.load(file)\n",
    "        g_list.append(g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "8ab0dc4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "prev_idxs = set()\n",
    "\n",
    "for t in range(num_task):\n",
    "    curr_idxs = sub_g_list[t].ndata['node_idxs'].tolist()\n",
    "    new_nodes = [i for i in range(len(curr_idxs)) if sub_g_list[t].ndata['node_idxs'][i].detach().item() not in prev_idxs]\n",
    "    new_nodes_mask = np.zeros(len(curr_idxs))\n",
    "    new_nodes_mask[new_nodes] = 1\n",
    "    sub_g_list[t].ndata['new_nodes_mask'] = torch.tensor(new_nodes_mask).int()\n",
    "    \n",
    "    prev_idxs.update(node_idx for node_idx in sub_g_list[t].ndata['node_idxs'].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "39f18bcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(num_task):\n",
    "    with open(datapath+f'sub_graph_{i}_by_edges', 'wb') as file:\n",
    "      pickle.dump(sub_g_list[i], file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adb7a803",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
