{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cd2127bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "from dgl.data import register_data_args\n",
    "import dgl\n",
    "import dgl.function as fn\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from dgl import DGLGraph\n",
    "import pickle\n",
    "import random\n",
    "import numpy as np\n",
    "# from sklearn.metrics import balanced_accuracy_score, f1_score, accuracy_score\n",
    "import csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fd6f6098",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'Amazon-Kindle'\n",
    "datapath = '../data-selected-classes/Amazon-Kindle/'\n",
    "train_ratio, valid_ratio, test_ratio = 0.3, 0.2, 0.5\n",
    "\n",
    "with open(datapath+'statistics', 'rb') as file:\n",
    "    num_tasks, num_class = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c7164671",
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "invalid syntax (<ipython-input-3-75c77789e5c9>, line 19)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  Cell \u001b[0;32mIn [3], line 19\u001b[0;36m\u001b[0m\n\u001b[0;31m    writer.writerow([f'{x} ({y:.1f}%)' zip(class_numbers, class_ratios)])\u001b[0m\n\u001b[0m                                       ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
     ]
    }
   ],
   "source": [
    "# with open(datapath+f'sub_graph_whole', 'rb') as file:\n",
    "#     g_whole = pickle.load(file)\n",
    "# n_nodes_whole = g_whole.num_nodes()\n",
    "# train_mask_whole, valid_mask_whole, test_mask_whole = \\\n",
    "# torch.tensor([False for i in range(n_nodes_whole)]), torch.tensor([False for i in range(n_nodes_whole)]), torch.tensor([False for i in range(n_nodes_whole)])\n",
    "# train_node_idxs_set, valid_node_idxs_set, test_node_idxs_set = set(),set(),set()\n",
    "# masks_supgraphs_list = []\n",
    "\n",
    "with open('class_statistics.csv', 'w', newline='') as csvfile:\n",
    "    writer = csv.writer(csvfile)\n",
    "    for time_slot in range(num_tasks):\n",
    "        with open(datapath+f'sub_graph_{time_slot}_by_edges', 'rb') as file:\n",
    "            g = pickle.load(file)\n",
    "\n",
    "        class_numbers = []\n",
    "        for label in range(num_class):\n",
    "            class_numbers.append(sum(g.ndata['y'] == label).item())\n",
    "        class_ratios = [x/sum(class_ratio)*100 for x in class_ratios]\n",
    "        writer.writerow([f'{x} ({y:.1f}%)' for (x,y) in zip(class_numbers, class_ratios)])\n",
    "#     print (g.ndata['y'])\n",
    "#     break\n",
    "        \n",
    "#         n_nodes = g.num_nodes()\n",
    "#         excluded_class = random.randint(0,num_class-1)\n",
    "#         selected_class_mask = (g.ndata['y']-excluded_class).bool()\n",
    "# #         new_nodes_mask = g.ndata['new_nodes_mask']\n",
    "#         new_nodes_mask = torch.logical_and(selected_class_mask,g.ndata['new_nodes_mask'])\n",
    "#         print (sum(new_nodes_mask))\n",
    "#         n_new_nodes = sum(new_nodes_mask)\n",
    "#         new_node_idxs = new_nodes_mask.nonzero()\n",
    "#         n_new_nodes = len(new_node_idxs)\n",
    "#         shuffled_ind = np.array([i for i in range(n_new_nodes)])\n",
    "#         random.shuffle(shuffled_ind)\n",
    "#         ratio_train, ratio_valid, ratio_test = train_ratio, valid_ratio, test_ratio\n",
    "#         # mask_file = f'data/masks_{args.Dataset}_t_{time_slot}_run_{}'\n",
    "#         n_train, n_valid, n_test = int(ratio_train*n_new_nodes), int(ratio_valid*n_new_nodes), int(ratio_test*n_new_nodes)\n",
    "#         ind_train, ind_valid, ind_test = \\\n",
    "#         shuffled_ind[:n_train], shuffled_ind[n_train:n_train+n_valid], shuffled_ind[n_train+n_valid:n_new_nodes]\n",
    "#         ind_train, ind_valid, ind_test = new_node_idxs[ind_train],new_node_idxs[ind_valid],new_node_idxs[ind_test]\n",
    "\n",
    "#         train_mask, valid_mask, test_mask = torch.tensor([False for i in range(n_nodes)]), torch.tensor([False for i in range(n_nodes)]), torch.tensor([False for i in range(n_nodes)])\n",
    "#         train_mask[ind_train] = True\n",
    "#         valid_mask[ind_valid] = True\n",
    "#         test_mask[ind_test] = True\n",
    "#         print (sum(train_mask))\n",
    "#         print (sum(valid_mask))\n",
    "#         print (sum(test_mask))\n",
    "#         break\n",
    "               \n",
    "\n",
    "#         masks_supgraphs_list.append((train_mask, valid_mask, test_mask))\n",
    "\n",
    "#         train_node_idxs = g.ndata['node_idxs'][train_mask]\n",
    "#         valid_node_idxs = g.ndata['node_idxs'][valid_mask]\n",
    "#         test_node_idxs = g.ndata['node_idxs'][test_mask]\n",
    "\n",
    "#         train_node_idxs_set.update(train_node_idxs.tolist())\n",
    "#         valid_node_idxs_set.update(valid_node_idxs.tolist())\n",
    "#         test_node_idxs_set.update(test_node_idxs.tolist())\n",
    "#         # train_mask_whole[train_node_idxs] = True\n",
    "#         # valid_mask_whole[valid_node_idxs] = True\n",
    "#         # test_mask_whole[test_node_idxs] = True\n",
    "#     train_mask_idxs_whole = torch.tensor([i for i in range(n_nodes_whole) if g_whole.ndata['node_idxs'][i].detach().item() in train_node_idxs_set])\n",
    "#     valid_mask_idxs_whole = torch.tensor([i for i in range(n_nodes_whole) if g_whole.ndata['node_idxs'][i].detach().item() in valid_node_idxs_set])\n",
    "#     test_mask_idxs_whole = torch.tensor([i for i in range(n_nodes_whole) if g_whole.ndata['node_idxs'][i].detach().item() in test_node_idxs_set])\n",
    "#     print (train_mask_idxs_whole.dtype)\n",
    "    \n",
    "#     train_mask_whole[train_mask_idxs_whole.long()] = True\n",
    "#     valid_mask_whole[valid_mask_idxs_whole.long()] = True\n",
    "#     test_mask_whole[test_mask_idxs_whole.long()] = True\n",
    "\n",
    "#     mask_whole = (train_mask_whole, valid_mask_whole, test_mask_whole)\n",
    "#     # run = 0\n",
    "#     with open(datapath+f'mask_seed_{run}', 'wb') as file:\n",
    "#       pickle.dump((masks_supgraphs_list, mask_whole), file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8b692cca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "1\n",
      "0\n",
      "2\n",
      "3\n",
      "2\n",
      "1\n",
      "3\n",
      "3\n",
      "1\n",
      "2\n",
      "2\n",
      "0\n",
      "1\n",
      "1\n",
      "2\n",
      "1\n",
      "1\n",
      "1\n",
      "0\n",
      "3\n",
      "1\n",
      "0\n",
      "3\n",
      "2\n",
      "1\n",
      "0\n",
      "0\n",
      "1\n",
      "0\n",
      "3\n",
      "2\n",
      "0\n",
      "2\n",
      "2\n",
      "3\n",
      "0\n",
      "0\n",
      "2\n",
      "1\n",
      "2\n",
      "2\n",
      "0\n",
      "3\n",
      "3\n",
      "1\n",
      "0\n",
      "0\n",
      "2\n",
      "0\n",
      "0\n",
      "1\n",
      "0\n",
      "0\n",
      "0\n",
      "3\n",
      "1\n",
      "0\n",
      "3\n",
      "1\n",
      "2\n",
      "0\n",
      "2\n",
      "0\n",
      "2\n",
      "1\n",
      "3\n",
      "2\n",
      "0\n",
      "0\n",
      "0\n",
      "3\n",
      "0\n",
      "1\n",
      "2\n",
      "0\n",
      "2\n",
      "3\n",
      "1\n",
      "3\n"
     ]
    }
   ],
   "source": [
    "for run in range(10):\n",
    "    \n",
    "    with open(datapath+f'sub_graph_whole', 'rb') as file:\n",
    "        g_whole = pickle.load(file)\n",
    "    \n",
    "    \n",
    "    n_nodes_whole = g_whole.num_nodes()\n",
    "    train_mask_whole, valid_mask_whole, test_mask_whole = \\\n",
    "    torch.tensor([False for i in range(n_nodes_whole)]), torch.tensor([False for i in range(n_nodes_whole)]), torch.tensor([False for i in range(n_nodes_whole)])\n",
    "    train_node_idxs_set, valid_node_idxs_set, test_node_idxs_set = set(),set(),set()\n",
    "    masks_supgraphs_list = []\n",
    "    \n",
    "    for time_slot in range(num_tasks):\n",
    "        \n",
    "        with open(datapath+f'sub_graph_{time_slot}_by_edges', 'rb') as file:\n",
    "            g = pickle.load(file)\n",
    "        n_nodes = g.num_nodes()\n",
    "        excluded_class = random.randint(0,num_class-1)\n",
    "        print (excluded_class)\n",
    "        selected_class_mask = (g.ndata['y']-excluded_class).bool()\n",
    "#         new_nodes_mask = g.ndata['new_nodes_mask']\n",
    "        new_nodes_mask = torch.logical_and(selected_class_mask,g.ndata['new_nodes_mask'])\n",
    "        n_new_nodes = sum(new_nodes_mask)\n",
    "        new_node_idxs = new_nodes_mask.nonzero()\n",
    "#         print (g.ndata['y'][new_node_idxs].tolist())\n",
    "        n_new_nodes = len(new_node_idxs)\n",
    "        shuffled_ind = np.array([i for i in range(n_new_nodes)])\n",
    "        random.shuffle(shuffled_ind)\n",
    "        ratio_train, ratio_valid, ratio_test = train_ratio, valid_ratio, test_ratio\n",
    "        # mask_file = f'data/masks_{args.Dataset}_t_{time_slot}_run_{}'\n",
    "        n_train, n_valid, n_test = int(ratio_train*n_new_nodes), int(ratio_valid*n_new_nodes), int(ratio_test*n_new_nodes)\n",
    "        ind_train, ind_valid, ind_test = \\\n",
    "        shuffled_ind[:n_train], shuffled_ind[n_train:n_train+n_valid], shuffled_ind[n_train+n_valid:n_new_nodes]\n",
    "        ind_train, ind_valid, ind_test = new_node_idxs[ind_train],new_node_idxs[ind_valid],new_node_idxs[ind_test]\n",
    "\n",
    "        train_mask, valid_mask, test_mask = torch.tensor([False for i in range(n_nodes)]), torch.tensor([False for i in range(n_nodes)]), torch.tensor([False for i in range(n_nodes)])\n",
    "        train_mask[ind_train] = True\n",
    "        valid_mask[ind_valid] = True\n",
    "        test_mask[ind_test] = True\n",
    "\n",
    "        masks_supgraphs_list.append((train_mask, valid_mask, test_mask))\n",
    "\n",
    "        train_node_idxs = g.ndata['node_idxs'][train_mask]\n",
    "        valid_node_idxs = g.ndata['node_idxs'][valid_mask]\n",
    "        test_node_idxs = g.ndata['node_idxs'][test_mask]\n",
    "\n",
    "        train_node_idxs_set.update(train_node_idxs.tolist())\n",
    "        valid_node_idxs_set.update(valid_node_idxs.tolist())\n",
    "        test_node_idxs_set.update(test_node_idxs.tolist())\n",
    "        # train_mask_whole[train_node_idxs] = True\n",
    "        # valid_mask_whole[valid_node_idxs] = True\n",
    "        # test_mask_whole[test_node_idxs] = True\n",
    "    train_mask_idxs_whole = torch.tensor([i for i in range(n_nodes_whole) if g_whole.ndata['node_idxs'][i].detach().item() in train_node_idxs_set])\n",
    "    valid_mask_idxs_whole = torch.tensor([i for i in range(n_nodes_whole) if g_whole.ndata['node_idxs'][i].detach().item() in valid_node_idxs_set])\n",
    "    test_mask_idxs_whole = torch.tensor([i for i in range(n_nodes_whole) if g_whole.ndata['node_idxs'][i].detach().item() in test_node_idxs_set])\n",
    "    train_mask_whole[train_mask_idxs_whole] = True\n",
    "    valid_mask_whole[valid_mask_idxs_whole] = True\n",
    "    test_mask_whole[test_mask_idxs_whole] = True\n",
    "\n",
    "    mask_whole = (train_mask_whole, valid_mask_whole, test_mask_whole)\n",
    "    # run = 0\n",
    "    with open(datapath+f'mask_seed_{run}', 'wb') as file:\n",
    "      pickle.dump((masks_supgraphs_list, mask_whole), file)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87f10f4c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
