{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0f559a2f-1baf-4856-8a1d-7b90d7577475",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import pickle\n",
    "import h5py\n",
    "import scipy.io\n",
    "import numpy as np\n",
    "import pickle as pkl\n",
    "import networkx as nx\n",
    "import scipy.sparse as sp\n",
    "from scipy.sparse.linalg.eigen.arpack import eigsh\n",
    "import sys\n",
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "\n",
    "np.random.seed(10007)\n",
    "\n",
    "\n",
    "name = \"polblogs\"\n",
    "\n",
    "data_mat = scipy.io.loadmat(\"{}.mat\".format(name))\n",
    "\n",
    "D_tot = data_mat['polblogsgiant']\n",
    "label_tot = data_mat['polblogsgianty']\n",
    "n_tot = len(label_tot)\n",
    "\n",
    "temp = np.zeros(5000)\n",
    "cnt = 0\n",
    "for i in range(1, n_tot):\n",
    "    if temp[label_tot[i]] == 0:\n",
    "        cnt = cnt + 1\n",
    "        temp[label_tot[i]] = cnt\n",
    "        \n",
    "k = cnt\n",
    "Pi_tot = np.zeros((n_tot, k))\n",
    "for i in range(0, n_tot):\n",
    "    Pi_tot[i, int((temp[label_tot[i]])[0]) - 1] = 1\n",
    "\n",
    "perm_tot = np.random.permutation(n_tot)\n",
    "\n",
    "#fold_size = int(np.floor(n_tot / fold_num))\n",
    "#label_ratio = 0.3\n",
    "#label_size = int(np.floor(label_ratio * (n_tot - fold_size)))\n",
    "\n",
    "#train_list = perm_tot[0:(n_tot - fold_size)]\n",
    "#test_list = perm_tot[(n_tot - fold_size ):]\n",
    "#label_list = perm_tot[0:label_size]\n",
    "\n",
    "\n",
    "ratio_l = [0.3, 0.5, 0.7]\n",
    "method_l = ['random', 'ones', 'adj']\n",
    "fold_num = 10\n",
    "fold_size = int(np.floor(n_tot / fold_num))\n",
    "fold_list = np.array(range(0, fold_num)) * fold_size\n",
    "fold_list = np.append(fold_list, n_tot)\n",
    "\n",
    "T = 50\n",
    "features_tot_rand = np.random.normal(size = n_tot * T).reshape((n_tot, T))\n",
    "features_tot_c = np.ones(n_tot * T).reshape((n_tot, T))\n",
    "\n",
    "D = D_tot[perm_tot][perm_tot]\n",
    "graph = {}\n",
    "for i in range(0, n_tot):\n",
    "    temp_list = []\n",
    "    for j in range(0, n_tot):\n",
    "        if(D[i, j] == 1):\n",
    "            temp_list.append(j)\n",
    "    graph[i] = temp_list\n",
    "\n",
    "for i in range(1, fold_num + 1):\n",
    "    for j in ratio_l:\n",
    "        for k in method_l:\n",
    "            label_size = int(np.floor(j * (n_tot - fold_size)))\n",
    "            mask = np.ones(n_tot, dtype = bool)\n",
    "            mask[fold_list[i-1]:fold_list[i]] = 0\n",
    "            train_list = perm_tot[mask]\n",
    "            test_list = perm_tot[fold_list[i-1]:fold_list[i]]\n",
    "            perm_train = np.random.permutation(len(train_list))\n",
    "            train_list = train_list[perm_train]\n",
    "            label_list = train_list[0:label_size]\n",
    "            \n",
    "            perm_cur = np.hstack([train_list, test_list])\n",
    "            \n",
    "            y = Pi_tot[label_list, ]\n",
    "            ty = Pi_tot[test_list, ]\n",
    "            ally = Pi_tot[train_list, ]\n",
    "            \n",
    "            if k == 'random':\n",
    "                x = csr_matrix(features_tot_rand[label_list, ])\n",
    "                tx = csr_matrix(features_tot_rand[test_list, ])\n",
    "                allx = csr_matrix(features_tot_rand[train_list, ])\n",
    "            elif k == 'ones':\n",
    "                x = csr_matrix(features_tot_c[label_list, ])\n",
    "                tx = csr_matrix(features_tot_c[test_list, ])\n",
    "                allx = csr_matrix(features_tot_c[train_list, ])\n",
    "            else:\n",
    "                x = csr_matrix(D_tot.astype(float)[label_list, ])\n",
    "                tx = csr_matrix(D_tot.astype(float)[test_list, ])\n",
    "                allx = csr_matrix(D_tot.astype(float)[train_list, ])\n",
    "                \n",
    "            D = D_tot[perm_cur][perm_cur]\n",
    "            graph = {}\n",
    "            for ii in range(0, n_tot):\n",
    "                temp_list = []\n",
    "                for jj in range(0, n_tot):\n",
    "                    if(D[ii, jj] == 1):\n",
    "                        temp_list.append(jj)\n",
    "                graph[ii] = temp_list\n",
    "            \n",
    "            names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']\n",
    "            obj = [x, y, tx, ty, allx, ally, graph]\n",
    "            #obj = {'x': x, 'y': y, 'tx' : tx, 'ty': ty, 'allx' : allx, 'ally': ally, 'graph' : graph}\n",
    "            for l in range(len(names)):\n",
    "                with open(\"./data/ind.{}_{}_{}_{}.{}\".format(name, i, j, k, names[l]), 'wb') as f:\n",
    "                    pickle.dump(obj[l], f)\n",
    "            #test_index = np.array(range(fold_list[i-1], fold_list[i]))\n",
    "            with open(\"./data/ind.{}_{}_{}_{}.{}\".format(name, i, j, k, 'test.index'), 'wb') as f:\n",
    "                pickle.dump(list(np.array(range(len(train_list), n_tot))), f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 378,
   "id": "1a9fb558-a732-4534-952c-7439cef38797",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1137"
      ]
     },
     "execution_count": 378,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(np.hstack([train_list, test_list]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 383,
   "id": "059c8a7f-c933-4d94-bda6-c4fe01420e2d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1136"
      ]
     },
     "execution_count": 383,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d30b4b0f-9ad1-4888-8d97-c1c869ba5155",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "Atry = np.ones((10, 10))\n",
    "Atry[[1, 2], ]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
