{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "90e91622-adda-40e5-9461-38f99cb7861a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext sql"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4677626f-1300-4e5c-875e-91ebbb2c12fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "%sql postgresql://postgres:PASSWORD_REDACTED@IP_REDACTED/adam"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b808c76-ee5a-4a51-87eb-fa61b57f109b",
   "metadata": {},
   "source": [
    "# Description\n",
    "\n",
    "We anecdotally observe more difficulty in getting ADAM to resonate, observing at worst a lack of convergence, but so far never divergence.  Hence, mean switching variance and instantaneous variance have been adjusted to more extreme values.\n",
    "\n",
    "We vary parameter beta_1, with all other parameters fixed on the same linear regression setting as mean switching problem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4aa3d5bc-a23d-45f6-b2e4-648bee7f04ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "import jax.numpy as jnp\n",
    "import psycopg2 as pg\n",
    "\n",
    "from jax import random\n",
    "\n",
    "base_args = {\n",
    "    'sweep_name': 'beta_1',\n",
    "    \n",
    "    # Learning params\n",
    "    'step_size': 0.01,\n",
    "    'beta_2': 0.999,\n",
    "    'iterations': 10000,\n",
    "\n",
    "    # Problem params\n",
    "    'input_dimensions': 5,\n",
    "    'output_dimensions': 1,\n",
    "    'target_weight_variance': 0.25,\n",
    "    'weight_init_variance': 0.25,\n",
    "    'observation_noise_variance': 0.1,\n",
    "    'batch_size': 1,\n",
    "\n",
    "    # Sampling params\n",
    "    'instantaneous_domain_variance': 0.1,\n",
    "    'mean_switching_variance': 1.0,\n",
    "    \n",
    "    # Runner metadata\n",
    "    'complete': False,\n",
    "}\n",
    "\n",
    "runs_per_config = 10\n",
    "beta_1s = [0.9, 0.925, 0.95, 0.975, 0.99]\n",
    "mean_switching_intervals = range(1, 101)\n",
    "\n",
    "connection = pg.connect(\n",
    "    user='postgres', \n",
    "    password='PASSWORD_REDACTED', \n",
    "    database='adam', \n",
    "    host='IP_REDACTED'\n",
    ")\n",
    "cursor = connection.cursor()\n",
    "\n",
    "for _ in range(runs_per_config):\n",
    "    for b, t in itertools.product(\n",
    "        beta_1s, \n",
    "        mean_switching_intervals\n",
    "    ):\n",
    "        row = {**base_args, **{\n",
    "            'beta_1': b, \n",
    "            'mean_switching_interval': t,\n",
    "        }}\n",
    "        cursor.execute(\n",
    "            'insert into runs (%s) values %s',\n",
    "            (\n",
    "                pg.extensions.AsIs(','.join(row.keys())), \n",
    "                tuple(row.values())\n",
    "            ),\n",
    "        )\n",
    "connection.commit()\n",
    "connection.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a33a534e-a2e1-4697-9c93-0d4e01b16ec0",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * postgresql://postgres:***@IP_REDACTED/adam\n",
      "1 rows affected.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "    <tr>\n",
       "        <th>sweep_name</th>\n",
       "        <th>complete</th>\n",
       "        <th>count</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>beta_1</td>\n",
       "        <td>False</td>\n",
       "        <td>10000</td>\n",
       "    </tr>\n",
       "</table>"
      ],
      "text/plain": [
       "[('beta_1', False, 10000)]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%sql\n",
    "\n",
    "\n",
    "select sweep_name, complete, count(*) from runs group by sweep_name, complete;"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
