{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cd2127bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "from dgl.data import register_data_args\n",
    "import dgl\n",
    "import dgl.function as fn\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from dgl import DGLGraph\n",
    "import pickle\n",
    "import random\n",
    "import numpy as np\n",
    "# from sklearn.metrics import balanced_accuracy_score, f1_score, accuracy_score\n",
    "import csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fd6f6098",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'Amazon-Kindle'\n",
    "datapath = '../data-selected-classes/Amazon-Kindle/'\n",
    "train_ratio, valid_ratio, test_ratio = 0.3, 0.2, 0.5\n",
    "\n",
    "with open(datapath+'statistics', 'rb') as file:\n",
    "    num_tasks, num_class = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c7164671",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(2322)\n",
      "tensor(696)\n",
      "tensor(464)\n",
      "tensor(1162)\n"
     ]
    },
    {
     "ename": "IndexError",
     "evalue": "tensors used as indices must be long, byte or bool tensors",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn [5], line 55\u001b[0m\n\u001b[1;32m     53\u001b[0m valid_mask_idxs_whole \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor([i \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n_nodes_whole) \u001b[38;5;28;01mif\u001b[39;00m g_whole\u001b[38;5;241m.\u001b[39mndata[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnode_idxs\u001b[39m\u001b[38;5;124m'\u001b[39m][i]\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;129;01min\u001b[39;00m valid_node_idxs_set])\n\u001b[1;32m     54\u001b[0m test_mask_idxs_whole \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor([i \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n_nodes_whole) \u001b[38;5;28;01mif\u001b[39;00m g_whole\u001b[38;5;241m.\u001b[39mndata[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnode_idxs\u001b[39m\u001b[38;5;124m'\u001b[39m][i]\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;129;01min\u001b[39;00m test_node_idxs_set])\n\u001b[0;32m---> 55\u001b[0m \u001b[43mtrain_mask_whole\u001b[49m\u001b[43m[\u001b[49m\u001b[43mtrain_mask_idxs_whole\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m     56\u001b[0m valid_mask_whole[valid_mask_idxs_whole] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m     57\u001b[0m test_mask_whole[test_mask_idxs_whole] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
      "\u001b[0;31mIndexError\u001b[0m: tensors used as indices must be long, byte or bool tensors"
     ]
    }
   ],
   "source": [
    "for run in range(10):\n",
    "    with open(datapath+f'sub_graph_whole', 'rb') as file:\n",
    "        g_whole = pickle.load(file)\n",
    "    n_nodes_whole = g_whole.num_nodes()\n",
    "    train_mask_whole, valid_mask_whole, test_mask_whole = \\\n",
    "    torch.tensor([False for i in range(n_nodes_whole)]), torch.tensor([False for i in range(n_nodes_whole)]), torch.tensor([False for i in range(n_nodes_whole)])\n",
    "    train_node_idxs_set, valid_node_idxs_set, test_node_idxs_set = set(),set(),set()\n",
    "    masks_supgraphs_list = []\n",
    "    for time_slot in range(num_tasks):\n",
    "        with open(datapath+f'sub_graph_{time_slot}_by_edges', 'rb') as file:\n",
    "            g = pickle.load(file)\n",
    "        n_nodes = g.num_nodes()\n",
    "        excluded_class = random.randint(0,num_class-1)\n",
    "        selected_class_mask = (g.ndata['y']-excluded_class).bool()\n",
    "#         new_nodes_mask = g.ndata['new_nodes_mask']\n",
    "        new_nodes_mask = torch.logical_and(selected_class_mask,g.ndata['new_nodes_mask'])\n",
    "        print (sum(new_nodes_mask))\n",
    "        n_new_nodes = sum(new_nodes_mask)\n",
    "        new_node_idxs = new_nodes_mask.nonzero()\n",
    "        n_new_nodes = len(new_node_idxs)\n",
    "        shuffled_ind = np.array([i for i in range(n_new_nodes)])\n",
    "        random.shuffle(shuffled_ind)\n",
    "        ratio_train, ratio_valid, ratio_test = train_ratio, valid_ratio, test_ratio\n",
    "        # mask_file = f'data/masks_{args.Dataset}_t_{time_slot}_run_{}'\n",
    "        n_train, n_valid, n_test = int(ratio_train*n_new_nodes), int(ratio_valid*n_new_nodes), int(ratio_test*n_new_nodes)\n",
    "        ind_train, ind_valid, ind_test = \\\n",
    "        shuffled_ind[:n_train], shuffled_ind[n_train:n_train+n_valid], shuffled_ind[n_train+n_valid:n_new_nodes]\n",
    "        ind_train, ind_valid, ind_test = new_node_idxs[ind_train],new_node_idxs[ind_valid],new_node_idxs[ind_test]\n",
    "\n",
    "        train_mask, valid_mask, test_mask = torch.tensor([False for i in range(n_nodes)]), torch.tensor([False for i in range(n_nodes)]), torch.tensor([False for i in range(n_nodes)])\n",
    "        train_mask[ind_train] = True\n",
    "        valid_mask[ind_valid] = True\n",
    "        test_mask[ind_test] = True\n",
    "        print (sum(train_mask))\n",
    "        print (sum(valid_mask))\n",
    "        print (sum(test_mask))\n",
    "        break\n",
    "               \n",
    "\n",
    "        masks_supgraphs_list.append((train_mask, valid_mask, test_mask))\n",
    "\n",
    "        train_node_idxs = g.ndata['node_idxs'][train_mask]\n",
    "        valid_node_idxs = g.ndata['node_idxs'][valid_mask]\n",
    "        test_node_idxs = g.ndata['node_idxs'][test_mask]\n",
    "\n",
    "        train_node_idxs_set.update(train_node_idxs.tolist())\n",
    "        valid_node_idxs_set.update(valid_node_idxs.tolist())\n",
    "        test_node_idxs_set.update(test_node_idxs.tolist())\n",
    "        # train_mask_whole[train_node_idxs] = True\n",
    "        # valid_mask_whole[valid_node_idxs] = True\n",
    "        # test_mask_whole[test_node_idxs] = True\n",
    "    train_mask_idxs_whole = torch.tensor([i for i in range(n_nodes_whole) if g_whole.ndata['node_idxs'][i].detach().item() in train_node_idxs_set])\n",
    "    valid_mask_idxs_whole = torch.tensor([i for i in range(n_nodes_whole) if g_whole.ndata['node_idxs'][i].detach().item() in valid_node_idxs_set])\n",
    "    test_mask_idxs_whole = torch.tensor([i for i in range(n_nodes_whole) if g_whole.ndata['node_idxs'][i].detach().item() in test_node_idxs_set])\n",
    "    print train_mask_idxs_whole\n",
    "    \n",
    "    train_mask_whole[train_mask_idxs_whole] = True\n",
    "    valid_mask_whole[valid_mask_idxs_whole] = True\n",
    "    test_mask_whole[test_mask_idxs_whole] = True\n",
    "\n",
    "    mask_whole = (train_mask_whole, valid_mask_whole, test_mask_whole)\n",
    "    # run = 0\n",
    "    with open(datapath+f'mask_seed_{run}', 'wb') as file:\n",
    "      pickle.dump((masks_supgraphs_list, mask_whole), file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "8b692cca",
   "metadata": {},
   "outputs": [],
   "source": [
    "for run in range(10):\n",
    "    \n",
    "    with open(datapath+f'sub_graph_whole', 'rb') as file:\n",
    "        g_whole = pickle.load(file)\n",
    "    \n",
    "    \n",
    "    n_nodes_whole = g_whole.num_nodes()\n",
    "    train_mask_whole, valid_mask_whole, test_mask_whole = \\\n",
    "    torch.tensor([False for i in range(n_nodes_whole)]), torch.tensor([False for i in range(n_nodes_whole)]), torch.tensor([False for i in range(n_nodes_whole)])\n",
    "    train_node_idxs_set, valid_node_idxs_set, test_node_idxs_set = set(),set(),set()\n",
    "    masks_supgraphs_list = []\n",
    "    \n",
    "    for time_slot in range(num_tasks):\n",
    "        \n",
    "        with open(datapath+f'sub_graph_{time_slot}_by_edges', 'rb') as file:\n",
    "            g = pickle.load(file)\n",
    "        n_nodes = g.num_nodes()\n",
    "        excluded_class = random.randint(0,num_class-1)\n",
    "#         print (excluded_class)\n",
    "        selected_class_mask = (g.ndata['y']-excluded_class).bool()\n",
    "        new_nodes_mask = g.ndata['new_nodes_mask']\n",
    "#         new_nodes_mask = torch.logical_and(selected_class_mask,g.ndata['new_nodes_mask'])\n",
    "        n_new_nodes = sum(new_nodes_mask)\n",
    "        new_node_idxs = new_nodes_mask.nonzero()\n",
    "#         print (g.ndata['y'][new_node_idxs].tolist())\n",
    "        n_new_nodes = len(new_node_idxs)\n",
    "        shuffled_ind = np.array([i for i in range(n_new_nodes)])\n",
    "        random.shuffle(shuffled_ind)\n",
    "        ratio_train, ratio_valid, ratio_test = train_ratio, valid_ratio, test_ratio\n",
    "        # mask_file = f'data/masks_{args.Dataset}_t_{time_slot}_run_{}'\n",
    "        n_train, n_valid, n_test = int(ratio_train*n_new_nodes), int(ratio_valid*n_new_nodes), int(ratio_test*n_new_nodes)\n",
    "        ind_train, ind_valid, ind_test = \\\n",
    "        shuffled_ind[:n_train], shuffled_ind[n_train:n_train+n_valid], shuffled_ind[n_train+n_valid:n_new_nodes]\n",
    "        ind_train, ind_valid, ind_test = new_node_idxs[ind_train],new_node_idxs[ind_valid],new_node_idxs[ind_test]\n",
    "\n",
    "        train_mask, valid_mask, test_mask = torch.tensor([False for i in range(n_nodes)]), torch.tensor([False for i in range(n_nodes)]), torch.tensor([False for i in range(n_nodes)])\n",
    "        train_mask[ind_train] = True\n",
    "        valid_mask[ind_valid] = True\n",
    "        test_mask[ind_test] = True\n",
    "\n",
    "        masks_supgraphs_list.append((train_mask, valid_mask, test_mask))\n",
    "\n",
    "        train_node_idxs = g.ndata['node_idxs'][train_mask]\n",
    "        valid_node_idxs = g.ndata['node_idxs'][valid_mask]\n",
    "        test_node_idxs = g.ndata['node_idxs'][test_mask]\n",
    "\n",
    "        train_node_idxs_set.update(train_node_idxs.tolist())\n",
    "        valid_node_idxs_set.update(valid_node_idxs.tolist())\n",
    "        test_node_idxs_set.update(test_node_idxs.tolist())\n",
    "        # train_mask_whole[train_node_idxs] = True\n",
    "        # valid_mask_whole[valid_node_idxs] = True\n",
    "        # test_mask_whole[test_node_idxs] = True\n",
    "    train_mask_idxs_whole = torch.tensor([i for i in range(n_nodes_whole) if g_whole.ndata['node_idxs'][i].detach().item() in train_node_idxs_set])\n",
    "    valid_mask_idxs_whole = torch.tensor([i for i in range(n_nodes_whole) if g_whole.ndata['node_idxs'][i].detach().item() in valid_node_idxs_set])\n",
    "    test_mask_idxs_whole = torch.tensor([i for i in range(n_nodes_whole) if g_whole.ndata['node_idxs'][i].detach().item() in test_node_idxs_set])\n",
    "    train_mask_whole[train_mask_idxs_whole] = True\n",
    "    valid_mask_whole[valid_mask_idxs_whole] = True\n",
    "    test_mask_whole[test_mask_idxs_whole] = True\n",
    "\n",
    "    mask_whole = (train_mask_whole, valid_mask_whole, test_mask_whole)\n",
    "    # run = 0\n",
    "    with open(datapath+f'mask_seed_{run}', 'wb') as file:\n",
    "      pickle.dump((masks_supgraphs_list, mask_whole), file)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87f10f4c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
