{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import pickle\n",
    "# import pandas as pd\n",
    "from urllib.request import urlopen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nltk.tokenize import word_tokenize\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "datafile = '/home/xiaoxue/data/acm.v9/ACM_processed.pickle'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "### load the meta data\n",
    "with open(datafile, 'rb') as handle:\n",
    "    data = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "506"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "venue_dict = defaultdict(int)\n",
    "for i, (id, item) in enumerate(data.items()):\n",
    "    if 'venue' in item:\n",
    "        venue_dict[item['venue']] += 1\n",
    "\n",
    "venue_list = []\n",
    "for venue in venue_dict:\n",
    "    venue_list.append((venue_dict[venue], venue))\n",
    "venue_list = sorted(venue_list, reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "venue_set = set([v[1] for v in venue_list])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from venue_classes import iss, sp, am, ai\n",
    "iss, sp, am, ai = set(iss), set(sp), set(am), set(ai)\n",
    "\n",
    "def get_label(venue):\n",
    "    if venue in iss: return 0 \n",
    "    if venue in sp: return 1\n",
    "    if venue in am: return 2\n",
    "    if venue in ai: return 3\n",
    "    return -1\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_year = 1999\n",
    "end_year = 2014\n",
    "num_task = 8\n",
    "num_papers_year = [0 for i in range(num_task)]\n",
    "\n",
    "token_dic = defaultdict(int)\n",
    "for i, (id, item) in enumerate(data.items()):\n",
    "    if 'venue' in item and 'year' in item and 'title' in item:\n",
    "        if get_label(item['venue']) == -1: continue\n",
    "        year = item['year']\n",
    "        if year < start_year or year > end_year:\n",
    "            continue\n",
    "        time_slot = max((year-start_year)//2, 0)\n",
    "        num_papers_year[time_slot] += 1\n",
    "        \n",
    "        tokens = word_tokenize(item['title'])\n",
    "#         print (tokens)\n",
    "        for token in tokens:\n",
    "            token_dic[token.lower()] += 1  \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "926"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "token_idx = {}\n",
    "for token in token_dic:\n",
    "    if token_dic[token] < 5000 and token_dic[token]>200:\n",
    "        token_idx[token] = len(token_idx)\n",
    "        \n",
    "len(token_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "datapath = '../data/ACM/'\n",
    "if not os.path.exists(datapath):\n",
    "    os.makedirs(datapath)\n",
    "num_task = 8\n",
    "num_class = 4\n",
    "with open(datapath+'statistics', 'wb') as file:\n",
    "  pickle.dump((num_task, num_class), file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_feature(text):\n",
    "    feature = [0 for i in range(len(token_idx))]\n",
    "    tokens = word_tokenize(text)\n",
    "    for token in tokens:\n",
    "        if token.lower() in token_idx:\n",
    "            feature[token_idx[token.lower()]] += 1\n",
    "    return feature\n",
    "    \n",
    "\n",
    "all_items = {}\n",
    "items_year_id = [set() for i in range(num_task)]\n",
    "\n",
    "for i, (id, item) in enumerate(data.items()):\n",
    "    if 'venue' in item and 'year' in item and 'title' in item:\n",
    "        label = get_label(item['venue'])\n",
    "        if label == -1: continue\n",
    "        all_items[item['id']] = {'label':label, 'feature': get_feature(item['title']), 'ref': item['ref'] if 'ref' in item else set()}     \n",
    "        \n",
    "        year = item['year']\n",
    "        if year < start_year or year > end_year:\n",
    "            continue\n",
    "        time_slot = (year-start_year)//2\n",
    "#         print (time_slot)\n",
    "        items_year_id[time_slot].add(item['id'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "id2idx = {}\n",
    "for item_id in all_items:\n",
    "    id2idx[item_id] = len(id2idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph(num_nodes=42705, num_edges=60616,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(457,), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n",
      "Graph(num_nodes=40075, num_edges=58846,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(457,), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n",
      "Graph(num_nodes=51206, num_edges=121078,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(457,), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n",
      "Graph(num_nodes=54092, num_edges=197812,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(457,), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n",
      "Graph(num_nodes=34332, num_edges=125474,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(457,), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'node_idxs': Scheme(shape=(), dtype=torch.int64)}\n",
      "      edata_schemes={})\n",
      "Graph(num_nodes=0, num_edges=0,\n",
      "      ndata_schemes={'num_new_nodes': Scheme(shape=(), dtype=torch.float32), 'x': Scheme(shape=(), dtype=torch.float32), 'y': Scheme(shape=(), dtype=torch.float32), 'node_idxs': Scheme(shape=(), dtype=torch.float32)}\n",
      "      edata_schemes={})\n"
     ]
    }
   ],
   "source": [
    "import dgl\n",
    "import torch\n",
    "\n",
    "g_list = []\n",
    "for time_slot in range(num_tasks):\n",
    "    id2idx_t = {}\n",
    "    for item_id in items_year_id[time_slot]:\n",
    "        id2idx_t[item_id] = len(id2idx_t)\n",
    "    num_new_items = len(id2idx_t)\n",
    "    for item_id in items_year_id[time_slot]:\n",
    "        item_year = item['year']\n",
    "        for rel_id in all_items[item_id]['ref']:\n",
    "            if rel_id in all_items and all_items[rel_id]['year'] <= item_year:\n",
    "                if rel_id not in id2idx_t:\n",
    "                    id2idx_t[rel_id] = len(id2idx_t)\n",
    "    \n",
    "    node_features = [[] for i in range(len(id2idx_t))]\n",
    "    node_idxs = [-1 for i in range(len(id2idx_t))]\n",
    "    class_label = [-1 for i in range(len(id2idx_t))]\n",
    "    g = dgl.DGLGraph()\n",
    "    g.add_nodes(len(id2idx_t))\n",
    "    \n",
    "    for item_id in id2idx_t:\n",
    "        idx = id2idx_t[item_id]\n",
    "#         year_idx_ = all_books[book_id]['year']-start_year\n",
    "        node_idxs[idx] = id2idx[book_id]\n",
    "        node_features[idx] = all_items[book_id]['feature']\n",
    "        class_label[idx] = all_items[book_id]['label']\n",
    "        \n",
    "    for item_id in items_year_id[year_idx]:\n",
    "        item_year = item['year']\n",
    "        for rel_id in all_items[item_id]['related_books']:\n",
    "            if rel_id in all_items and all_items[rel_id]['year'] <= item_year:\n",
    "                g.add_edges(id2idx_t[book_id], id2idx_t[rel_id])\n",
    "                g.add_edges(id2idx_t[rel_id], id2idx_t[book_id])\n",
    "    node_features = torch.tensor(node_features)\n",
    "    class_label = torch.tensor(class_label)\n",
    "    node_idxs = torch.tensor(node_idxs)\n",
    "#     node_idxs[idx] = id2idx[book_id]\n",
    "    g.ndata['num_new_nodes'] = torch.tensor([num_new_books for i in range(len(id2idx_t))])\n",
    "    g.ndata['x'] = node_features\n",
    "    g.ndata['y'] = class_label\n",
    "    g.ndata['node_idxs'] = node_idxs\n",
    "    print (g)\n",
    "    g_list.append(g)\n",
    "#     break\n",
    "    with open(datapath+f'graph_{year_idx}_by_edges', 'wb') as file:\n",
    "      pickle.dump(g, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# f = open('statistics.csv', 'w')\n",
    "# writer = csv.writer(f)\n",
    "\n",
    "# for i in range (len(book_category_years)):\n",
    "#     row = ['{} ({:.2f})'.format(x, 100*x/sum(book_category_years[i])) for x in book_category_years[i]]\n",
    "#     writer.writerow(row)\n",
    "# f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([182985, 457])\n"
     ]
    }
   ],
   "source": [
    "# import dgl\n",
    "# import torch\n",
    "# import pickle\n",
    "# # create graph for all books from 2011-2016\n",
    "# id2idx = {}\n",
    "\n",
    "# for time_slot in range(6):\n",
    "#     for book_id in books_year[time_slot]:\n",
    "#         if book_id not in id2idx:\n",
    "#             id2idx[book_id] = len(id2idx)\n",
    "\n",
    "# print (len(id2idx))\n",
    "\n",
    "node_features = [[] for i in range(len(id2idx))]\n",
    "class_label = [-1 for i in range(len(id2idx))]\n",
    "g = dgl.DGLGraph()\n",
    "g.add_nodes(len(id2idx))\n",
    "for time_slot in range(6):\n",
    "    for book_id in book_year_id[time_slot]:\n",
    "        book_year = all_books[book_id]['year']\n",
    "        year_idx = book_year-start_year\n",
    "        idx = id2idx[book_id]\n",
    "        node_features[idx] = all_books[book_id]['feature']\n",
    "        class_label[idx] = all_books[book_id]['label']\n",
    "        for rel_id in all_books[book_id]['related_books']:\n",
    "            if rel_id in id2idx and all_books[rel_id]['year'] <= year_idx+start_year:\n",
    "                g.add_edges(idx, id2idx[rel_id])\n",
    "                g.add_edges(id2idx[rel_id], idx)\n",
    "node_features = torch.tensor(node_features)\n",
    "class_label = torch.tensor(class_label)\n",
    "print (node_features.size())\n",
    "g.ndata['x'] = node_features\n",
    "g.ndata['y'] = class_label\n",
    "with open(datapath+'full_graph', 'wb') as file:\n",
    "  pickle.dump(g, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(datapath+'graph_whole', 'wb') as file:\n",
    "  pickle.dump(g, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1 if False else 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
