{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from centralized_solvers import eigengame, parallel_deflation\n",
    "from utils import *\n",
    "\n",
    "np.random.seed(41)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load and Preprocess MNIST Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "\n",
    "mnist_data = torchvision.datasets.MNIST('./', download=True)\n",
    "\n",
    "trn_img_flt = np.array([np.array(img).flatten() for img, _ in mnist_data]) / 255.\n",
    "trn_img_mean = trn_img_flt.mean(axis=0)\n",
    "trn_img_centered = trn_img_flt - trn_img_mean.reshape((1, -1))\n",
    "mat = trn_img_centered.T @ trn_img_centered / trn_img_centered.shape[0]\n",
    "evals_raw, evecs_raw = np.linalg.eigh(mat)\n",
    "true_evals = evals_raw[::-1][:30]\n",
    "true_evecs = evecs_raw[:, ::-1][:,:30].T\n",
    "\n",
    "num_trials = 10\n",
    "r = 30\n",
    "step_size_func = lambda idx: 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parallel Deflation-MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 0.0012911806414984868\n",
      "Running Trial #2\n",
      "Trial #2 Error: 0.0007581033047346986\n",
      "Running Trial #3\n",
      "Trial #3 Error: 0.0017219689803699082\n",
      "Running Trial #4\n",
      "Trial #4 Error: 0.08568833012720048\n",
      "Running Trial #5\n",
      "Trial #5 Error: 0.0033488808258883315\n",
      "Running Trial #6\n",
      "Trial #6 Error: 0.00026956639105399577\n",
      "Running Trial #7\n",
      "Trial #7 Error: 0.002020231562018738\n",
      "Running Trial #8\n",
      "Trial #8 Error: 0.0014138823110246626\n",
      "Running Trial #9\n",
      "Trial #9 Error: 5.968110946590731e-05\n",
      "Running Trial #10\n",
      "Trial #10 Error: 0.00978324677440718\n"
     ]
    }
   ],
   "source": [
    "L = 500\n",
    "T = 1\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = parallel_deflation(mat, r, L, T, step_size_func, update='pw')\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T1 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T1 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T1 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T1 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 9.511821003990935e-06\n",
      "Running Trial #2\n",
      "Trial #2 Error: 2.6115134195799512e-05\n",
      "Running Trial #3\n",
      "Trial #3 Error: 3.234525558147667e-05\n",
      "Running Trial #4\n",
      "Trial #4 Error: 1.2876362316179243e-06\n",
      "Running Trial #5\n",
      "Trial #5 Error: 6.120369666192785e-07\n",
      "Running Trial #6\n",
      "Trial #6 Error: 6.160634560337176e-06\n",
      "Running Trial #7\n",
      "Trial #7 Error: 1.846781170240909e-05\n",
      "Running Trial #8\n",
      "Trial #8 Error: 7.939117177244544e-06\n",
      "Running Trial #9\n",
      "Trial #9 Error: 6.5524795991680816e-06\n",
      "Running Trial #10\n",
      "Trial #10 Error: 1.8719203803372405e-06\n"
     ]
    }
   ],
   "source": [
    "L = 300\n",
    "T = 3\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = parallel_deflation(mat, r, L, T, step_size_func, update='pw')\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T3 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T3 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T3 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T3 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 2.4408356754897522e-06\n",
      "Running Trial #2\n",
      "Trial #2 Error: 8.855553364358597e-06\n",
      "Running Trial #3\n",
      "Trial #3 Error: 5.273748101160096e-06\n",
      "Running Trial #4\n",
      "Trial #4 Error: 1.7431174297201346e-06\n",
      "Running Trial #5\n",
      "Trial #5 Error: 2.2276114631740875e-06\n",
      "Running Trial #6\n",
      "Trial #6 Error: 2.240383515010395e-06\n",
      "Running Trial #7\n",
      "Trial #7 Error: 2.0968321874009244e-06\n",
      "Running Trial #8\n",
      "Trial #8 Error: 1.971576235862463e-06\n",
      "Running Trial #9\n",
      "Trial #9 Error: 4.445443716697782e-07\n",
      "Running Trial #10\n",
      "Trial #10 Error: 1.680333061100298e-06\n"
     ]
    }
   ],
   "source": [
    "L = 200\n",
    "T = 5\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = parallel_deflation(mat, r, L, T, step_size_func, update='pw')\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T5 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T5 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T5 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T5 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "max_err_results = dict()\n",
    "avg_err_results = dict()\n",
    "\n",
    "max_err_results['T=1'] = (max_err_mean_T1.tolist(), max_err_std_T1.tolist())\n",
    "max_err_results['T=3'] = (max_err_mean_T3.tolist(), max_err_std_T3.tolist())\n",
    "max_err_results['T=5'] = (max_err_mean_T5.tolist(), max_err_std_T5.tolist())\n",
    "\n",
    "avg_err_results['T=1'] = (avg_err_mean_T1.tolist(), avg_err_std_T1.tolist())\n",
    "avg_err_results['T=3'] = (avg_err_mean_T3.tolist(), avg_err_std_T3.tolist())\n",
    "avg_err_results['T=5'] = (avg_err_mean_T5.tolist(), avg_err_std_T5.tolist())\n",
    "\n",
    "with open('parallel_deflation_mnist.txt', 'w+') as jfile:\n",
    "    json.dump(dict(max_result=max_err_results, avg_result=avg_err_results), jfile)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## EigenGame-Alpha-MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 0.0013855998608143286\n",
      "Running Trial #2\n",
      "Trial #2 Error: 0.000780804400791811\n",
      "Running Trial #3\n",
      "Trial #3 Error: 0.0165827110447389\n",
      "Running Trial #4\n",
      "Trial #4 Error: 0.006610207604548143\n",
      "Running Trial #5\n",
      "Trial #5 Error: 0.02401295088768233\n",
      "Running Trial #6\n",
      "Trial #6 Error: 0.005669936512737841\n",
      "Running Trial #7\n",
      "Trial #7 Error: 0.00023028584947130785\n",
      "Running Trial #8\n",
      "Trial #8 Error: 0.0027082466731954513\n",
      "Running Trial #9\n",
      "Trial #9 Error: 0.00010240437484804006\n",
      "Running Trial #10\n",
      "Trial #10 Error: 0.002617239526788341\n"
     ]
    }
   ],
   "source": [
    "L = 500\n",
    "T = 1\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = eigengame(mat, r, L, T, step_size_func, update='alpha')\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T1 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T1 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T1 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T1 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 6.786166153392613e-07\n",
      "Running Trial #2\n",
      "Trial #2 Error: 2.589116958093376e-06\n",
      "Running Trial #3\n",
      "Trial #3 Error: 1.923577733409992e-06\n",
      "Running Trial #4\n",
      "Trial #4 Error: 7.434171249429239e-07\n",
      "Running Trial #5\n",
      "Trial #5 Error: 6.789437383350442e-07\n",
      "Running Trial #6\n",
      "Trial #6 Error: 2.3654860059004343e-07\n",
      "Running Trial #7\n",
      "Trial #7 Error: 1.3818768405162847e-07\n",
      "Running Trial #8\n",
      "Trial #8 Error: 2.601223797730691e-06\n",
      "Running Trial #9\n",
      "Trial #9 Error: 2.9405545403869946e-06\n",
      "Running Trial #10\n",
      "Trial #10 Error: 5.267681671121682e-06\n"
     ]
    }
   ],
   "source": [
    "L = 200\n",
    "T = 5\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = eigengame(mat, r, L, T, step_size_func, update='alpha')\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T5 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T5 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T5 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T5 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "max_err_results = dict()\n",
    "avg_err_results = dict()\n",
    "\n",
    "max_err_results['T=1'] = (max_err_mean_T1.tolist(), max_err_std_T1.tolist())\n",
    "max_err_results['T=5'] = (max_err_mean_T5.tolist(), max_err_std_T5.tolist())\n",
    "\n",
    "avg_err_results['T=1'] = (avg_err_mean_T1.tolist(), avg_err_std_T1.tolist())\n",
    "avg_err_results['T=5'] = (avg_err_mean_T5.tolist(), avg_err_std_T5.tolist())\n",
    "\n",
    "with open('eigengame_alpha_mnist.txt', 'w+') as jfile:\n",
    "    json.dump(dict(max_result=max_err_results, avg_result=avg_err_results), jfile)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## EigenGame-Mu-MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 0.0002828524228628293\n",
      "Running Trial #2\n",
      "Trial #2 Error: 0.0008641836877627376\n",
      "Running Trial #3\n",
      "Trial #3 Error: 0.0022671544850690266\n",
      "Running Trial #4\n",
      "Trial #4 Error: 0.0010584118205372256\n",
      "Running Trial #5\n",
      "Trial #5 Error: 1.7724303300665546e-05\n",
      "Running Trial #6\n",
      "Trial #6 Error: 0.0024958343802720657\n",
      "Running Trial #7\n",
      "Trial #7 Error: 0.017461269201645743\n",
      "Running Trial #8\n",
      "Trial #8 Error: 0.00012372311472815897\n",
      "Running Trial #9\n",
      "Trial #9 Error: 0.003511987358313588\n",
      "Running Trial #10\n",
      "Trial #10 Error: 0.0024374102225675287\n"
     ]
    }
   ],
   "source": [
    "L = 500\n",
    "T = 1\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = eigengame(mat, r, L, T, step_size_func, update='mu')\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T1 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T1 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T1 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T1 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Trial #1\n",
      "Trial #1 Error: 5.515442944562956e-07\n",
      "Running Trial #2\n",
      "Trial #2 Error: 5.09698133959981e-07\n",
      "Running Trial #3\n",
      "Trial #3 Error: 6.025304382787862e-07\n",
      "Running Trial #4\n",
      "Trial #4 Error: 8.407622852060422e-06\n",
      "Running Trial #5\n",
      "Trial #5 Error: 5.790400274910167e-07\n",
      "Running Trial #6\n",
      "Trial #6 Error: 9.805371687628092e-07\n",
      "Running Trial #7\n",
      "Trial #7 Error: 7.477978162145817e-06\n",
      "Running Trial #8\n",
      "Trial #8 Error: 1.059177011399501e-05\n",
      "Running Trial #9\n",
      "Trial #9 Error: 1.0330277688763746e-06\n",
      "Running Trial #10\n",
      "Trial #10 Error: 3.0264401453122348e-06\n"
     ]
    }
   ],
   "source": [
    "L = 200\n",
    "T = 5\n",
    "\n",
    "all_max_errs = np.zeros((num_trials, L+1))\n",
    "all_avg_errs = np.zeros((num_trials, L+1))\n",
    "for trial_idx in range(num_trials):\n",
    "    print(f'Running Trial #{trial_idx+1}')\n",
    "    evecs_hist = eigengame(mat, r, L, T, step_size_func, update='mu')\n",
    "    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])\n",
    "    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')\n",
    "\n",
    "max_err_mean_T5 = all_max_errs.mean(axis=0)\n",
    "max_err_std_T5 = all_max_errs.std(axis=0)\n",
    "avg_err_mean_T5 = all_avg_errs.mean(axis=0)\n",
    "avg_err_std_T5 = all_avg_errs.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "max_err_results = dict()\n",
    "avg_err_results = dict()\n",
    "\n",
    "max_err_results['T=1'] = (max_err_mean_T1.tolist(), max_err_std_T1.tolist())\n",
    "max_err_results['T=5'] = (max_err_mean_T5.tolist(), max_err_std_T5.tolist())\n",
    "\n",
    "avg_err_results['T=1'] = (avg_err_mean_T1.tolist(), avg_err_std_T1.tolist())\n",
    "avg_err_results['T=5'] = (avg_err_mean_T5.tolist(), avg_err_std_T5.tolist())\n",
    "\n",
    "with open('eigengame_mu_mnist.txt', 'w+') as jfile:\n",
    "    json.dump(dict(max_result=max_err_results, avg_result=avg_err_results), jfile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
