{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import zlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "import difficulty\n",
    "import results\n",
    "\n",
    "from sweeps import sweep_difficulties\n",
    "from networks import LTALayer, ReluLayer\n",
    "from training import target_function\n",
    "from rendering import animate_training, mean_steady_state_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "run = -1\n",
    "basename = 'unknown'\n",
    "learning_rate = 1e-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng_seed = zlib.adler32(bytes(f'{run}', encoding='utf8'))\n",
    "print(f'Run {run}')\n",
    "print(f'Using seed {rng_seed}')\n",
    "torch.manual_seed(rng_seed)\n",
    "np.random.seed(rng_seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "difficulties = torch.linspace(0, 0.98, 20).tolist()\n",
    "\n",
    "experiment_params = {\n",
    "    'training_iterations': 20000,\n",
    "    'num_means': 50,\n",
    "    'difficulties': difficulties,\n",
    "    'bound': 1,\n",
    "    'test_size': 100,\n",
    "    'measurement_interval': 10,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def show_difficulties():\n",
    "    num_d = len(difficulties)\n",
    "    segment_length = experiment_params['training_iterations'] // experiment_params['num_means']\n",
    "    fig, ax = plt.subplots(num_d, 2, squeeze=False, figsize=(15, 20), sharex=True, sharey=True)\n",
    "    fig.subplots_adjust(wspace=0, hspace=0)\n",
    "    diff_string = ', '.join([f'{d:.2f}' for d in difficulties])\n",
    "    fig.suptitle(f'means, trajectories d∈{{{diff_string}}}')\n",
    "    for i, d in enumerate(difficulties):\n",
    "        distributions = difficulty.get_distributions(experiment_params['num_means'], segment_length, d, 1)\n",
    "\n",
    "        means = [dist.mean for dist in distributions]\n",
    "        ax[i, 0].plot(means)\n",
    "        ax[i, 0].set_ylim((-1, 1))\n",
    "\n",
    "        trajectory = [dist(1) for dist in distributions]\n",
    "        ax[i, 1].plot(trajectory)\n",
    "        ax[i, 1].set_ylim((-1, 1))\n",
    "\n",
    "show_difficulties()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def model_optimizer_init():\n",
    "    k = 200 # To get parameter count in excess of LTA\n",
    "    pre_activation_width = k\n",
    "    width_between_layers = k\n",
    "    model = torch.nn.Sequential(\n",
    "        ReluLayer(1, width_between_layers, pre_activation_width),\n",
    "        ReluLayer(width_between_layers, 1, pre_activation_width),\n",
    "    )\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "    return model, optimizer\n",
    "\n",
    "\n",
    "relu_results = sweep_difficulties(\n",
    "    model_optimizer_init=model_optimizer_init,\n",
    "    **experiment_params,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "relu_performances = []\n",
    "for diff, (steps, segment_test_losses, segment_train_losses) in zip(difficulties, relu_results):\n",
    "    perf = mean_steady_state_loss(segment_test_losses, 0.125)\n",
    "    relu_performances.append(perf)\n",
    "\n",
    "plt.plot(difficulties, relu_performances, label='Relu Large')\n",
    "plt.title(f'Mean steady state loss on stationary distribution')\n",
    "plt.xlabel('Covariate shift difficulty')\n",
    "plt.xlim((0, 1))\n",
    "plt.legend()\n",
    "plt.savefig(results.full_path(basename, f'comparison_by_difficulty_large.{run}.pdf'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "results.ensure_results_dir(basename, **experiment_params)\n",
    "\n",
    "results.save_results(relu_results, results.full_path(basename, f'relu_large_results.{run}.{learning_rate}.pt'))\n"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
