{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0a3ebb53-9db0-4796-97b6-33d208454a1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext sql"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3fb175ac-e8a7-493d-a987-c4d06e37afbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "%sql postgresql://postgres:PASSWORD_REDACTED@IP_REDACTED/periodic_mean"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8ae9909-6324-4f40-80d5-26535de3f3d9",
   "metadata": {},
   "source": [
    "# Description\n",
    "\n",
    "We empirically observe that as batch size increases (approaching expected gradient), divergence is amplified.  Here we will sweep to demonstrate that effect.\n",
    "\n",
    "We've already demonstrated in the two weight experiments that momentum and step size have the expected effect, so we freeze those parameters.\n",
    "\n",
    "We also empirically note that increasing input dimensions (in the range 1-500) also worsens divergence.  (The mean signal amplitude is _fixed_ in this problem, it could simply be coming from increasing expected sample norm induced by fixed instantaneous variance.)  However it is more important to show the effect of the S in SGD, since it is approximated away in our analysis.  Hence, input dimensions is also fixed here.\n",
    "\n",
    "All other parameters significant parameters (e.g. mean switch interval, instantaneous variance) are chosen such that divergence occurs for the sake of demonstration.  (Fully stochastic can be made to diverge, but it takes fairly extreme temporal correlation.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0cfbefc4-1e73-4385-a5bd-bc99f56cd002",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "import jax.numpy as jnp\n",
    "import psycopg2 as pg\n",
    "\n",
    "from jax import random\n",
    "\n",
    "# To get 95% within bound, set variance = (bound / 2) ** 2\n",
    "base_args = {\n",
    "    'sweep_name': 'small_batch',\n",
    "    \n",
    "    # Learning params\n",
    "    'step_size': 0.01,\n",
    "    'momentum': 0.95,\n",
    "    'iterations': 10000,\n",
    "\n",
    "    # Problem params\n",
    "    'input_dimensions': 5,\n",
    "    'output_dimensions': 1,\n",
    "    'target_weight_variance': 0.25,\n",
    "    'weight_init_variance': 0.25,\n",
    "    'observation_noise_variance': 0.1,\n",
    "\n",
    "    # Sampling params\n",
    "    'instantaneous_domain_variance': 0.25,\n",
    "    'mean_position_displacement': 1.0,\n",
    "    \n",
    "    # Runner metadata\n",
    "    'complete': False,\n",
    "}\n",
    "\n",
    "runs_per_config = 20\n",
    "\n",
    "batch_sizes = [1, 2, 3, 4, 5, 20, 50]\n",
    "mean_periods = range(101, 121)\n",
    "\n",
    "connection = pg.connect(user='postgres', password='PASSWORD_REDACTED', database='periodic_mean', host='IP_REDACTED')\n",
    "cursor = connection.cursor()\n",
    "for _ in range(runs_per_config):\n",
    "    for b, t in itertools.product(\n",
    "        batch_sizes, mean_periods\n",
    "    ):\n",
    "        row = {**base_args, **{\n",
    "            'batch_size': b, \n",
    "            'mean_period': t,\n",
    "        }}\n",
    "        cursor.execute(\n",
    "            'insert into runs (%s) values %s',\n",
    "            (\n",
    "                pg.extensions.AsIs(','.join(row.keys())), \n",
    "                tuple(row.values())\n",
    "            ),\n",
    "        )\n",
    "connection.commit()\n",
    "connection.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f6b6f5c8-9152-4a80-a8b8-8bb10e60bda8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * postgresql://postgres:***@IP_REDACTED/periodic_mean\n",
      "3 rows affected.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "    <tr>\n",
       "        <th>sweep_name</th>\n",
       "        <th>complete</th>\n",
       "        <th>count</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>small_batch</td>\n",
       "        <td>False</td>\n",
       "        <td>2788</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>small_batch</td>\n",
       "        <td>True</td>\n",
       "        <td>28012</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>fully_stochastic</td>\n",
       "        <td>True</td>\n",
       "        <td>27000</td>\n",
       "    </tr>\n",
       "</table>"
      ],
      "text/plain": [
       "[('small_batch', False, 2788),\n",
       " ('small_batch', True, 28012),\n",
       " ('fully_stochastic', True, 27000)]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%sql\n",
    "\n",
    "select sweep_name, complete, count(*) from runs group by sweep_name, complete;"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba4ffb22-8255-4392-9fb3-1ac9fd6b91e1",
   "metadata": {},
   "source": [
    "# DELETION.  Just in case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "221d47d5-fe33-45b2-8ab8-0e56c1924b67",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sql\n",
    "\n",
    "delete from losses \n",
    "using runs \n",
    "where\n",
    "runs.run_id = losses.run_id\n",
    "and sweep_name = 'small_batch';\n",
    "\n",
    "delete from runs\n",
    "where sweep_name = 'small_batch';\n",
    "\n",
    "select null;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a0b075d4-66f0-4615-b47e-c8080e979128",
   "metadata": {},
   "outputs": [],
   "source": [
    "! rm ../sweep_notebooks/1_fully_stochastic/*.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e23ee9c-0a6d-462a-b98f-54b5e45b2e05",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
