{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import csv\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "import torch\n",
    "import dgl\n",
    "import os\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_research_area(venue):\n",
    "    db = {'SIGMOD','ICDE','VLDB','EDBT','PODS','ICDT','DASFAA','SSDBM','CIKM'}\n",
    "    dm = {'KDD','ICDM','SDM','PKDD','PAKDD'}\n",
    "    ai = {'IJCAI','AAAI','NIPS','ICML','ECML','ACML','IJCNN','UAI','ECAI','COLT','ACL','KR'}\n",
    "    cv = {'CVPR','ICCV','ECCV','ACCV','MM','ICPR','ICIP','ICME'}\n",
    "    if venue.upper() in db: return 0\n",
    "    if venue.upper() in dm: return 1\n",
    "    if venue.upper() in ai: return 2\n",
    "    if venue.upper() in cv: return 3\n",
    "    return -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "datapath = '../data/DBLP/'\n",
    "if not os.path.exists(datapath):\n",
    "    os.makedirs(datapath)\n",
    "num_task = 8\n",
    "num_class = 4\n",
    "with open(datapath+'statistics', 'wb') as file:\n",
    "  pickle.dump((num_task, num_class), file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_processed = '/home/xiaoxue/data/dblp.v13/dblpv13_processed.json'\n",
    "\n",
    "# get set of ids for target papers, and get feature dict from the target papers\n",
    "target_papers_all_id = set()\n",
    "target_papers_year_id = [set() for i in range(num_task)]\n",
    "keywords_dic = defaultdict(int)\n",
    "fos_dic = defaultdict(int)\n",
    "\n",
    "infile = open(dataset_processed)\n",
    "n = 0\n",
    "paper_all_years = {}\n",
    "for x in infile:\n",
    "    item = json.loads(x)\n",
    "    if 'keywords' not in item or len(item['keywords']) == 0:\n",
    "        continue\n",
    "    if 'fos' not in item or len(item['fos']) == 0:\n",
    "        continue\n",
    "    if 'year' not in item or item['year'] > 2014:\n",
    "        continue\n",
    "    if 'venue' not in item or 'raw' not in item['venue']:\n",
    "        continue\n",
    "    research_area = get_research_area(item['venue']['raw'])\n",
    "    if research_area < 0:\n",
    "        continue\n",
    "    target_papers_all_id.add(item['_id'])\n",
    "    target_papers_year_id[max((item['year']-1999)//2, 0)].add(item['_id'])\n",
    "    \n",
    "    for kw in item['keywords']:\n",
    "            keywords_dic[kw] += 1\n",
    "    for fos in item['fos']:\n",
    "        fos_dic[fos] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "244\n"
     ]
    }
   ],
   "source": [
    "# select features from feature dict\n",
    "from copy import copy\n",
    "keywords_dic_reduced = copy(keywords_dic)\n",
    "\n",
    "for key in keywords_dic:\n",
    "    if keywords_dic[key] < 500:\n",
    "        del keywords_dic_reduced[key]\n",
    "\n",
    "fos_dic_reduced = copy(fos_dic)\n",
    "\n",
    "for key in fos_dic:\n",
    "    if fos_dic[key] < 500:\n",
    "        del fos_dic_reduced[key]\n",
    "\n",
    "keywords_idx = {}\n",
    "fos_idx = {}\n",
    "for key in keywords_dic_reduced:\n",
    "    keywords_idx[key] = len(keywords_idx)\n",
    "for key in fos_dic_reduced:\n",
    "    fos_idx[key] = len(fos_idx)\n",
    "    \n",
    "num_keywords = len(keywords_dic_reduced)\n",
    "num_fos = len(fos_dic_reduced)\n",
    "print (num_keywords+num_fos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "61185"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(target_papers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get dict for tagret papers and related papers: papers that cited target papers\n",
    "all_papers = {}\n",
    "\n",
    "infile = open(dataset_processed)\n",
    "# n = 0\n",
    "paper_all_years = {}\n",
    "for x in infile:\n",
    "    item = json.loads(x)\n",
    "    if 'keywords' not in item or len(item['keywords']) == 0:\n",
    "        continue\n",
    "    if 'fos' not in item or len(item['fos']) == 0:\n",
    "        continue\n",
    "    if 'year' not in item or item['year'] > 2014:\n",
    "        continue\n",
    "    if 'venue' not in item or 'raw' not in item['venue']:\n",
    "        continue\n",
    "    research_area = get_research_area(item['venue']['raw'])\n",
    "    references = item['references'] if 'references' in item else []\n",
    "#   save full information for target paper\n",
    "    if item['_id'] in target_papers_all_id:\n",
    "        year = item['year']\n",
    "        keyword_feat = [0 for i in range(num_keywords)]\n",
    "        fos_feat = [0 for i in range(num_fos)]\n",
    "        for kw in item['keywords']:\n",
    "            if kw in keywords_idx:\n",
    "                keyword_feat[keywords_idx[kw]] = 1\n",
    "            for fos in item['fos']:\n",
    "                if fos in fos_idx:\n",
    "                    fos_feat[fos_idx[fos]] = 1\n",
    "        paper = {'keyword':keyword_feat, 'fos':fos_feat, 'ref':references, 'class':research_area, 'year':year}\n",
    "#   only need reference information for related paper  \n",
    "    elif any (ref in target_papers_all_id for ref in references):\n",
    "        paper = {'ref':[ref for ref in references if (ref in target_papers_all_id)], 'year':year}\n",
    "#         related_paper_all_id.update(references)\n",
    "    else: continue\n",
    "    all_papers[item['_id']] = paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(datapath+'keywords.idx', 'wb') as file:\n",
    "  pickle.dump(keywords_idx, file)\n",
    "with open(datapath+'fos.idx', 'wb') as file:\n",
    "  pickle.dump(fos_idx, file)\n",
    "\n",
    "# with open('keywords.idx', 'rb') as file:\n",
    "#     keywords_idx = pickle.load(file)\n",
    "# #     print(keywords_idx)\n",
    "#\n",
    "# with open('fos.idx', 'rb') as file:\n",
    "#     fos_idx = pickle.load(file)\n",
    "# #     print(fos_idx)\n",
    "# num_keywords = len(keywords_idx)\n",
    "# num_fos = len(fos_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "# with open(datapath+'papers_by_year', 'wb') as file:\n",
    "#   pickle.dump(papers_year, file)\n",
    "\n",
    "with open(datapath+'all_papers', 'wb') as file:\n",
    "  pickle.dump(all_papers, file)\n",
    "\n",
    "# with open(datapath+'papers_by_year', 'rb') as file:\n",
    "#     papers_year = pickle.load(file)\n",
    "# with open(datapath+'papers_all', 'rb') as file:\n",
    "#     all_papers = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "379007\n"
     ]
    }
   ],
   "source": [
    "print (len(all_papers))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5000\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn [63], line 41\u001b[0m\n\u001b[1;32m     39\u001b[0m                 \u001b[38;5;28;01mif\u001b[39;00m ref_ref_id \u001b[38;5;129;01min\u001b[39;00m id2idx:\n\u001b[1;32m     40\u001b[0m                     g\u001b[38;5;241m.\u001b[39madd_edges(idx, id2idx[ref_ref_id])\n\u001b[0;32m---> 41\u001b[0m                     \u001b[43mg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_edges\u001b[49m\u001b[43m(\u001b[49m\u001b[43mid2idx\u001b[49m\u001b[43m[\u001b[49m\u001b[43mref_ref_id\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     42\u001b[0m \u001b[38;5;66;03m# print (node_features.shape)\u001b[39;00m\n\u001b[1;32m     43\u001b[0m \n\u001b[1;32m     44\u001b[0m \u001b[38;5;66;03m# for i in range(len(node_features)):\u001b[39;00m\n\u001b[1;32m     45\u001b[0m \u001b[38;5;66;03m#     if node_features[i] == []:\u001b[39;00m\n\u001b[1;32m     46\u001b[0m \u001b[38;5;66;03m#         print (i)\u001b[39;00m\n\u001b[1;32m     47\u001b[0m node_features \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(node_features)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.8/site-packages/dgl/heterograph.py:463\u001b[0m, in \u001b[0;36mDGLHeteroGraph.add_edges\u001b[0;34m(self, u, v, data, etype)\u001b[0m\n\u001b[1;32m    358\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Add multiple new edges for the specified edge type\u001b[39;00m\n\u001b[1;32m    359\u001b[0m \n\u001b[1;32m    360\u001b[0m \u001b[38;5;124;03mThe i-th new edge will be from ``u[i]`` to ``v[i]``.\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    460\u001b[0m \u001b[38;5;124;03mremove_edges\u001b[39;00m\n\u001b[1;32m    461\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    462\u001b[0m \u001b[38;5;66;03m# TODO(xiangsx): block do not support add_edges\u001b[39;00m\n\u001b[0;32m--> 463\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprepare_tensor\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mu\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mu\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    464\u001b[0m v \u001b[38;5;241m=\u001b[39m utils\u001b[38;5;241m.\u001b[39mprepare_tensor(\u001b[38;5;28mself\u001b[39m, v, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mv\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    466\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m etype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.8/site-packages/dgl/utils/checks.py:46\u001b[0m, in \u001b[0;36mprepare_tensor\u001b[0;34m(g, data, name)\u001b[0m\n\u001b[1;32m     42\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m (F\u001b[38;5;241m.\u001b[39mndim(data) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m F\u001b[38;5;241m.\u001b[39mshape(data)[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m        \u001b[38;5;66;03m# empty tensor\u001b[39;00m\n\u001b[1;32m     43\u001b[0m             F\u001b[38;5;241m.\u001b[39mdtype(data) \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (F\u001b[38;5;241m.\u001b[39mint32, F\u001b[38;5;241m.\u001b[39mint64)):\n\u001b[1;32m     44\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m DGLError(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mExpect argument \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m to have data type int32 or int64,\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m     45\u001b[0m                        \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m but got \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(name, F\u001b[38;5;241m.\u001b[39mdtype(data)))\n\u001b[0;32m---> 46\u001b[0m     ret \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mcopy_to(F\u001b[38;5;241m.\u001b[39mastype(data, g\u001b[38;5;241m.\u001b[39midtype), \u001b[43mg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m)\n\u001b[1;32m     48\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m F\u001b[38;5;241m.\u001b[39mndim(ret) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m     49\u001b[0m     ret \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39munsqueeze(ret, \u001b[38;5;241m0\u001b[39m)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.8/site-packages/dgl/heterograph.py:5393\u001b[0m, in \u001b[0;36mDGLHeteroGraph.device\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   5368\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m   5369\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdevice\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m   5370\u001b[0m     \u001b[38;5;124;03m\"\"\"Get the device of the graph.\u001b[39;00m\n\u001b[1;32m   5371\u001b[0m \n\u001b[1;32m   5372\u001b[0m \u001b[38;5;124;03m    Returns\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   5391\u001b[0m \u001b[38;5;124;03m    The case of heterogeneous graphs is the same.\u001b[39;00m\n\u001b[1;32m   5392\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m-> 5393\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mto_backend_ctx(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_graph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mctx\u001b[49m)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.8/site-packages/dgl/heterograph_index.py:182\u001b[0m, in \u001b[0;36mHeteroGraphIndex.ctx\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    173\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m    174\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mctx\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    175\u001b[0m     \u001b[38;5;124;03m\"\"\"Return the context of this graph index.\u001b[39;00m\n\u001b[1;32m    176\u001b[0m \n\u001b[1;32m    177\u001b[0m \u001b[38;5;124;03m    Returns\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    180\u001b[0m \u001b[38;5;124;03m        The context of the graph.\u001b[39;00m\n\u001b[1;32m    181\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 182\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_CAPI_DGLHeteroContext\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32mdgl/_ffi/_cython/./function.pxi:294\u001b[0m, in \u001b[0;36mdgl._ffi._cy3.core.FunctionBase.__call__\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32mdgl/_ffi/_cython/./function.pxi:186\u001b[0m, in \u001b[0;36mdgl._ffi._cy3.core.make_ret\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.8/site-packages/dgl/_ffi/runtime_ctypes.py:142\u001b[0m, in \u001b[0;36mDGLContext.__new__\u001b[0;34m(cls, device_type, device_id)\u001b[0m\n\u001b[1;32m    140\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__new__\u001b[39m(\u001b[38;5;28mcls\u001b[39m, device_type, device_id):\n\u001b[1;32m    141\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m (device_type, device_id) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_cache:\n\u001b[0;32m--> 142\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cache\u001b[49m\u001b[43m[\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_id\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m    144\u001b[0m     inst \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msuper\u001b[39m(DGLContext, \u001b[38;5;28mcls\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__new__\u001b[39m(DGLContext)\n\u001b[1;32m    146\u001b[0m     inst\u001b[38;5;241m.\u001b[39mdevice_type \u001b[38;5;241m=\u001b[39m device_type\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# create graph for all target papers from 2005-2014\n",
    "id2idx = {}\n",
    "\n",
    "for id in target_papers_all_id:\n",
    "    id2idx[id] = len(id2idx)\n",
    "\n",
    "# for paper_id in all_papers:\n",
    "#     id2idx_all[paper_id] = len(id2idx_all)\n",
    "\n",
    "# for time_slot in range(num_task):\n",
    "#     for paper_id in papers_year[time_slot]:\n",
    "#         if paper_id not in id2idx_all:\n",
    "#             id2idx_all[paper_id] = len(id2idx_all)\n",
    "# print (len(id2idx_all))\n",
    "node_features = [[] for i in range(len(id2idx))]\n",
    "\n",
    "# node_features = np.zeros((len(id2idx_all), num_keywords+num_fos))\n",
    "class_label = [-1 for i in range(len(id2idx))]\n",
    "g = dgl.DGLGraph()\n",
    "g.add_nodes(len(id2idx))\n",
    "n = 0\n",
    "for paper_id in id2idx:\n",
    "    n+=1\n",
    "    if n%5000 == 0: print (n)\n",
    "    paper_year = all_papers[paper_id]['year']\n",
    "#     if paper_year < 2005:\n",
    "#         continue\n",
    "    idx = id2idx[paper_id]\n",
    "    node_features[idx] = all_papers[paper_id]['keyword']+all_papers[paper_id]['fos']\n",
    "    class_label[idx] = all_papers[paper_id]['class']\n",
    "    for ref_id in all_papers[paper_id]['ref']:\n",
    "        if ref_id not in all_papers or ref_id == paper_id or all_papers[ref_id]['year'] > paper_year:\n",
    "            continue\n",
    "        if ref_id in id2idx:\n",
    "            g.add_edges(idx, id2idx[ref_id])\n",
    "            g.add_edges(id2idx[ref_id], idx)\n",
    "    \n",
    "        for ref_ref_id in all_papers[ref_id]['ref']:\n",
    "            if ref_ref_id in id2idx:\n",
    "                g.add_edges(idx, id2idx[ref_ref_id])\n",
    "                g.add_edges(id2idx[ref_ref_id], idx)\n",
    "# print (node_features.shape)\n",
    "\n",
    "# for i in range(len(node_features)):\n",
    "#     if node_features[i] == []:\n",
    "#         print (i)\n",
    "node_features = torch.tensor(node_features)\n",
    "class_label = torch.tensor(class_label)\n",
    "# print (node_features.size())\n",
    "g.ndata['x'] = node_features\n",
    "g.ndata['y'] = class_label\n",
    "with open(datapath+'graph_whole', 'wb') as file:\n",
    "  pickle.dump(g, file)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Graph(num_nodes=61185, num_edges=725040,\n",
       "      ndata_schemes={'x': Scheme(shape=(244,), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64)}\n",
       "      edata_schemes={})"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph(num_nodes=18007, num_edges=19892,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n",
      "Graph(num_nodes=4483, num_edges=5290,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n",
      "Graph(num_nodes=6558, num_edges=7113,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n",
      "Graph(num_nodes=10564, num_edges=12996,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n",
      "Graph(num_nodes=14374, num_edges=19260,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n",
      "Graph(num_nodes=14545, num_edges=19672,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n",
      "Graph(num_nodes=14545, num_edges=21732,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n",
      "Graph(num_nodes=13998, num_edges=22820,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(244,), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n"
     ]
    }
   ],
   "source": [
    "# create graphs by spliting the graph by edges of each year\n",
    "g_list = []\n",
    "for time_slot in range(num_task):\n",
    "    id2idx_t = {}\n",
    "#     start_year = str(2005+time_slot*2)\n",
    "#     end_year = str(2005+time_slot*2+1)\n",
    "    for paper_id in target_papers_year_id[time_slot]:\n",
    "        id2idx_t[paper_id] = len(id2idx)\n",
    "    num_new_papers = len(id2idx_t)\n",
    "    # print (num_new_papers)\n",
    "    for paper_id in target_papers_year_id[time_slot]:\n",
    "        paper = all_papers[paper_id]\n",
    "        paper_year = paper['year']\n",
    "        for ref_id in paper['ref']:\n",
    "            if ref_id not in all_papers or ref_id == paper_id or all_papers[ref_id]['year'] > paper_year:\n",
    "                continue\n",
    "            if ref_id not in id2idx_t and ref_id in id2idx:\n",
    "                id2idx_t[ref_id] = len(id2idx_t)\n",
    "            for ref_ref_id in all_papers[ref_id]['ref']:\n",
    "                if ref_ref_id not in id2idx_t and ref_ref_id in id2idx:\n",
    "                    id2idx_t[ref_ref_id] = len(id2idx_t)\n",
    "                    \n",
    "    node_features = [[] for i in range(len(id2idx_t))]\n",
    "    node_idxs = [-1 for i in range(len(id2idx_t))]\n",
    "    class_label = [-1 for i in range(len(id2idx_t))]\n",
    "    g = dgl.DGLGraph()\n",
    "    g.add_nodes(len(id2idx_t))\n",
    "    for paper_id in id2idx_t:\n",
    "        idx = id2idx[paper_id]\n",
    "        node_features[idx] = all_papers[paper_id]['keyword']+all_papers[paper_id]['fos']\n",
    "        node_idxs[idx] = id2idx[paper_id]\n",
    "        class_label[idx] = all_papers[paper_id]['class']\n",
    "#     print (class_label)\n",
    "    for paper_id in target_papers_year_id[time_slot]:\n",
    "        paper = all_papers[paper_id]\n",
    "        idx = id2idx[paper_id]\n",
    "        paper_year = paper['year']\n",
    "        for ref_id in paper['ref']:\n",
    "            if ref_id not in all_papers or ref_id == paper_id or all_papers[ref_id]['year'] > paper_year:\n",
    "                continue\n",
    "            if ref_id in id2idx_t:\n",
    "                g.add_edges(idx, id2idx_t[ref_id])\n",
    "                g.add_edges(id2idx_t[ref_id], idx)\n",
    "            for ref_ref_id in all_papers[ref_id]['ref']:\n",
    "                if ref_ref_id in id2idx_t:\n",
    "                    g.add_edges(idx, id2idx_t[ref_ref_id])\n",
    "                    g.add_edges(id2idx_t[ref_ref_id], idx)\n",
    "    node_features = torch.tensor(node_features)\n",
    "    class_label = torch.tensor(class_label)\n",
    "    node_idxs = torch.tensor(node_idxs)\n",
    "    g.ndata['num_new_nodes'] = torch.tensor([num_new_papers for i in range(len(id2idx))])\n",
    "    g.ndata['x'] = node_features\n",
    "    g.ndata['node_idxs'] = node_idxs\n",
    "    g.ndata['y'] = class_label\n",
    "    print (g)\n",
    "    g_list.append(g)\n",
    "    with open(datapath+f'graph_{time_slot}_by_edges', 'wb') as file:\n",
    "      pickle.dump(g, file)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "s= {1,2,3}\n",
    "a = [8,4,5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "any(x in s for x in a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "s.update([5,6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1, 2, 3, 5, 6}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[3, 5, 6]"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[x for x in s if x>2] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
