{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import results\n",
    "import rendering\n",
    "import difficulty\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_difficulties(difficulties, num_means, iterations):\n",
    "    num_d = len(difficulties)\n",
    "    segment_length = iterations // num_means\n",
    "    for i, d in enumerate(difficulties):        \n",
    "        distributions = difficulty.get_distributions(num_means, segment_length, d, 1)\n",
    "        \n",
    "        font_size = 16\n",
    "        fig = plt.figure()\n",
    "        fig.subplots_adjust(bottom=0.2, right=0.8)\n",
    "        gs = fig.add_gridspec(1, 5)\n",
    "        \n",
    "        traj_ax = fig.add_subplot(gs[0, :4])\n",
    "\n",
    "        trajectory = [dist(1) for dist in distributions]\n",
    "        traj_ax.plot(trajectory, linewidth=0, marker='o', alpha=0.3, label='$X_t$')\n",
    "\n",
    "        means = [dist.mean for dist in distributions]\n",
    "        traj_ax.plot(means, color='black', linestyle='--', linewidth=2, label='$S_t = E[X_t]$')\n",
    "        \n",
    "        bound = 1.3\n",
    "        traj_ax.set_yticks([-1, -.5, 0, .5, 1])\n",
    "        traj_ax.set_xticks([0, 500, 1000])\n",
    "        traj_ax.set_ylim((-bound, bound))\n",
    "        traj_ax.spines['right'].set_color('none')\n",
    "        traj_ax.spines['top'].set_color('none')\n",
    "        traj_ax.set_xlabel('Training Iterations', fontsize=font_size)\n",
    "        \n",
    "        traj_ax.legend(fontsize=font_size, loc='lower right', title_fontsize=font_size, markerscale=2, framealpha=0.9)\n",
    "        \n",
    "        equil_ax = fig.add_subplot(gs[0, 4:])\n",
    "        \n",
    "        x = torch.linspace(-bound * 1.2, bound * 1.2, 100)\n",
    "        p_x = 2 / np.sqrt(2 * np.pi) * torch.exp(-2 * x ** 2)\n",
    "        equil_ax.fill_betweenx(x, p_x, color='gray', alpha=0.4, label='$\\\\mathcal{N}(0, \\\\xi^2)$')\n",
    "        equil_ax.plot(torch.linspace(0, 0.8, 50), torch.zeros(50), color='gray', linestyle='--', linewidth=2)\n",
    "        \n",
    "        equil_ax.annotate('$\\\\mathcal{N}(0, \\\\xi^2)$', xy=(0.25, -1), fontsize=font_size)\n",
    "        equil_ax.set_yticks([-1, -.5, 0, .5, 1])\n",
    "        equil_ax.set_yticklabels([])\n",
    "        equil_ax.set_ylim((-bound, bound))\n",
    "        equil_ax.spines['right'].set_color('none')\n",
    "        equil_ax.spines['top'].set_color('none')\n",
    "        equil_ax.spines['bottom'].set_color('none')\n",
    "        equil_ax.set_xlabel('Equil.\\nDensity', fontsize=font_size)\n",
    "        equil_ax.set_xticks([])\n",
    "        equil_ax.set_xlim((0, 0.9))\n",
    "        \n",
    "        plt.savefig(f'plots/d={d:.2f}_trajectory.pdf')\n",
    "\n",
    "show_difficulties(torch.linspace(0, 0.98, 20).tolist(), 10, 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def construct_steady_loss_cube(basename, configuration, difficulties, learning_rates, num_runs, steady_state_fraction=0.125):\n",
    "    # Construct a tensor indexed [learning rate, run, difficulty] containing the steady state losses\n",
    "    steady_losses = torch.zeros((len(learning_rates), num_runs, len(difficulties)))\n",
    "    for lr_idx, learning_rate in enumerate(learning_rates):\n",
    "        for run in range(num_runs):\n",
    "            filename = results.full_path(basename, f'{configuration}.{run}.{learning_rate}.pt')\n",
    "            results_by_diff = results.load_results(filename)\n",
    "\n",
    "            for diff_idx, (steps, test_losses, train_losses) in enumerate(results_by_diff):\n",
    "                steady_losses[lr_idx, run, diff_idx] = rendering.mean_steady_state_loss(\n",
    "                    test_losses,\n",
    "                    steady_state_fraction=steady_state_fraction\n",
    "                )\n",
    "    return steady_losses\n",
    "\n",
    "\n",
    "def get_means_stderrs(steady_losses, p_value=0.001):\n",
    "    # Construct a tensor indexed [difficulty, learning rate] containing the mean steady state losses\n",
    "    # 1.96 -> 95% -> p=0.05\n",
    "    # 2.81 -> 99.5% -> p=0.005\n",
    "    # 3.3 -> 99.9% -> p=0.001\n",
    "    # 5.2 -> 99.99999% -> p=1e-7\n",
    "    stderr_multipliers = {\n",
    "        0.05: 1.96,\n",
    "        0.005: 2.81,\n",
    "        0.001: 3.3,\n",
    "        1e-7: 5.2,\n",
    "    }\n",
    "    return (\n",
    "        steady_losses.mean(dim=1).T,\n",
    "        stderr_multipliers[p_value] * steady_losses.std(dim=1).T / np.sqrt(num_runs),\n",
    "    )\n",
    "    \n",
    "\n",
    "def plot_sensitivity(basename, configurations, difficulties, learning_rates, num_runs, plot_params, p_value):\n",
    "    steady_loss_means = {}\n",
    "    steady_loss_stderrs = {}\n",
    "    for configuration in configurations:\n",
    "        steady_losses = construct_steady_loss_cube(basename, configuration, difficulties, learning_rates, num_runs)\n",
    "        steady_loss_means[configuration], steady_loss_stderrs[configuration] = get_means_stderrs(steady_losses, p_value)\n",
    "\n",
    "    for diff_idx, difficulty in enumerate(difficulties):\n",
    "        for configuration in configurations:\n",
    "            means = steady_loss_means[configuration][diff_idx]\n",
    "            stderrs = steady_loss_stderrs[configuration][diff_idx]\n",
    "            plt.plot(\n",
    "                learning_rates,\n",
    "                means,\n",
    "                color=plot_params[configuration]['color'],\n",
    "                label=plot_params[configuration]['name'],\n",
    "            )\n",
    "            plt.fill_between(\n",
    "                learning_rates,\n",
    "                means - stderrs,\n",
    "                means + stderrs,\n",
    "                color=plot_params[configuration]['color'],\n",
    "                alpha=0.3\n",
    "            )\n",
    "        \n",
    "        font_size=16\n",
    "            \n",
    "        plt.legend(loc='lower right', frameon=False, fontsize=font_size, title=f'$d={difficulty:.2f}$', title_fontsize=font_size)\n",
    "        plt.xlabel('Learning Rate', fontsize=font_size)\n",
    "        plt.ylabel('Mean\\nFinal\\nLoss', rotation=0, labelpad=10, ha='center', va='center', fontsize=font_size)\n",
    "        plt.yticks([1e-3, 1e-2, 1e0])\n",
    "        plt.xscale('log')\n",
    "        plt.yscale('log')\n",
    "        plt.ylim((1e-3, 1))\n",
    "        \n",
    "        plt.gca().spines['right'].set_color('none')\n",
    "        plt.gca().spines['top'].set_color('none')\n",
    "        \n",
    "        plt.subplots_adjust(left=0.16, bottom=0.15, right=0.95, top=0.95)\n",
    "        \n",
    "        plt.savefig(f'plots/d={difficulty:.2f}_sensitivity.pdf')\n",
    "        plt.clf()\n",
    "\n",
    "plot_params = {\n",
    "    'lta_results': {\n",
    "        'color': 'black',\n",
    "        'name': 'LTA'\n",
    "    },\n",
    "    'relu_results': {\n",
    "        'color': 'red',\n",
    "        'name': 'Relu'\n",
    "    },\n",
    "#     'relu_large_results': {\n",
    "#         'color': 'green',\n",
    "#         'name': 'Relu (large)'\n",
    "#     },\n",
    "}\n",
    "\n",
    "basename = '2_layer_batching_50'\n",
    "configurations = plot_params.keys()\n",
    "difficulties = torch.linspace(0, 0.98, 20).tolist()\n",
    "learning_rates = [1e-2, 5e-3, 1e-3, 5e-4, 1e-4, 5e-5, 1e-5, 5e-6]\n",
    "num_runs = 10\n",
    "shading_p_value = 0.05\n",
    "\n",
    "plot_sensitivity(\n",
    "    basename=basename,\n",
    "    configurations=configurations,\n",
    "    difficulties=difficulties,\n",
    "    learning_rates=learning_rates,\n",
    "    num_runs=num_runs,\n",
    "    plot_params=plot_params,\n",
    "    p_value=shading_p_value,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def optimal_learning_rates(basename, configuration, difficulties, learning_rates, num_runs):\n",
    "    steady_losses = construct_steady_loss_cube(basename, configuration, difficulties, learning_rates, num_runs)\n",
    "    steady_loss_means, _ = get_means_stderrs(steady_losses)\n",
    "    minimizing_lr_indices = [np.argmin(steady_loss_means[diff_idx]) for diff_idx in range(len(difficulties))]\n",
    "    \n",
    "    return minimizing_lr_indices, [learning_rates[min_idx] for min_idx in minimizing_lr_indices]\n",
    "\n",
    "\n",
    "def plot_optimal_difficulty_sweep(basename, configurations, difficulties, learning_rates, num_runs, plot_params, p_value):\n",
    "    steady_loss_means = {}\n",
    "    steady_loss_stderrs = {}\n",
    "    for configuration in configurations:\n",
    "        steady_losses = construct_steady_loss_cube(basename, configuration, difficulties, learning_rates, num_runs)\n",
    "        steady_loss_means[configuration], steady_loss_stderrs[configuration] = get_means_stderrs(steady_losses, p_value)\n",
    "        lr_indices, _ = optimal_learning_rates(basename, configuration, difficulties, learning_rates, num_runs)\n",
    "        \n",
    "        means = torch.tensor(\n",
    "            [steady_loss_means[configuration][diff_idx][lr_idx] for diff_idx, lr_idx in zip(range(len(difficulties)), lr_indices)]\n",
    "        )\n",
    "        errs = torch.tensor(\n",
    "            [steady_loss_stderrs[configuration][diff_idx][lr_idx] for diff_idx, lr_idx in zip(range(len(difficulties)), lr_indices)]\n",
    "        )\n",
    "        \n",
    "        plt.plot(\n",
    "            difficulties,\n",
    "            means,\n",
    "            color=plot_params[configuration]['color'],\n",
    "            label=plot_params[configuration]['name'],\n",
    "        )\n",
    "        plt.fill_between(\n",
    "            difficulties,\n",
    "            means - errs,\n",
    "            means + errs,\n",
    "            color=plot_params[configuration]['color'],\n",
    "            alpha=0.3\n",
    "        )\n",
    "        plt.plot(\n",
    "            difficulties,\n",
    "            [means[0] for d in difficulties],\n",
    "            linestyle='--',\n",
    "            color=plot_params[configuration]['color'],\n",
    "            alpha=0.5\n",
    "        )\n",
    "        \n",
    "    font_size = 16\n",
    "\n",
    "    plt.gca().spines['right'].set_color('none')\n",
    "    plt.gca().spines['top'].set_color('none')\n",
    "\n",
    "    plt.legend(loc='upper left', frameon=False, fontsize=font_size)\n",
    "    plt.xlabel('Covariate Shift Difficulty', fontsize=font_size)\n",
    "    plt.ylabel('Mean\\nFinal\\nLoss', rotation=0, ha='center', fontsize=font_size)\n",
    "    plt.yticks([0, 0.05, 0.15])\n",
    "    plt.ylim((-0.01, 0.2))\n",
    "    \n",
    "    plt.subplots_adjust(left=0.16, bottom=0.15, right=0.95, top=0.95)\n",
    "\n",
    "    plt.savefig(f'plots/difficulty_sweep_{basename}.pdf')\n",
    "\n",
    "plot_optimal_difficulty_sweep(\n",
    "    basename, \n",
    "    configurations, \n",
    "    difficulties, \n",
    "    learning_rates, \n",
    "    num_runs, \n",
    "    plot_params,\n",
    "    shading_p_value,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def smooth(t, window_size):\n",
    "    return t.unfold(0, window_size, 1).mean(dim=1)\n",
    "\n",
    "def plot_learning_curve(basename, configuration, run, low_diff_lr, high_diff_lr, high_diff, high_diff_idx, plot_params, show_yticks=True):\n",
    "    low_filename = results.full_path(basename, f'{configuration}.{run}.{low_diff_lr}.pt')\n",
    "    low_results_by_diff = results.load_results(low_filename)\n",
    "    _steps, low_test_losses, _ = low_results_by_diff[0]\n",
    "    \n",
    "    high_filename = results.full_path(basename, f'{configuration}.{run}.{low_diff_lr}.pt')\n",
    "    high_results_by_diff = results.load_results(high_filename)\n",
    "    steps, high_test_losses, _ = high_results_by_diff[high_diff_idx]\n",
    "    \n",
    "    window_size = 50\n",
    "    steps = steps[window_size - 1:]\n",
    "    \n",
    "    plt.plot(\n",
    "        steps,\n",
    "        smooth(high_test_losses, window_size),\n",
    "        color=plot_params[configuration]['color'],\n",
    "        label=f'd={high_diff:0.2f}',\n",
    "        linewidth=2,\n",
    "    )\n",
    "    plt.plot(\n",
    "        steps,\n",
    "        smooth(low_test_losses, window_size),\n",
    "        linestyle='--',\n",
    "        color=plot_params[configuration]['color'],\n",
    "        label=f'd=0',\n",
    "        linewidth=2,\n",
    "        alpha=0.5,\n",
    "    )\n",
    "    \n",
    "    font_size = 16\n",
    "    if not show_yticks:\n",
    "        plt.gca().axes.yaxis.set_ticklabels([])\n",
    "    else:\n",
    "        plt.ylabel('Loss\\nOver\\nEquilib.\\nDist.', rotation=0, labelpad=10, ha='center', va='center', fontsize=font_size)\n",
    "    plt.xlabel('Training Iterations', fontsize=font_size)\n",
    "    plt.legend(frameon=False, fontsize=font_size)\n",
    "        \n",
    "    plt.gca().spines['right'].set_color('none')\n",
    "    plt.gca().spines['top'].set_color('none')\n",
    "    plt.xticks(list(range(0, 25000, 5000)))\n",
    "    plt.ylim((0, 0.4))\n",
    "    plt.yticks([0, 0.1, 0.3, 0.4])\n",
    "    plt.subplots_adjust(left=0.16, bottom=0.15, right=0.95, top=0.95)\n",
    "    \n",
    "    \n",
    "    plt.savefig(f'plots/diff_sweep_learning_curve_{configuration}.pdf')\n",
    "    plt.clf()\n",
    "\n",
    "\n",
    "plot_learning_curve(basename, 'lta_results', 0, 1e-5, 5e-5, difficulties[17], 17, plot_params)\n",
    "plot_learning_curve(basename, 'relu_results', 0, 1e-3, 1e-3, difficulties[17], 17, plot_params, show_yticks=False)\n",
    "# plot_learning_curve(basename, 'relu_large_results', 0, 1e-3, 1e-3, difficulties[17], 17, plot_params, show_yticks=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
