{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sweeps import sweep_learning_rates\n",
    "from networks import LTALayer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Sweep a range of network capacities to find a sufficiently parametrized (but not overparametrized)\n",
    "LTA network with two of the single layer LTA nets stacked together with hidden activation\n",
    "between them of dimension $k$.\n",
    "\n",
    "Start with reasonable sizes from single layer, including some slightly lower. Intuition\n",
    "dictates we should just find a k in this range by adding depth.  Also omit learning rates\n",
    "which are consistently too large for LTA."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "learning_rates = [5e-4, 1e-4, 5e-5, 1e-5, 5e-6]\n",
    "\n",
    "final_losses = {}\n",
    "for k in [10, 15, 20, 30, 40, 50, 60, 70, 80]:\n",
    "\n",
    "    print(f'------\\nk = {k}\\n------')\n",
    "    pre_tiling_width = k\n",
    "    bins = k\n",
    "    eta = 1.0 / k\n",
    "    width_between_layers = k\n",
    "    model_init = lambda: torch.nn.Sequential(\n",
    "        LTALayer(1, width_between_layers, pre_tiling_width, bins, eta, -1, 1),\n",
    "        LTALayer(width_between_layers, 1, pre_tiling_width, bins, eta, -1, 1),\n",
    "    )\n",
    "\n",
    "    final_loss_means, final_loss_vars = sweep_learning_rates(\n",
    "        model_init=model_init,\n",
    "        training_iterations = 20000,\n",
    "        num_means = 1,\n",
    "        diff = 0,\n",
    "        bound = 1,\n",
    "        test_size = 100,\n",
    "        learning_rates = learning_rates,\n",
    "        measurement_interval=100,\n",
    "        final_intervals_to_average=5,\n",
    "        test_grid=False,\n",
    "    )\n",
    "    final_losses[k] = {'mean': torch.tensor(final_loss_means), 'var': torch.tensor(final_loss_vars)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "cm = plt.get_cmap('inferno')\n",
    "for i, (k, losses) in enumerate(final_losses.items()):\n",
    "    means = losses['mean']\n",
    "    vars = losses['var']\n",
    "    if k==40:\n",
    "        width=3\n",
    "    else:\n",
    "        width=1\n",
    "    plt.plot(learning_rates, means, linewidth=width, label=f'k={k}', alpha=0.7, color=cm(1 - i / len(final_losses)))\n",
    "    plt.fill_between(learning_rates, means - vars, means + vars, alpha=0.3, color=cm(1 - i / len(final_losses)))\n",
    "    plt.yscale('log')\n",
    "    plt.xscale('log')\n",
    "    plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "k=40 is the winner!\n",
    "\n",
    "again to be sure:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "learning_rates = [5e-4, 1e-4, 5e-5, 1e-5, 5e-6]\n",
    "\n",
    "final_losses = {}\n",
    "for k in [10, 15, 20, 30, 40, 50, 60, 70, 80]:\n",
    "\n",
    "    print(f'------\\nk = {k}\\n------')\n",
    "    pre_tiling_width = k\n",
    "    bins = k\n",
    "    eta = 1.0 / k\n",
    "    width_between_layers = k\n",
    "    model_init = lambda: torch.nn.Sequential(\n",
    "        LTALayer(1, width_between_layers, pre_tiling_width, bins, eta, -1, 1),\n",
    "        LTALayer(width_between_layers, 1, pre_tiling_width, bins, eta, -1, 1),\n",
    "    )\n",
    "\n",
    "    final_loss_means, final_loss_vars = sweep_learning_rates(\n",
    "        model_init=model_init,\n",
    "        training_iterations = 20000,\n",
    "        num_means = 1,\n",
    "        diff = 0,\n",
    "        bound = 1,\n",
    "        test_size = 100,\n",
    "        learning_rates = learning_rates,\n",
    "        measurement_interval=100,\n",
    "        final_intervals_to_average=5,\n",
    "        test_grid=False,\n",
    "    )\n",
    "    final_losses[k] = {'mean': torch.tensor(final_loss_means), 'var': torch.tensor(final_loss_vars)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "cm = plt.get_cmap('inferno')\n",
    "for i, (k, losses) in enumerate(final_losses.items()):\n",
    "    means = losses['mean']\n",
    "    vars = losses['var']\n",
    "    if k==40:\n",
    "        width=3\n",
    "    else:\n",
    "        width=1\n",
    "    plt.plot(learning_rates, means, linewidth=width, label=f'k={k}', alpha=0.7, color=cm(1 - i / len(final_losses)))\n",
    "    plt.fill_between(learning_rates, means - vars, means + vars, alpha=0.3, color=cm(1 - i / len(final_losses)))\n",
    "    plt.yscale('log')\n",
    "    plt.xscale('log')\n",
    "    plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "40 is the winner!  smaller nets are clearly underparametrized, larger ones are not clearly better.\n",
    "\n",
    "\n",
    "Same story as last time, so let's roll with this number."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
