{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# set up\n",
    "import time\n",
    "import numpy as np\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "jax.config.update(\"jax_enable_x64\", True)\n",
    "import io\n",
    "\n",
    "\n",
    "def calc_numpy_eig(cov):\n",
    "    eigenvalues, eigenvectors = jnp.linalg.eigh(cov)\n",
    "    idx = jnp.argsort(abs(eigenvalues))[::-1]  # Indices for descending sort\n",
    "    eigenvalues = eigenvalues[idx]\n",
    "    eigenvectors = eigenvectors[:, idx]\n",
    "    return eigenvalues, eigenvectors\n",
    "\n",
    "def model(v, cov, V1):\n",
    "    rewards = jnp.dot(cov, v)\n",
    "    penalties = jnp.zeros_like(rewards)\n",
    "    for j in range(np.size(V1[:, V1.any(0)], axis=1)):\n",
    "        vj = V1[:, j]\n",
    "        penalties += (v @ cov @ vj) * vj\n",
    "    return rewards - penalties\n",
    "\n",
    "def update(v, cov, V1, lr=0.5):\n",
    "    dv = model(v, cov, V1)\n",
    "    dv_R = dv - jnp.dot(dv, v) * v\n",
    "    vhat = v + lr * dv_R\n",
    "    return (vhat / jnp.linalg.norm(vhat))\n",
    "\n",
    "def calc_eigenvector(cov_broadcast, k, l, V1_broadcast):\n",
    "    v_k = np.array(V1_broadcast.value[:, k])\n",
    "    if k <= l:\n",
    "        V_k = np.hstack((V1_broadcast.value[:, :k+1], np.zeros((cov_broadcast.value.shape[0], cov_broadcast.value.shape[1] - k - 1))))\n",
    "        v_k = update(v_k, cov_broadcast.value, V_k)\n",
    "    return v_k\n",
    "\n",
    "def overall_error_matrices(M1, M2, K):\n",
    "    diags = jnp.diag(M1[:, :K].T @ M2[:, :K])\n",
    "    error_arr = jnp.abs(2 - 2 * jnp.abs(diags))\n",
    "    return jnp.sqrt(sum(error_arr) / K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "X = sc.binaryFiles(\"s3://pcabk/Y_exp.npy\").map(lambda file_content: np.load(io.BytesIO(file_content[1]))).collect()[0] \n",
    "X /= jnp.linalg.norm(X)\n",
    "\n",
    "cov = jnp.dot(jnp.transpose(X), X)\n",
    "cov_broadcast = sc.broadcast(cov)\n",
    "_, q = calc_numpy_eig(cov)\n",
    "\n",
    "K = 32\n",
    "L = 5000\n",
    "\n",
    "v = jnp.array([[1.0] for i in range(cov.shape[1])])\n",
    "v /= jnp.linalg.norm(v)\n",
    "v0 = jnp.array([[1.0] for i in range(cov.shape[1])])\n",
    "v0 /= jnp.linalg.norm(v0)\n",
    "V1 = np.zeros_like(cov)\n",
    "V1[:, 0] = v.T\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken 1800.3959312438965"
     ]
    }
   ],
   "source": [
    "# run the experiment\n",
    "\n",
    "t = time.time()\n",
    "time_steps = []\n",
    "errors = []\n",
    "\n",
    "# l is number of communication rounds    \n",
    "for l in range(L):\n",
    "    # Broadcast V1 to the workers\n",
    "    V1_broadcast = sc.broadcast(V1)\n",
    "    \n",
    "    # compute each vectors in parallel\n",
    "    eigenvectors = sc.parallelize(range(K)).map(\n",
    "        lambda k: calc_eigenvector(cov_broadcast, k, l, V1_broadcast)\n",
    "    ).collect()\n",
    "\n",
    "    # Update V1 with the computed eigenvectors\n",
    "    for k, v in enumerate(eigenvectors):\n",
    "        V1[:, k] = v\n",
    "\n",
    "    if l < K - 1:\n",
    "        V1[:, l + 1] = v0.T\n",
    "        \n",
    "    # compute aggregated error\n",
    "    \n",
    "    V1_broadcast.unpersist()\n",
    "    current_time = time.time() - t\n",
    "    time_steps.append(current_time)\n",
    "    errors.append(overall_error_matrices(V1, q, K))\n",
    "    \n",
    "    if current_time > 1800:\n",
    "        break\n",
    "    \n",
    "print(\"time taken\", time.time() - t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "errors = [float(error) for error in errors]\n",
    "errors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
