{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0a3ebb53-9db0-4796-97b6-33d208454a1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext sql"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3fb175ac-e8a7-493d-a987-c4d06e37afbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "%sql postgresql://postgres:PASSWORD_REDACTED@IP_REDACTED/periodic_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0cfbefc4-1e73-4385-a5bd-bc99f56cd002",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "import jax.numpy as jnp\n",
    "import psycopg2 as pg\n",
    "\n",
    "from jax import random\n",
    "\n",
    "# To get 95% within bound, set variance = (bound / 2) ** 2\n",
    "base_args = {\n",
    "    'sweep_name': 'initial_sweep',\n",
    "    \n",
    "    # Learning params\n",
    "    'iterations': 10000,\n",
    "\n",
    "    # Problem params\n",
    "    'output_dimensions': 1,\n",
    "    'target_weight_variance': 0.25,\n",
    "    'weight_init_variance': 0.25,\n",
    "    'observation_noise_variance': 0.0,\n",
    "    'batch_size': 10,\n",
    "\n",
    "    # Sampling params\n",
    "    'instantaneous_domain_variance': 0.25,\n",
    "    \n",
    "    # Runner metadata\n",
    "    'complete': False,\n",
    "}\n",
    "\n",
    "runs_per_config = 5\n",
    "\n",
    "input_dimensions = range(1, 7)\n",
    "step_sizes = [0.001]\n",
    "momenta = [0.95]\n",
    "mean_periods = range(1, 100)\n",
    "mean_position_displacements = [0.325, 0.35, 0.375]#[0.2, 0.3, 0.4]\n",
    "\n",
    "connection = pg.connect(user='postgres', password='PASSWORD_REDACTED', database='periodic_mean', host='IP_REDACTED')\n",
    "cursor = connection.cursor()\n",
    "for _ in range(runs_per_config):\n",
    "    for d, s, m, i, displacement in itertools.product(\n",
    "        input_dimensions, step_sizes, momenta, mean_periods, mean_position_displacements\n",
    "    ):\n",
    "        row = {**base_args, **{\n",
    "            'input_dimensions': d, \n",
    "            'step_size': s, \n",
    "            'momentum': m, \n",
    "            'mean_period': i,\n",
    "            'mean_position_displacement': displacement,\n",
    "        }}\n",
    "        cursor.execute(\n",
    "            'insert into runs (%s) values %s',\n",
    "            (\n",
    "                pg.extensions.AsIs(','.join(row.keys())), \n",
    "                tuple(row.values())\n",
    "            ),\n",
    "        )\n",
    "connection.commit()\n",
    "connection.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f6b6f5c8-9152-4a80-a8b8-8bb10e60bda8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * postgresql://postgres:***@IP_REDACTED/periodic_mean\n",
      "2 rows affected.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "    <tr>\n",
       "        <th>sweep_name</th>\n",
       "        <th>complete</th>\n",
       "        <th>count</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>initial_sweep</td>\n",
       "        <td>False</td>\n",
       "        <td>12920</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>initial_sweep</td>\n",
       "        <td>True</td>\n",
       "        <td>10840</td>\n",
       "    </tr>\n",
       "</table>"
      ],
      "text/plain": [
       "[('initial_sweep', False, 12920), ('initial_sweep', True, 10840)]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%sql\n",
    "\n",
    "select sweep_name, complete, count(*) from runs group by sweep_name, complete;"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba4ffb22-8255-4392-9fb3-1ac9fd6b91e1",
   "metadata": {},
   "source": [
    "# DELETION.  Just in case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "221d47d5-fe33-45b2-8ab8-0e56c1924b67",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * postgresql://postgres:***@IP_REDACTED/periodic_mean\n",
      "4670000 rows affected.\n",
      "2970 rows affected.\n",
      "1 rows affected.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "    <tr>\n",
       "        <th>?column?</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td>None</td>\n",
       "    </tr>\n",
       "</table>"
      ],
      "text/plain": [
       "[(None,)]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%sql\n",
    "\n",
    "delete from losses \n",
    "using runs \n",
    "where\n",
    "runs.run_id = losses.run_id\n",
    "and sweep_name = 'initial_sweep';\n",
    "\n",
    "delete from runs\n",
    "where sweep_name = 'initial_sweep';\n",
    "\n",
    "select null;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "a0b075d4-66f0-4615-b47e-c8080e979128",
   "metadata": {},
   "outputs": [],
   "source": [
    "! rm ../sweep_notebooks/*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e23ee9c-0a6d-462a-b98f-54b5e45b2e05",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
