{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from utils.dict_utils import dict_to_defaultdict\n",
    "from utils.data_loaders import get_shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'cifar2_binary'\n",
    "num_classes = get_shape(dataset)[-1]\n",
    "train_size = 1024 # 1024, 1024, 1024, 1024, 1024, 1024, 8192, 8192\n",
    "num_hidden = 1\n",
    "bias = False\n",
    "normalization = 'none'\n",
    "activation = 'lrelu'\n",
    "optimizer = 'sgd'\n",
    "optimizer_name = 'GD' if optimizer == 'sgd' else 'RMSProp' if optimizer == 'rmsprop' else optimizer\n",
    "num_seeds = 10\n",
    "num_epochs = 100  # 2000, 100,  500,  25,   125,  6,    125,  6\n",
    "batch_size = 1024 # 1024, 1024, 256,  256,  64,   64,   512,  512\n",
    "lr = 0.02\n",
    "\n",
    "steps_per_epoch = train_size // batch_size\n",
    "\n",
    "title = 'train size = {}, batch size = {}'.format(train_size, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dir = os.path.join(\n",
    "    'results', 'equiv_models_test', '{}_{}'.format(dataset, train_size), \n",
    "    'num_hidden={}_activation={}_bias={}_normalization={}'.format(num_hidden, activation, bias, normalization), \n",
    "    '{}_lr={}_batch_size={}_num_epochs={}'.format(optimizer, lr, batch_size, num_epochs)\n",
    ")\n",
    "log_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_all_path = os.path.join(log_dir, 'results_all.dat')\n",
    "results_all = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: None))))))\n",
    "if os.path.exists(results_all_path):\n",
    "    with open(results_all_path, 'rb') as f:\n",
    "        results_all = dict_to_defaultdict(pickle.load(f), results_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "correction_epochs = [0]\n",
    "if num_hidden == 1:\n",
    "    real_widths = [128, 65536]\n",
    "    scaling_modes = [\n",
    "        # zero-dimensional regions:\n",
    "        'mean_field', \n",
    "        'ntk', \n",
    "        'mean_field_simple_init_corrected',\n",
    "        'default',\n",
    "        'default_sym',\n",
    "        # one-dimensional regions:\n",
    "#         (-0.75,0.5), #MF-NTK\n",
    "#         (-0.5,0.25), #NTK-sym_default\n",
    "#         (-0.75,0.75), #sym_default-MF\n",
    "#         (-1.5,1.5), #post-MF\n",
    "#         (0,0), #pre-sym_default\n",
    "#         (-1,0.5), #post-NTK\n",
    "#         (0,-0.5), #pre-NTK\n",
    "        # two-dimensional regions:\n",
    "#         (-0.66,0.5), #baricenter\n",
    "#         (-1,0.75), #post-MF-NTK\n",
    "#         (-10,9.75), #post-MF-NTK\n",
    "#         (0,-0.25), #pre-NTK-sym_default, far\n",
    "#         (10,-10.25), #pre-NTK-sym_default, far\n",
    "        # ill-defined:\n",
    "#         (1,1), #diverge\n",
    "#         (-1,-1), #stagnate\n",
    "    ]\n",
    "    scaling_mode_names = {\n",
    "        'mean_field': 'MF', \n",
    "        'ntk': 'NTK', \n",
    "        'mean_field_simple_init_corrected': 'IC-MF', \n",
    "        'default': 'default', \n",
    "        'default_sym': 'sym-default'\n",
    "    }\n",
    "    ref_widths = [128]\n",
    "else:\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 18})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ewma(a, alpha):\n",
    "    av_a = a\n",
    "    for i in range(1, len(a)):\n",
    "        av_a[i] = a[i] * alpha + av_a[i-1] * (1-alpha)\n",
    "    return av_a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_curve(scaling_mode, ref_width, real_width, correction_epoch, key, \n",
    "               idx=None, threshold=1000, smoothening_factor=0, label=None, **kwargs):\n",
    "    data = [\n",
    "        [\n",
    "            results_all[scaling_mode][ref_width][correction_epoch][real_width][seed][key][epoch] \n",
    "            for epoch in range(num_epochs)\n",
    "        ] for seed in range(num_seeds)\n",
    "    ]\n",
    "    data = np.array(data)\n",
    "    data = np.clip(data, -threshold, threshold)\n",
    "    data = np.exp(ewma(np.log(data.T)[::-1], alpha=1-smoothening_factor)[::-1].T)\n",
    "    data_mean = data.mean(axis=0)\n",
    "    data_std = data.std(axis=0)\n",
    "    plt.plot(np.arange(1, num_epochs+1)*steps_per_epoch, data_mean, label=label, **kwargs)\n",
    "    plt.fill_between(\n",
    "        np.arange(1, num_epochs+1)*steps_per_epoch,\n",
    "        data_mean - data_std, data_mean + data_std,\n",
    "        alpha=0.3, **kwargs\n",
    "    )\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "linestyles = ['solid', 'dashed', 'dotted', 'dashdot', (0, (1,5)), (0, (3,5))]\n",
    "cmap = plt.get_cmap('tab10')\n",
    "key_bases = ['test_losses', 'test_accs', 'train_losses', 'train_accs']\n",
    "key_modifiers = [('', '')]\n",
    "ylims = [(0.35,0.45), (0.8, 0.9), (0.6,0.7), (0.7,1.0)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_draw = None\n",
    "\n",
    "for key_base, ylim in zip(key_bases, ylims):\n",
    "    _ = plt.figure(figsize=(12,6))\n",
    "\n",
    "    plt.xlabel('training step, k+{}'.format(steps_per_epoch))\n",
    "    #plt.ylim(ylim) # uncomment to adjust y-limits manually\n",
    "    plt.grid(True)\n",
    "\n",
    "    if key_base.endswith('_losses'):\n",
    "        plt.yscale('log')\n",
    "        plt.ylabel(\"CE loss\")\n",
    "    elif key_base.endswith('_accs'):\n",
    "        plt.ylabel(\"accuracy\")\n",
    "    plt.xscale('log')\n",
    "\n",
    "    for ref_width in ref_widths:\n",
    "        for real_width in real_widths[::-1]:\n",
    "            for k, scaling_mode in enumerate(scaling_modes):\n",
    "                if to_draw is not None and scaling_mode not in to_draw and (ref_width != real_width):\n",
    "                    continue\n",
    "                if (ref_width == real_width) and (scaling_mode != 'default'):\n",
    "                    continue\n",
    "                for correction_epoch in (\n",
    "                    correction_epochs if scaling_mode == 'mean_field' else [0] if scaling_mode.startswith('mean_field') else [None]\n",
    "                ):\n",
    "                    for i, key_mod in enumerate(key_modifiers):\n",
    "                        key = key_mod[0] + key_base + key_mod[1]\n",
    "                        if (scaling_mode == 'default') and (ref_width == real_width):\n",
    "                            draw_curve(\n",
    "                                scaling_mode, None, \n",
    "                                real_width, correction_epoch, key, color='black', \n",
    "                                linestyle='dashed', lw=3\n",
    "                            )\n",
    "                        else:\n",
    "                            draw_curve(\n",
    "                                scaling_mode, ref_width if scaling_mode != 'default' else None, \n",
    "                                real_width, correction_epoch, key, color=cmap(k), \n",
    "                                linestyle=linestyles[i], lw=3\n",
    "                            )\n",
    "                    \n",
    "    plt.legend(\n",
    "        [scaling_mode_names[scaling_mode] for scaling_mode in scaling_modes] + ['reference; d*={}'.format(ref_widths[0])]\n",
    "    )\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_logits(\n",
    "    scaling_mode, ref_width, real_width, correction_epoch, \n",
    "    add_displacement=0, mul_displacement=1,\n",
    "    label=None, **kwargs\n",
    "):\n",
    "    test_logits = [[\n",
    "        results_all[scaling_mode][ref_width][correction_epoch][real_width][seed]['test_logits'][epoch] \n",
    "        for epoch in range(num_epochs)\n",
    "    ] for seed in range(num_seeds)]\n",
    "    test_logits = np.array(test_logits).squeeze(axis=-1)\n",
    "\n",
    "    data = np.mean(np.abs(test_logits), axis=-1)\n",
    "    data += add_displacement\n",
    "    data *= mul_displacement\n",
    "        \n",
    "    data_mean = data.mean(axis=0)\n",
    "    data_std = data.std(axis=0)\n",
    "    plt.plot(np.arange(1, num_epochs+1)*steps_per_epoch, data_mean, label=label, **kwargs)\n",
    "    plt.fill_between(\n",
    "        np.arange(1, num_epochs+1)*steps_per_epoch, \n",
    "        data_mean-data_std, data_mean+data_std,\n",
    "        alpha=0.3, **kwargs\n",
    "    )\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_label(scaling_mode, ref_width=None, real_width=None):\n",
    "    if ref_width is None or real_width is None:\n",
    "        appendix = ''\n",
    "    else:\n",
    "        appendix = ': ' + r'$d = 2^{}{}{}$'.format('{', int(np.log2(ref_width if scaling_mode == 'reference' else real_width)), '}')\n",
    "    if scaling_mode == 'reference':\n",
    "        return scaling_mode + appendix\n",
    "    else:\n",
    "        return scaling_mode_names[scaling_mode] + appendix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key_bases = ['losses', 'accs']\n",
    "key_modifiers = [('test_',''), ('train_','')]\n",
    "# ylims = [(0.35,0.45), (0.8, 0.9), (0.6,0.7), (0.7,1.0)]\n",
    "\n",
    "to_draw = ['mean_field', 'ntk', 'default_sym', 'mean_field_simple_init_corrected',]\n",
    "\n",
    "for key_base, ylim in zip(key_bases, ylims):\n",
    "    _ = plt.figure(figsize=(6,4))\n",
    "\n",
    "    plt.xlabel('training step, k+{}'.format(steps_per_epoch))\n",
    "    # plt.ylim(ylim) # uncomment to adjust y-limits manually\n",
    "    plt.grid(True)\n",
    "\n",
    "    if key_base == 'losses':\n",
    "        plt.ylabel(\"CE loss\")\n",
    "    elif key_base == 'accs':\n",
    "        plt.ylabel(\"accuracy\")\n",
    "    #plt.yscale('log')\n",
    "    plt.xscale('log')\n",
    "\n",
    "    for ref_width in ref_widths:\n",
    "        for real_width in real_widths[::-1]:\n",
    "            for k, scaling_mode in enumerate(scaling_modes):\n",
    "                if to_draw is not None and scaling_mode not in to_draw and (ref_width != real_width):\n",
    "                    continue\n",
    "                if (ref_width == real_width) and (scaling_mode != 'default'):\n",
    "                    continue\n",
    "                for correction_epoch in (\n",
    "                    correction_epochs if scaling_mode == 'mean_field' else [0] if scaling_mode.startswith('mean_field') else [None]\n",
    "                ):\n",
    "                    for i, key_mod in enumerate(key_modifiers):\n",
    "                        key = key_mod[0] + key_base + key_mod[1]\n",
    "                        if (scaling_mode == 'default') and (ref_width == real_width):\n",
    "                            draw_curve(\n",
    "                                scaling_mode, None, \n",
    "                                real_width, correction_epoch, key, color='black', \n",
    "                                linestyle=linestyles[i], lw=3, label=get_label('reference', ref_width, real_width) if i == 0 else None\n",
    "                            )\n",
    "                        else:\n",
    "                            draw_curve(\n",
    "                                scaling_mode, ref_width if scaling_mode != 'default' else None, \n",
    "                                real_width, correction_epoch, key, color=cmap(k), \n",
    "                                linestyle=linestyles[i], lw=3, label=get_label(scaling_mode, ref_width, real_width) if i == 0 else None\n",
    "                            )\n",
    "                    \n",
    "    if key_base == 'losses':\n",
    "        plt.legend()\n",
    "    plt.title(title)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_draw = ['mean_field', 'ntk', 'default_sym', 'mean_field_simple_init_corrected',]\n",
    "\n",
    "_ = plt.figure(figsize=(6,4))\n",
    "\n",
    "plt.xlabel('training step, k+{}'.format(steps_per_epoch))\n",
    "#plt.ylim(ylim) # uncomment to adjust y-limits manually\n",
    "plt.grid(True)\n",
    "\n",
    "plt.ylabel('mean abs logit, ' + r\"$\\mathbb{E}_x |f(x)|$\")\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "\n",
    "for ref_width in ref_widths:\n",
    "    for real_width in real_widths[::-1]:\n",
    "        for k, scaling_mode in list(enumerate(scaling_modes)):\n",
    "            if to_draw is not None and scaling_mode not in to_draw and (ref_width != real_width):\n",
    "                continue\n",
    "            if (ref_width == real_width) and (scaling_mode != 'default'):\n",
    "                continue\n",
    "            for correction_epoch in (\n",
    "                correction_epochs if scaling_mode == 'mean_field' else [0] if scaling_mode.startswith('mean_field') else [None]\n",
    "            ):\n",
    "                if (scaling_mode == 'default') and (ref_width == real_width):\n",
    "                    draw_logits(\n",
    "                        scaling_mode, None, \n",
    "                        real_width, correction_epoch, color='black',\n",
    "                        linestyle='dotted',\n",
    "                        lw=3, label=get_label('reference', ref_width, real_width)\n",
    "                    )\n",
    "                else:\n",
    "                    draw_logits(\n",
    "                        scaling_mode, ref_width if scaling_mode != 'default' else None, \n",
    "                        real_width, correction_epoch, color=cmap(k), \n",
    "                        lw=3, label=get_label(scaling_mode, ref_width, real_width),\n",
    "                        mul_displacement=1.05**k\n",
    "                    )\n",
    "\n",
    "# plt.legend()\n",
    "plt.title(title)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.special import kl_div, expit, digamma, gamma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_and_var(a, **kwargs):\n",
    "    return np.mean(a, **kwargs), np.var(a, ddof=1, **kwargs)\n",
    "\n",
    "def kl_between_normals(mean, var, mean_ref, var_ref):\n",
    "    return 0.5 * (var / var_ref + (mean_ref - mean) ** 2 / var_ref - 1 - np.log(var / var_ref))\n",
    "\n",
    "def estimate_beta_distr_params(mean, var, eps=1e-2):\n",
    "    alpha = np.clip(mean * ((mean * (1 - mean) + eps) / (var + eps) - 1), a_min=eps, a_max=1/eps)\n",
    "    beta = np.clip((1 - mean) * ((mean * (1 - mean) + eps) / (var + eps) - 1), a_min=eps, a_max=1/eps)\n",
    "    return alpha, beta\n",
    "\n",
    "def B(alpha, beta):\n",
    "    return gamma(alpha) * gamma(beta) / gamma(alpha + beta)\n",
    "\n",
    "def kl_between_betas(alpha, beta, alpha_ref, beta_ref):\n",
    "    return np.log(B(alpha_ref, beta_ref) / B(alpha, beta)) +\\\n",
    "           (alpha - alpha_ref) * digamma(alpha) + (beta - beta_ref) * digamma(beta) - (alpha - alpha_ref + beta - beta_ref) * digamma(alpha + beta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def discrepancy_between_logits(logits, logits_ref, discrepancy_type, seed_axis=0):\n",
    "    if discrepancy_type == 'logit':\n",
    "        mean, var = mean_and_var(logits, axis=seed_axis)\n",
    "        mean_ref, var_ref = mean_and_var(logits_ref, axis=seed_axis)\n",
    "        discrepancy = kl_between_normals(mean, var, mean_ref, var_ref)\n",
    "    elif discrepancy_type == 'prob':\n",
    "        mean, var = mean_and_var(expit(logits), axis=seed_axis)\n",
    "        mean_ref, var_ref = mean_and_var(expit(logits_ref), axis=seed_axis)\n",
    "        alpha, beta = estimate_beta_distr_params(mean, var)\n",
    "        alpha_ref, beta_ref = estimate_beta_distr_params(mean_ref, var_ref)\n",
    "        discrepancy = kl_between_betas(alpha, beta, alpha_ref, beta_ref)\n",
    "    elif discrepancy_type == 'class':\n",
    "        prob = np.mean((logits > 0).astype(float), axis=seed_axis)\n",
    "        prob_ref = np.mean((logits_ref > 0).astype(float), axis=seed_axis)\n",
    "        discrepancy = np.abs(prob - prob_ref)\n",
    "    else:\n",
    "        raise ValueError\n",
    "    return discrepancy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_mean_discrepancy(scaling_mode, ref_width, real_width, correction_epoch, discrepancy_type, **kwargs):\n",
    "    test_logits = [[\n",
    "        results_all[scaling_mode][ref_width][correction_epoch][real_width][seed]['test_logits'][epoch] \n",
    "        for epoch in range(num_epochs)\n",
    "    ] for seed in range(num_seeds)]\n",
    "    test_logits = np.array(test_logits).squeeze(axis=-1)\n",
    "\n",
    "    test_logits_ref = [[\n",
    "        results_all['default'][None][None][ref_widths[0]][seed]['test_logits'][epoch] \n",
    "        for epoch in range(num_epochs)\n",
    "    ] for seed in range(num_seeds)]\n",
    "    test_logits_ref = np.array(test_logits_ref).squeeze(axis=-1)\n",
    "    \n",
    "    data = discrepancy_between_logits(test_logits, test_logits_ref, discrepancy_type=discrepancy_type)\n",
    "        \n",
    "    data_mean = data.mean(axis=-1)\n",
    "    data_std = data.std(axis=-1)\n",
    "    data_max = data.max(axis=-1)\n",
    "    plt.plot(np.arange(1, num_epochs+1) * steps_per_epoch, data_mean, **kwargs)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "discrepancy_types = ['logit', 'prob', 'class']\n",
    "ylabels = ['KL(limit, ref)', 'KL(limit, ref)', '|p_limit - p_ref|']\n",
    "\n",
    "for discrepancy_type, ylabel in zip(discrepancy_types, ylabels):\n",
    "    _ = plt.figure(figsize=(12,6))\n",
    "\n",
    "    plt.xlabel('training step, k+{}'.format(steps_per_epoch))\n",
    "    #plt.ylim(ylim) # uncomment to adjust y-limits manually\n",
    "    if discrepancy_type == 'logit':\n",
    "        plt.yscale('log')\n",
    "    elif discrepancy_type == 'prob':\n",
    "        plt.yscale('log')\n",
    "    else:\n",
    "        plt.yscale('log')\n",
    "    plt.xscale('log')\n",
    "    plt.grid(True)\n",
    "\n",
    "    plt.ylabel(ylabel)\n",
    "\n",
    "    for ref_width in ref_widths:\n",
    "        for real_width in real_widths[::-1]:\n",
    "            for k, scaling_mode in enumerate(scaling_modes):\n",
    "                if ref_width == real_width:\n",
    "                    continue\n",
    "                for correction_epoch in (\n",
    "                    correction_epochs if scaling_mode == 'mean_field' else [0] if scaling_mode.startswith('mean_field') else [None]\n",
    "                ):\n",
    "                    draw_mean_discrepancy(\n",
    "                        scaling_mode, ref_width if scaling_mode != 'default' else None, \n",
    "                        real_width, correction_epoch, discrepancy_type=discrepancy_type,\n",
    "                        color=cmap(k), lw=3, label=get_label(scaling_mode)\n",
    "                    )\n",
    "\n",
    "    plt.legend()\n",
    "    plt.title(title)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_tangent_kernels(\n",
    "    scaling_mode, ref_width, real_width, correction_epoch, \n",
    "    add_displacement=0, mul_displacement=1,\n",
    "    label=None, layer='sum', relative=False, **kwargs\n",
    "):\n",
    "    test_tangent_kernels = [[\n",
    "        results_all[scaling_mode][ref_width][correction_epoch][real_width][seed]['test_tangent_kernels'][epoch][::2]\n",
    "        for epoch in range(num_epochs)\n",
    "    ] for seed in range(num_seeds)]\n",
    "    test_tangent_kernels = np.array(test_tangent_kernels)\n",
    "\n",
    "    if layer == 'input':\n",
    "        test_tangent_kernels = test_tangent_kernels[...,0,:]\n",
    "    elif layer == 'output':\n",
    "        test_tangent_kernels = test_tangent_kernels[...,1,:]\n",
    "    elif layer == 'sum':\n",
    "        test_tangent_kernels = np.sum(test_tangent_kernels, axis=-2)\n",
    "    elif layer == 'hidden':\n",
    "        raise NotImplementedError\n",
    "    else:\n",
    "        raise ValueError\n",
    "        \n",
    "    if relative:\n",
    "        test_tangent_kernels = test_tangent_kernels / test_tangent_kernels[:,0:1]\n",
    "    \n",
    "    data = np.mean(np.abs(test_tangent_kernels), axis=-1)\n",
    "    data += add_displacement\n",
    "    data *= mul_displacement\n",
    "\n",
    "    data_mean = data.mean(axis=0)\n",
    "    data_std = data.std(axis=0)\n",
    "    plt.plot(np.arange(1, num_epochs+1)*steps_per_epoch, data_mean, label=label, **kwargs)\n",
    "    plt.fill_between(\n",
    "        np.arange(1, num_epochs+1)*steps_per_epoch, \n",
    "        data_mean-data_std, data_mean+data_std, \n",
    "        alpha=0.5, **kwargs\n",
    "    )\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# to_draw = ['mean_field', 'ntk', 'default_sym']\n",
    "\n",
    "for relative in [True, False]:\n",
    "    for layer in ['sum']:\n",
    "        _ = plt.figure(figsize=(6,4))\n",
    "#         plt.title(layer + (' relative' if relative else ''))\n",
    "\n",
    "        plt.xlabel('training step, k+{}'.format(steps_per_epoch))\n",
    "        #plt.ylim(ylim) # uncomment to adjust y-limits manually\n",
    "        plt.grid(True)\n",
    "\n",
    "        plt.yscale('log')\n",
    "        plt.xscale('log')\n",
    "        if relative:\n",
    "            plt.ylabel(r\"$\\mathbb{E}_x (K(x,x) / K_{init}(x,x))$\")\n",
    "            plt.yticks([1,2], ['1', '2'])\n",
    "        else:\n",
    "            plt.ylabel('mean diag kernel, ' + r\"$\\mathbb{E}_x K(x,x)$\")\n",
    "\n",
    "        for ref_width in ref_widths:\n",
    "            for real_width in real_widths[::-1]:\n",
    "                for k, scaling_mode in enumerate(scaling_modes):\n",
    "                    if to_draw is not None and scaling_mode not in to_draw and (ref_width != real_width):\n",
    "                        continue\n",
    "                    if (ref_width == real_width) and (scaling_mode != 'default'):\n",
    "                        continue\n",
    "                    for correction_epoch in (\n",
    "                        correction_epochs if scaling_mode == 'mean_field' else [0] if scaling_mode.startswith('mean_field') else [None]\n",
    "                    ):\n",
    "                        if (scaling_mode == 'default') and (ref_width == real_width):\n",
    "                            draw_tangent_kernels(\n",
    "                                scaling_mode, None, \n",
    "                                real_width, correction_epoch, color='black', layer=layer, relative=relative,\n",
    "                                linestyle='dotted', lw=3, label=get_label('reference', ref_width, real_width)\n",
    "                            )\n",
    "                        else:\n",
    "                            draw_tangent_kernels(\n",
    "                                scaling_mode, ref_width if scaling_mode != 'default' else None, \n",
    "                                real_width, correction_epoch, color=cmap(k), layer=layer, relative=relative,\n",
    "                                lw=3, label=get_label(scaling_mode, ref_width, real_width),\n",
    "                                mul_displacement=1.03**k\n",
    "                            )\n",
    "\n",
    "        #plt.legend()\n",
    "        plt.title(title)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_logits_by_tangent_kernels(\n",
    "    scaling_mode, ref_width, real_width, correction_epoch, \n",
    "    add_displacement=0, mul_displacement=1,\n",
    "    label=None, layer='sum', relative=False, **kwargs\n",
    "):\n",
    "    test_logits = [[\n",
    "        results_all[scaling_mode][ref_width][correction_epoch][real_width][seed]['test_logits'][epoch] \n",
    "        for epoch in range(num_epochs)\n",
    "    ] for seed in range(num_seeds)]\n",
    "    test_logits = np.array(test_logits).squeeze(axis=-1)\n",
    "\n",
    "    test_tangent_kernels = [[\n",
    "        results_all[scaling_mode][ref_width][correction_epoch][real_width][seed]['test_tangent_kernels'][epoch][::2]\n",
    "        for epoch in range(num_epochs)\n",
    "    ] for seed in range(num_seeds)]\n",
    "    test_tangent_kernels = np.array(test_tangent_kernels)\n",
    "\n",
    "    if layer == 'input':\n",
    "        test_tangent_kernels = test_tangent_kernels[...,0,:]\n",
    "    elif layer == 'output':\n",
    "        test_tangent_kernels = test_tangent_kernels[...,1,:]\n",
    "    elif layer == 'sum':\n",
    "        test_tangent_kernels = np.sum(test_tangent_kernels, axis=-2)\n",
    "    elif layer == 'hidden':\n",
    "        raise NotImplementedError\n",
    "    else:\n",
    "        raise ValueError\n",
    "        \n",
    "    data = np.mean(np.abs(test_logits[...,:test_tangent_kernels.shape[-1]] / test_tangent_kernels), axis=-1)\n",
    "    data += add_displacement\n",
    "    data *= mul_displacement\n",
    "\n",
    "    data_mean = data.mean(axis=0)\n",
    "    data_std = data.std(axis=0)\n",
    "    plt.plot(np.arange(1, num_epochs+1)*steps_per_epoch, data_mean, label=label, **kwargs)\n",
    "    plt.fill_between(\n",
    "        np.arange(1, num_epochs+1)*steps_per_epoch, \n",
    "        data_mean-data_std, data_mean+data_std, \n",
    "        alpha=0.3, **kwargs\n",
    "    )\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# to_draw = ['mean_field', 'ntk', 'default_sym']\n",
    "\n",
    "for relative in [False]:\n",
    "    for layer in ['sum', 'input', 'output']:\n",
    "        _ = plt.figure(figsize=(6,4))\n",
    "#         plt.title(layer + (' relative' if relative else ''))\n",
    "\n",
    "        plt.xlabel('training step, k+{}'.format(steps_per_epoch))\n",
    "        #plt.ylim(ylim) # uncomment to adjust y-limits manually\n",
    "        plt.grid(True)\n",
    "\n",
    "        plt.ylabel(r\"$\\mathbb{E}_x |f(x) / K(x,x)|$\")\n",
    "        plt.yscale('log')\n",
    "        plt.xscale('log')\n",
    "#         plt.yticks([1,2], ['1', '2'])\n",
    "\n",
    "        for ref_width in ref_widths:\n",
    "            for real_width in real_widths[::-1]:\n",
    "                for k, scaling_mode in enumerate(scaling_modes):\n",
    "                    if to_draw is not None and scaling_mode not in to_draw and (ref_width != real_width):\n",
    "                        continue\n",
    "                    if (ref_width == real_width) and (scaling_mode != 'default'):\n",
    "                        continue\n",
    "                    for correction_epoch in (\n",
    "                        correction_epochs if scaling_mode == 'mean_field' else [0] if scaling_mode.startswith('mean_field') else [None]\n",
    "                    ):\n",
    "                        if (scaling_mode == 'default') and (ref_width == real_width):\n",
    "                            draw_logits_by_tangent_kernels(\n",
    "                                scaling_mode, None, \n",
    "                                real_width, correction_epoch, color='black', layer=layer, relative=relative,\n",
    "                                linestyle='dotted', lw=3, label=get_label('reference', ref_width, real_width)\n",
    "                            )\n",
    "                        else:\n",
    "                            draw_logits_by_tangent_kernels(\n",
    "                                scaling_mode, ref_width if scaling_mode != 'default' else None, \n",
    "                                real_width, correction_epoch, color=cmap(k), layer=layer, relative=relative,\n",
    "                                lw=3, label=get_label(scaling_mode, ref_width, real_width),\n",
    "                                mul_displacement=1.05**k\n",
    "                            )\n",
    "\n",
    "        # plt.legend()\n",
    "        plt.title(title)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
