{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "data = np.load(\"adult.npz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['x_train',\n",
       " 'x_test',\n",
       " 'y_train',\n",
       " 'y_test',\n",
       " 'attr_train',\n",
       " 'attr_test',\n",
       " 'train_inds',\n",
       " 'valid_inds']"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(200, 2)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data['y_test'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = data['x_train']\n",
    "x_test = data['x_test']\n",
    "y_train = data['y_train']\n",
    "y_test = data['y_test']\n",
    "attr_train = data['attr_train']\n",
    "attr_test = data['attr_test']\n",
    "train_inds = data['train_inds']\n",
    "valid_inds = data['valid_inds']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([7080, 1369, 6071, ..., 2718, 2053, 7625], dtype=int64)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_inds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20000, 30)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train_ = np.zeros((800, 2))\n",
    "y_test_ = np.zeros((200, 2))\n",
    "attr_train_ = np.zeros((800, 1))\n",
    "attr_test_ = np.zeros((200, 1))\n",
    "x_train_ = np.zeros((800, 58))\n",
    "x_test_ = np.zeros((200, 58))\n",
    "for i in range(488):\n",
    "    y_train_[i][0] = 1\n",
    "for i in range(800):\n",
    "    y_train_[i][1] = 1 - y_train_[i][0]\n",
    "for i in range(122):\n",
    "    y_test_[i][0] = 1\n",
    "for i in range(200):\n",
    "    y_test_[i][1] = 1 - y_test_[i][0]\n",
    "for i in range(392):\n",
    "    attr_train_[i][0] = 1\n",
    "for i in range(488, 656):\n",
    "    attr_train_[i][0] = 1\n",
    "for i in range(98):\n",
    "    attr_test_[i][0] = 1\n",
    "for i in range(122, 164):\n",
    "    attr_test_[i][0] = 1\n",
    "\n",
    "for i in range(392):\n",
    "    x_train_[i] = np.random.laplace(0, 1, 58)\n",
    "for i in range(392, 488):\n",
    "    x_train_[i] = np.random.chisquare(4, 58)\n",
    "for i in range(488, 656):\n",
    "    x_train_[i] = np.random.chisquare(2, 58)\n",
    "for i in range(656, 800):\n",
    "    x_train_[i] = np.random.standard_t(4, 58)\n",
    "for i in range(98):\n",
    "    x_test_[i] = np.random.laplace(0, 1, 58)\n",
    "for i in range(98, 122):\n",
    "    x_test_[i] = np.random.chisquare(4, 58)\n",
    "for i in range(122, 164):\n",
    "    x_test_[i] = np.random.chisquare(2, 58)\n",
    "for i in range(164, 200):\n",
    "    x_test_[i] = np.random.standard_t(4, 58)\n",
    "shuf = np.random.permutation(800)\n",
    "valid_pct = 0.2\n",
    "valid_ct = int(800 * valid_pct)\n",
    "valid_inds_ = shuf[:valid_ct]\n",
    "train_inds_ = shuf[valid_ct:]\n",
    "train_inds_ = np.array(train_inds_, dtype = 'int64')\n",
    "valid_inds_ = np.array(valid_inds_, dtype = 'int64')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train_ = np.zeros((8000, 2))\n",
    "y_test_ = np.zeros((2000, 2))\n",
    "attr_train_ = np.zeros((8000, 1))\n",
    "attr_test_ = np.zeros((2000, 1))\n",
    "x_train_ = np.zeros((8000, 58))\n",
    "x_test_ = np.zeros((2000, 58))\n",
    "for i in range(4880):\n",
    "    y_train_[i][0] = 1\n",
    "for i in range(8000):\n",
    "    y_train_[i][1] = 1 - y_train_[i][0]\n",
    "for i in range(1220):\n",
    "    y_test_[i][0] = 1\n",
    "for i in range(2000):\n",
    "    y_test_[i][1] = 1 - y_test_[i][0]\n",
    "for i in range(3920):\n",
    "    attr_train_[i][0] = 1\n",
    "for i in range(4880, 6560):\n",
    "    attr_train_[i][0] = 1\n",
    "for i in range(980):\n",
    "    attr_test_[i][0] = 1\n",
    "for i in range(1220, 1640):\n",
    "    attr_test_[i][0] = 1\n",
    "\n",
    "for i in range(3920):\n",
    "    #x_train_[i] = np.random.normal(0.1, 1, 58)\n",
    "    x_train_[i] = np.random.chisquare(4, 58)\n",
    "for i in range(3920, 4880):\n",
    "    x_train_[i] = np.random.chisquare(1, 58)\n",
    "for i in range(4880, 6560):\n",
    "    x_train_[i] = np.random.standard_t(4, 58)\n",
    "for i in range(6560, 8000):\n",
    "    x_train_[i] = np.random.standard_t(1, 58)\n",
    "for i in range(980):\n",
    "    x_test_[i] = np.random.chisquare(4, 58)\n",
    "for i in range(980, 1220):\n",
    "    x_test_[i] = np.random.chisquare(1, 58)\n",
    "for i in range(1220, 1640):\n",
    "    x_test_[i] = np.random.standard_t(4, 58)\n",
    "for i in range(1640, 2000):\n",
    "    x_test_[i] = np.random.standard_t(1, 58)\n",
    "shuf = np.random.permutation(8000)\n",
    "valid_pct = 0.2\n",
    "valid_ct = int(8000 * valid_pct)\n",
    "valid_inds_ = shuf[:valid_ct]\n",
    "train_inds_ = shuf[valid_ct:]\n",
    "train_inds_ = np.array(train_inds_, dtype = 'int64')\n",
    "valid_inds_ = np.array(valid_inds_, dtype = 'int64')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez(\"adult.npz\", x_train=x_train_, x_test=x_test_,\n",
    "            y_train=y_train_, y_test=y_test_,\n",
    "            attr_train=attr_train_, attr_test=attr_test_,\n",
    "            train_inds=train_inds_, valid_inds=valid_inds_)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "571643379aecffa424456530f10e872c4e3c6ba56933130a5a1c673ee3d5896a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
