{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "90e91622-adda-40e5-9461-38f99cb7861a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext sql"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4677626f-1300-4e5c-875e-91ebbb2c12fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "%sql postgresql://postgres:PASSWORD_REDACTED@IP_REDACTED/neural_net"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b808c76-ee5a-4a51-87eb-fa61b57f109b",
   "metadata": {},
   "source": [
    "# Description\n",
    "\n",
    "SGDm neural net from $n$ to $1$ dimensions with observation noise, stochastic covariate shift with randomly changing domain sampling mean.\n",
    "\n",
    "Batch size >1 because overwhelmingly common for neural nets.\n",
    "\n",
    "Architecture fixed, because this is just showing something resonance-like on a non-quadratic loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4aa3d5bc-a23d-45f6-b2e4-648bee7f04ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "import jax.numpy as jnp\n",
    "import psycopg2 as pg\n",
    "\n",
    "from jax import random\n",
    "\n",
    "base_args = {\n",
    "    'sweep_name': 'nonresonant',\n",
    "    \n",
    "    # Learning params\n",
    "    'step_size': 0.01,\n",
    "    'iterations': 20000,\n",
    "\n",
    "    # Problem params\n",
    "    'nn_width': 20,\n",
    "    'nn_layers': 2,\n",
    "    'nn_init_variance': 0.1,\n",
    "    'input_dimensions': 2,\n",
    "    'observation_noise_variance': 0.1,\n",
    "    'batch_size': 10,\n",
    "    'test_set_size': 100,\n",
    "\n",
    "    # Sampling params\n",
    "    'instantaneous_domain_variance': 0.25,\n",
    "    'mean_switching_variance': 0.4,\n",
    "    \n",
    "    # Runner metadata\n",
    "    'complete': False,\n",
    "}\n",
    "\n",
    "runs_per_config = 20\n",
    "momenta = [0.8, 0.85, 0.875, 0.9, 0.925, 0.95]\n",
    "mean_switching_intervals = range(1, 101 , 1)\n",
    "\n",
    "connection = pg.connect(\n",
    "    user='postgres', \n",
    "    password='PASSWORD_REDACTED', \n",
    "    database='neural_net', \n",
    "    host='IP_REDACTED'\n",
    ")\n",
    "cursor = connection.cursor()\n",
    "\n",
    "for _ in range(runs_per_config):\n",
    "    for m, t in itertools.product(\n",
    "        momenta, mean_switching_intervals\n",
    "    ):\n",
    "        row = {**base_args, **{\n",
    "            'momentum': m, \n",
    "            'mean_switching_interval': t\n",
    "        }}\n",
    "        cursor.execute(\n",
    "            'insert into runs (%s) values %s',\n",
    "            (\n",
    "                pg.extensions.AsIs(','.join(row.keys())), \n",
    "                tuple(row.values())\n",
    "            ),\n",
    "        )\n",
    "connection.commit()\n",
    "connection.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a33a534e-a2e1-4697-9c93-0d4e01b16ec0",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * postgresql://postgres:***@IP_REDACTED/neural_net\n",
      "3 rows affected.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "    <tr>\n",
       "        <th>sweep_name</th>\n",
       "        <th>complete</th>\n",
       "        <th>count</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>nonresonant</td>\n",
       "        <td>True</td>\n",
       "        <td>8000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>mean_variance</td>\n",
       "        <td>True</td>\n",
       "        <td>21175</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>nonresonant</td>\n",
       "        <td>False</td>\n",
       "        <td>4000</td>\n",
       "    </tr>\n",
       "</table>"
      ],
      "text/plain": [
       "[('nonresonant', True, 8000),\n",
       " ('mean_variance', True, 21175),\n",
       " ('nonresonant', False, 4000)]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%sql\n",
    "\n",
    "\n",
    "select sweep_name, complete, count(*) from runs group by sweep_name, complete;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e767e77e-f881-42fd-91c5-f8e8bcc26d8b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
