{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from stochastic_solvers import parallel_deflation, eigengame\n",
    "from utils import *\n",
    "\n",
    "np.random.seed(41)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 784)\n"
     ]
    }
   ],
   "source": [
    "import torchvision\n",
    "\n",
    "mnist_data = torchvision.datasets.MNIST('./', download=True)\n",
    "\n",
    "trn_img_flt = np.array([np.array(img).flatten() for img, _ in mnist_data]) / 255.\n",
    "trn_img_mean = trn_img_flt.mean(axis=0)\n",
    "trn_img_centered = trn_img_flt - trn_img_mean.reshape((1, -1))\n",
    "print(trn_img_centered.shape)\n",
    "mat = trn_img_centered.T @ trn_img_centered / trn_img_centered.shape[0]\n",
    "evals_raw, evecs_raw = np.linalg.eigh(mat)\n",
    "true_evals = evals_raw[::-1][:30]\n",
    "true_evecs = evecs_raw[:, ::-1][:,:30].T\n",
    "\n",
    "data_gen = lambda sample: random_batch(trn_img_centered, sample)\n",
    "\n",
    "num_trials = 10\n",
    "r = 30\n",
    "batch_size = 1000\n",
    "\n",
    "\n",
    "eval_func = lambda x: compute_avg_error(true_evecs, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parallel Deflation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 0.07931484685691509\n",
      "Running Trial #2\n",
      "Trial #2 Error: 0.09914454240575465\n",
      "Running Trial #3\n",
      "Trial #3 Error: 0.09418131663907331\n",
      "Running Trial #4\n",
      "Trial #4 Error: 0.21964881389109794\n",
      "Running Trial #5\n",
      "Trial #5 Error: 0.07981865075718252\n",
      "Running Trial #6\n",
      "Trial #6 Error: 0.08185703833859334\n",
      "Running Trial #7\n",
      "Trial #7 Error: 0.10887208842795389\n",
      "Running Trial #8\n",
      "Trial #8 Error: 0.08195899120061222\n",
      "Running Trial #9\n",
      "Trial #9 Error: 0.08597193398035886\n",
      "Running Trial #10\n",
      "Trial #10 Error: 0.2025477628134142\n"
     ]
    }
   ],
   "source": [
    "L = 1200\n",
    "T = 1\n",
    "step_size_func = lambda idx: 10 / (1 + np.floor(idx / 10))\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = parallel_deflation(data_gen, r, L, T, step_size_func, batch_size)\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_avg_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T1 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T1 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T1 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T1 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 0.08677165995138474\n",
      "Running Trial #2\n",
      "Trial #2 Error: 0.09958614315861679\n",
      "Running Trial #3\n",
      "Trial #3 Error: 0.3381445477950125\n",
      "Running Trial #4\n",
      "Trial #4 Error: 0.1355754896618563\n",
      "Running Trial #5\n",
      "Trial #5 Error: 0.10306586834036491\n",
      "Running Trial #6\n",
      "Trial #6 Error: 0.097766789220794\n",
      "Running Trial #7\n",
      "Trial #7 Error: 0.11135735956953237\n",
      "Running Trial #8\n",
      "Trial #8 Error: 0.09336796493401384\n",
      "Running Trial #9\n",
      "Trial #9 Error: 0.2621280428012151\n",
      "Running Trial #10\n",
      "Trial #10 Error: 0.2228245209869005\n"
     ]
    }
   ],
   "source": [
    "L = 240\n",
    "T = 5\n",
    "step_size_func = lambda idx: 10 / (1 + np.floor(idx / 2))\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = parallel_deflation(data_gen, r, L, T, step_size_func, batch_size)\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_avg_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T5 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T5 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T5 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T5 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "max_err_results = dict()\n",
    "avg_err_results = dict()\n",
    "\n",
    "max_err_results['T=1'] = (max_err_mean_T1.tolist(), max_err_std_T1.tolist())\n",
    "max_err_results['T=5'] = (max_err_mean_T5.tolist(), max_err_std_T5.tolist())\n",
    "\n",
    "avg_err_results['T=1'] = (avg_err_mean_T1.tolist(), avg_err_std_T1.tolist())\n",
    "avg_err_results['T=5'] = (avg_err_mean_T5.tolist(), avg_err_std_T5.tolist())\n",
    "\n",
    "with open('parallel_deflation_mnist_sto.txt', 'w+') as jfile:\n",
    "    json.dump(dict(max_result=max_err_results, avg_result=avg_err_results), jfile)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# EigenGame-Mu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 0.2800819053946179\n",
      "Running Trial #2\n",
      "Trial #2 Error: 0.35477575780643655\n",
      "Running Trial #3\n",
      "Trial #3 Error: 0.06114400574545103\n",
      "Running Trial #4\n",
      "Trial #4 Error: 0.08765750793496763\n",
      "Running Trial #5\n",
      "Trial #5 Error: 0.06116958915120261\n",
      "Running Trial #6\n",
      "Trial #6 Error: 0.06521674546904341\n",
      "Running Trial #7\n",
      "Trial #7 Error: 0.12841713895099455\n",
      "Running Trial #8\n",
      "Trial #8 Error: 0.08263263885944094\n",
      "Running Trial #9\n",
      "Trial #9 Error: 0.13691939394910063\n",
      "Running Trial #10\n",
      "Trial #10 Error: 0.10995566216014638\n"
     ]
    }
   ],
   "source": [
    "L = 1200\n",
    "T = 1\n",
    "step_size_func = lambda idx: 10 / (1 + np.floor(idx / 10))\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = eigengame(data_gen, r, L, T, step_size_func, batch_size, update='mu')\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_avg_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T1 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T1 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T1 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T1 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 0.20000367110492265\n",
      "Running Trial #2\n",
      "Trial #2 Error: 0.18002064143225627\n",
      "Running Trial #3\n",
      "Trial #3 Error: 0.11461935797674873\n",
      "Running Trial #4\n",
      "Trial #4 Error: 0.26313974013063113\n",
      "Running Trial #5\n",
      "Trial #5 Error: 0.19840596289598525\n",
      "Running Trial #6\n",
      "Trial #6 Error: 0.12021433784788399\n",
      "Running Trial #7\n",
      "Trial #7 Error: 0.08374777631890154\n",
      "Running Trial #8\n",
      "Trial #8 Error: 0.09099088260523475\n",
      "Running Trial #9\n",
      "Trial #9 Error: 0.1866020109061588\n",
      "Running Trial #10\n",
      "Trial #10 Error: 0.2387226931153399\n"
     ]
    }
   ],
   "source": [
    "L = 240\n",
    "T = 5\n",
    "step_size_func = lambda idx: 10 / (1 + np.floor(idx / 2))\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = eigengame(data_gen, r, L, T, step_size_func, batch_size, update='mu')\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_avg_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T5 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T5 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T5 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T5 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "max_err_results = dict()\n",
    "avg_err_results = dict()\n",
    "\n",
    "max_err_results['T=1'] = (max_err_mean_T1.tolist(), max_err_std_T1.tolist())\n",
    "max_err_results['T=5'] = (max_err_mean_T5.tolist(), max_err_std_T5.tolist())\n",
    "\n",
    "avg_err_results['T=1'] = (avg_err_mean_T1.tolist(), avg_err_std_T1.tolist())\n",
    "avg_err_results['T=5'] = (avg_err_mean_T5.tolist(), avg_err_std_T5.tolist())\n",
    "\n",
    "with open('eigengame_mu_mnist_sto.txt', 'w+') as jfile:\n",
    "    json.dump(dict(max_result=max_err_results, avg_result=avg_err_results), jfile)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# EigenGame-Alpha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 0.06680067135104835\n",
      "Running Trial #2\n",
      "Trial #2 Error: 0.08383224909477664\n",
      "Running Trial #3\n",
      "Trial #3 Error: 0.2521713095341494\n",
      "Running Trial #4\n",
      "Trial #4 Error: 0.06849535367003343\n",
      "Running Trial #5\n",
      "Trial #5 Error: 0.09019275791369717\n",
      "Running Trial #6\n",
      "Trial #6 Error: 0.06280845752427353\n",
      "Running Trial #7\n",
      "Trial #7 Error: 0.09847965652348292\n",
      "Running Trial #8\n",
      "Trial #8 Error: 0.06646231971721014\n",
      "Running Trial #9\n",
      "Trial #9 Error: 0.0873212897821445\n",
      "Running Trial #10\n",
      "Trial #10 Error: 0.11153374469673465\n"
     ]
    }
   ],
   "source": [
    "L = 1200\n",
    "T = 1\n",
    "step_size_func = lambda idx: 10 / (1 + np.floor(idx / 10))\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = eigengame(data_gen, r, L, T, step_size_func, batch_size, update='alpha')\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_avg_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T1 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T1 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T1 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T1 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 0.10539663923046567\n",
      "Running Trial #2\n",
      "Trial #2 Error: 0.0805605605666201\n",
      "Running Trial #3\n",
      "Trial #3 Error: 0.09776179975499993\n",
      "Running Trial #4\n",
      "Trial #4 Error: 0.14264293137552858\n",
      "Running Trial #5\n",
      "Trial #5 Error: 0.08702676952366487\n",
      "Running Trial #6\n",
      "Trial #6 Error: 0.3419322727578004\n",
      "Running Trial #7\n",
      "Trial #7 Error: 0.1886071802060075\n",
      "Running Trial #8\n",
      "Trial #8 Error: 0.35532805096160336\n",
      "Running Trial #9\n",
      "Trial #9 Error: 0.06182658173880311\n",
      "Running Trial #10\n",
      "Trial #10 Error: 0.16130111013789586\n"
     ]
    }
   ],
   "source": [
    "L = 240\n",
    "T = 5\n",
    "step_size_func = lambda idx: 10 / (1 + np.floor(idx / 2))\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = eigengame(data_gen, r, L, T, step_size_func, batch_size, update='alpha')\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_avg_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T5 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T5 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T5 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T5 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "max_err_results = dict()\n",
    "avg_err_results = dict()\n",
    "\n",
    "max_err_results['T=1'] = (max_err_mean_T1.tolist(), max_err_std_T1.tolist())\n",
    "max_err_results['T=5'] = (max_err_mean_T5.tolist(), max_err_std_T5.tolist())\n",
    "\n",
    "avg_err_results['T=1'] = (avg_err_mean_T1.tolist(), avg_err_std_T1.tolist())\n",
    "avg_err_results['T=5'] = (avg_err_mean_T5.tolist(), avg_err_std_T5.tolist())\n",
    "\n",
    "with open('eigengame_alpha_mnist_sto.txt', 'w+') as jfile:\n",
    "    json.dump(dict(max_result=max_err_results, avg_result=avg_err_results), jfile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
