{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "90e91622-adda-40e5-9461-38f99cb7861a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext sql"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4677626f-1300-4e5c-875e-91ebbb2c12fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "%sql postgresql://postgres:PASSWORD_REDACTED@IP_REDACTED/switching_mean"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a92114bd-c5b2-4269-b1bc-23a08926bb4e",
   "metadata": {},
   "source": [
    "# Description\n",
    "\n",
    "Show sensitivity on mean signal amplitude by varying mean switching variance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4aa3d5bc-a23d-45f6-b2e4-648bee7f04ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "import jax.numpy as jnp\n",
    "import psycopg2 as pg\n",
    "\n",
    "from jax import random\n",
    "\n",
    "# To get 95% within bound, set variance = (bound / 2) ** 2\n",
    "base_args = {\n",
    "    'sweep_name': 'mean_amplitude',\n",
    "    \n",
    "    # Learning params\n",
    "    'step_size': 0.01,\n",
    "    'momentum': 0.95,\n",
    "    'iterations': 10000,\n",
    "\n",
    "    # Problem params\n",
    "    'input_dimensions': 5,\n",
    "    'output_dimensions': 1,\n",
    "    'target_weight_variance': 0.25,\n",
    "    'weight_init_variance': 0.25,\n",
    "    'observation_noise_variance': 0.1,\n",
    "    'batch_size': 1,\n",
    "\n",
    "    # Sampling params\n",
    "    'instantaneous_domain_variance': 0.25,\n",
    "    \n",
    "    # Runner metadata\n",
    "    'complete': False,\n",
    "}\n",
    "\n",
    "runs_per_config = 20\n",
    "\n",
    "mean_switching_variances = [0.0, 0.1, 0.2, 0.3, 0.4]\n",
    "mean_switching_intervals = range(1, 50)\n",
    "\n",
    "connection = pg.connect(\n",
    "    user='postgres', \n",
    "    password='PASSWORD_REDACTED', \n",
    "    database='switching_mean', \n",
    "    host='IP_REDACTED'\n",
    ")\n",
    "cursor = connection.cursor()\n",
    "\n",
    "for _ in range(runs_per_config):\n",
    "    for v, t in itertools.product(\n",
    "        mean_switching_variances, mean_switching_intervals\n",
    "    ):\n",
    "        row = {**base_args, **{\n",
    "            'mean_switching_variance': v, \n",
    "            'mean_switching_interval': t,\n",
    "        }}\n",
    "        cursor.execute(\n",
    "            'insert into runs (%s) values %s',\n",
    "            (\n",
    "                pg.extensions.AsIs(','.join(row.keys())), \n",
    "                tuple(row.values())\n",
    "            ),\n",
    "        )\n",
    "connection.commit()\n",
    "connection.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a33a534e-a2e1-4697-9c93-0d4e01b16ec0",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * postgresql://postgres:***@IP_REDACTED/switching_mean\n",
      "4 rows affected.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "    <tr>\n",
       "        <th>sweep_name</th>\n",
       "        <th>complete</th>\n",
       "        <th>count</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>input_dimensions</td>\n",
       "        <td>False</td>\n",
       "        <td>6615</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>mean_amplitude</td>\n",
       "        <td>True</td>\n",
       "        <td>2475</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>input_dimensions</td>\n",
       "        <td>True</td>\n",
       "        <td>4455</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>mean_amplitude</td>\n",
       "        <td>False</td>\n",
       "        <td>3675</td>\n",
       "    </tr>\n",
       "</table>"
      ],
      "text/plain": [
       "[('input_dimensions', False, 6615),\n",
       " ('mean_amplitude', True, 2475),\n",
       " ('input_dimensions', True, 4455),\n",
       " ('mean_amplitude', False, 3675)]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%sql\n",
    "\n",
    "select sweep_name, complete, count(*) from runs group by sweep_name, complete;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e767e77e-f881-42fd-91c5-f8e8bcc26d8b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * postgresql:///switching_mean\n",
      "3550000 rows affected.\n",
      "2070 rows affected.\n",
      "1 rows affected.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "    <tr>\n",
       "        <th>?column?</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>None</td>\n",
       "    </tr>\n",
       "</table>"
      ],
      "text/plain": [
       "[(None,)]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%sql\n",
    "\n",
    "delete from losses \n",
    "using runs \n",
    "where\n",
    "runs.run_id = losses.run_id\n",
    "and sweep_name = 'mean_amplitude';\n",
    "\n",
    "delete from runs\n",
    "where sweep_name = 'mean_amplitude';\n",
    "\n",
    "select null;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edc22fd9-ae29-422a-9498-b936106ccac0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
