{
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "!pip install pytorch-lightning --quiet\n",
        "!pip install -q git+https://www.github.com/google/neural-tangents\n",
        "\n",
        "import os\n",
        "import sys\n",
        "\n",
        "if os.path.isdir('/content/eigenlearning'):\n",
        "    !rm -r '/content/eigenlearning'\n",
        "## [REDACTED EIGENLEARNING LIBRARY IMPORT]\n",
        "sys.path.insert(0,'/content/eigenlearning')\n",
        "\n",
        "!pip3 install pickle5\n",
        "import pickle5 as pickle\n",
        "\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive', force_remount=True)"
      ],
      "metadata": {
        "id": "pQlptQqjVnlq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import neural_tangents as nt\n",
        "import numpy as np\n",
        "import jax\n",
        "from jax import numpy as jnp\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "import time\n",
        "\n",
        "import image_datasets\n",
        "import measures\n",
        "import powerlaws\n",
        "import utils\n",
        "\n",
        "import torch\n",
        "from torch import nn\n",
        "import torch.nn.functional as F\n",
        "from torchvision import transforms\n",
        "from torchvision.datasets import CIFAR10\n",
        "from torchvision.models import resnet18\n",
        "from torch.utils.data import Dataset, DataLoader, random_split\n",
        "\n",
        "import pytorch_lightning as pl\n",
        "import torchmetrics\n",
        "\n",
        "from pytorch_lightning.callbacks.progress import TQDMProgressBar\n",
        "from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping\n",
        "from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger\n",
        "\n",
        "import shutil\n",
        "\n",
        "def rcsetup():\n",
        "    plt.rc(\"figure\", dpi=150, facecolor=(1, 1, 1), figsize=(6, 3.5))\n",
        "    plt.rc(\"font\", family='stixgeneral', size=18)\n",
        "    plt.rc(\"axes\", titlesize=19)\n",
        "    # plt.rc(\"axes\", facecolor=(1, .99, .95))\n",
        "    plt.rc(\"mathtext\", fontset='cm')\n",
        "\n",
        "def get_plot_color(ind, ncolors=10):\n",
        "    from matplotlib.colors import hsv_to_rgb\n",
        "    colorlist = [\"xkcd:blue\", \"xkcd:pumpkin\", \"xkcd:moss\", \"xkcd:lavender\", \"xkcd:goldenrod\", \"xkcd:puce\", \"xkcd:crimson\"]\n",
        "    colorlist = [hsv_to_rgb((h,1,.7)) for h in np.linspace(0, 0.8, ncolors)]\n",
        "    ncolors = len(colorlist)\n",
        "    return colorlist[ind%ncolors]\n",
        "\n",
        "plt.rcdefaults()\n",
        "rcsetup()"
      ],
      "metadata": {
        "id": "IopPF9VaVz3l"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Data generation"
      ],
      "metadata": {
        "id": "FInfOTT1WZSl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def powerlaw_lambdas_v2s(alpha, beta, M, normalize_lambdas=False):\n",
        "    idxs = np.arange(1,M+1) * 1.\n",
        "\n",
        "    lambdas = idxs**-alpha\n",
        "    if normalize_lambdas:\n",
        "        lambdas /= lambdas.sum()\n",
        "\n",
        "    v2s = idxs**-beta\n",
        "    v2s /= v2s.sum()\n",
        "\n",
        "    f_terms = {i:np.sqrt(v2s[i]) for i in range(len(v2s))}\n",
        "\n",
        "    return lambdas, v2s, f_terms\n",
        "\n",
        "def theory_mses(lambdas, f_terms, deltas_over_n, ns, M):\n",
        "    results = {}\n",
        "    for n_train in ns:\n",
        "        ridges = n_train * deltas_over_n\n",
        "        for ridge in ridges:\n",
        "            preds = measures.learning_measure_predictions(None, None, n=n_train, ridge=ridge,\n",
        "                                                          f_terms=f_terms, lambdas=lambdas, mults=1)\n",
        "            results[(n_train, ridge)] = preds\n",
        "            print('.', end='')\n",
        "        print()\n",
        "    results['M'] = M\n",
        "    results['ns'] = ns\n",
        "    results['eigenvals'] = lambdas\n",
        "    results['deltas_over_n'] = deltas_over_n\n",
        "    return results"
      ],
      "metadata": {
        "id": "umUEAa3sWlA9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Synthetic experiments"
      ],
      "metadata": {
        "id": "6PZeTaqgWeqz"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "alpha = 2\n",
        "beta = alpha\n",
        "M = 500000\n",
        "lambdas, v2s, f_terms = powerlaw_lambdas_v2s(alpha, beta, M)\n",
        "\n",
        "ns = [16, 125, 1000, 8000, 256000]\n",
        "deltas_over_n = np.array([10**i for i in np.linspace(2, -12, 30)])\n",
        "\n",
        "theory_results = theory_mses(lambdas, f_terms, deltas_over_n, ns, M)\n",
        "fname = \"r_spectrum_\" + str(alpha)\n",
        "with open(\"/content/drive/My Drive/eigenlearning DB/kernel/{}_theory_30k_orders2.pickle\".format(fname), \"wb\") as handle:\n",
        "    pickle.dump(theory_results, handle, protocol=pickle.HIGHEST_PROTOCOL)"
      ],
      "metadata": {
        "id": "yYjoMIsPWxip"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## MNIST KRR (theory vs empirical)"
      ],
      "metadata": {
        "id": "lv4KDSimW_7J"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from jax import random\n",
        "from image_datasets import get_image_eigendata\n",
        "\n",
        "M = 30000\n",
        "\n",
        "if M <= 16000:\n",
        "    _, _, kernel_fn = utils.get_net_fns(width=500, d_out=1, n_hidden_layers=4)\n",
        "    key, subkey = random.split(np.uint32([0,17]), 2)\n",
        "    classes = [[0,1,2,3,4], [5,6,7,8,9]]\n",
        "    eigendata = get_image_eigendata('mnist', M, kernel_fn, classes)\n",
        "if M == 30000:\n",
        "    with open(\"/content/drive/My Drive/eigenlearning DB/kernel/mnist_30k_eigendata.pickle\", \"rb\") as handle:\n",
        "        eigendata = pickle.load(handle)\n",
        "\n",
        "lambdas = eigendata['eigenvals']\n",
        "f_terms = eigendata['f_terms']\n",
        "deltas_over_n = np.array([10**i for i in np.linspace(2, -8, 30)])\n",
        "ns = [16, 125, 1000, 8000]\n",
        "\n",
        "results = theory_mses(lambdas, f_terms, deltas_over_n, ns, M)\n",
        "with open(\"/content/drive/My Drive/eigenlearning DB/kernel/r_mnist_theory_30k_04_59.pickle\", \"wb\") as handle:\n",
        "    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)"
      ],
      "metadata": {
        "id": "d2E9dUkHXBhX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "ns = [16, 125, 1000, 8000]\n",
        "deltas_over_n = np.array([10**i for i in np.linspace(2, -8, 30)])\n",
        "n_trials = 10\n",
        "\n",
        "results = {}\n",
        "results['ns'] = ns\n",
        "results['deltas_over_n'] = deltas_over_n\n",
        "results['n_trials'] = n_trials\n",
        "\n",
        "classes = [[0,1,2,3,4],[5,6,7,8,9]]\n",
        "net_fns = utils.get_net_fns(width=500, d_out=1, n_hidden_layers=4)\n",
        "for n in ns:\n",
        "    for delta_over_n in deltas_over_n:\n",
        "        delta = delta_over_n * n\n",
        "        stats = measures.learning_measure_statistics(net_fns, 'mnist', n,\n",
        "                                            classes=classes, pred_type='kernel',\n",
        "                                            n_trials=n_trials, n_test=2000, ridge=delta,\n",
        "                                            compute_train_measures=True)\n",
        "        results[(n, delta)] = stats['kernel']\n",
        "        print('.', end='')\n",
        "    print()\n",
        "\n",
        "with open(\"/content/drive/My Drive/eigenlearning DB/kernel/r_mnist_04_59.pickle\", \"wb\") as handle:\n",
        "    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)"
      ],
      "metadata": {
        "id": "T7FR8En_Xv2u"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Neural networks"
      ],
      "metadata": {
        "id": "QkBUB8zlY-ML"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "48F1P-wvB7FQ"
      },
      "outputs": [],
      "source": [
        "def modified_resnet():\n",
        "    model = resnet18(num_classes=10)\n",
        "    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=1, padding=1, bias=False)\n",
        "    model.maxpool = nn.Identity()\n",
        "    return model\n",
        "\n",
        "class SoftError(torchmetrics.Metric):\n",
        "    full_state_update = False\n",
        "\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.add_state(\"correct\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n",
        "        self.add_state(\"total\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n",
        "\n",
        "    def update(self, preds: torch.Tensor, target: torch.Tensor):\n",
        "        soft = preds[range(preds.shape[0]), target]\n",
        "\n",
        "        self.correct += target.numel() - torch.sum(soft)\n",
        "        self.total += target.numel()\n",
        "\n",
        "    def compute(self):\n",
        "        return self.correct / self.total\n",
        "\n",
        "\n",
        "class LitResnet(pl.LightningModule):\n",
        "\n",
        "    def __init__(self, lr=0.1, dataset_size=50000):\n",
        "        super().__init__()\n",
        "\n",
        "        self.rng = torch.Generator().manual_seed(40)\n",
        "        self.lr = lr\n",
        "\n",
        "        self.n_classes = 10\n",
        "        self.dims = (3, 32, 32)\n",
        "        self.datasize = dataset_size\n",
        "\n",
        "        self.model = modified_resnet()\n",
        "\n",
        "        self.test_error = 1 - torchmetrics.Accuracy()\n",
        "        self.train_error = 1 - torchmetrics.Accuracy()\n",
        "        self.soft_error = SoftError()\n",
        "\n",
        "        self.train_step_num = 0\n",
        "        \n",
        "\n",
        "    def forward(self, x):\n",
        "        out = self.model(x)\n",
        "        return F.log_softmax(out, dim=1)\n",
        "\n",
        "    def training_step(self, batch, batch_idx):\n",
        "        x, y = batch\n",
        "        logits = self(x)\n",
        "        loss = F.nll_loss(logits, y)\n",
        "        preds = torch.argmax(logits, dim=1)\n",
        "        self.train_error.update(preds, y)\n",
        "\n",
        "        self.log(\"TrainLoss\", loss)\n",
        "        self.train_step_num += 1\n",
        "        return loss\n",
        "\n",
        "    def training_epoch_end(self, outs):\n",
        "        pass\n",
        "\n",
        "    def validation_step(self, batch, batch_idx):\n",
        "        x, y = batch\n",
        "        logits = self(x)\n",
        "        preds = torch.argmax(logits, dim=1)\n",
        "        self.test_error.update(preds, y)\n",
        "        self.soft_error.update(torch.exp(logits), y)\n",
        "\n",
        "        self.log(\"TestError\", self.test_error, prog_bar=True)\n",
        "        self.log(\"SoftError\", self.soft_error)\n",
        "        if self.train_step_num > 0:\n",
        "            self.log(\"TrainError\", self.train_error)\n",
        "\n",
        "    def predict_step(self, batch, batch_idx, dataloader_idx=0):\n",
        "        x, y = batch\n",
        "        y_hat = self.model(x)\n",
        "        return y_hat\n",
        "\n",
        "    def configure_optimizers(self):\n",
        "        optimizer = torch.optim.SGD(\n",
        "            self.parameters(),\n",
        "            lr=self.lr,\n",
        "            momentum=0.9,\n",
        "        )\n",
        "        max_epochs = MAX_STEPS / self.datasize * BATCH_SZ\n",
        "        scheduler_dict = {\n",
        "            \"scheduler\": torch.optim.lr_scheduler.CosineAnnealingLR(\n",
        "                optimizer,\n",
        "                T_max=max_epochs\n",
        "            ),\n",
        "            \"interval\": \"epoch\",\n",
        "        }\n",
        "        return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler_dict}\n",
        "\n",
        "    ####################\n",
        "    # DATA RELATED HOOKS\n",
        "    ####################\n",
        "\n",
        "    def prepare_data(self):\n",
        "        # download data\n",
        "        CIFAR10(\"cifar10/\", download=True)\n",
        "    \n",
        "    def setup(self, stage):\n",
        "        train_transforms = transforms.Compose([\n",
        "            transforms.RandomHorizontalFlip(),\n",
        "            transforms.RandomCrop(32, padding=4),\n",
        "            transforms.ToTensor(),\n",
        "        ])\n",
        "        self.traindata = CIFAR10(\"cifar10/\", train=True, download=True,\n",
        "                                 transform=train_transforms)\n",
        "        self.testdata = CIFAR10(\"cifar10/\", train=False, download=True,\n",
        "                                transform=transforms.ToTensor())\n",
        "\n",
        "    def train_dataloader(self):\n",
        "        small_trainset, _ = random_split(self.traindata,\n",
        "                                         [self.datasize, len(self.traindata)-self.datasize])\n",
        "        return DataLoader(small_trainset, batch_size=BATCH_SZ, num_workers=2, shuffle=True)\n",
        "\n",
        "    def val_dataloader(self):\n",
        "        return DataLoader(self.testdata, batch_size=BATCH_SZ, num_workers=2)\n",
        "\n",
        "\n",
        "from pytorch_lightning.loops.epoch import TrainingEpochLoop\n",
        "class LogValLoop(TrainingEpochLoop):\n",
        "\n",
        "    def set_val_schedule(self, val_schedule):\n",
        "        self.val_sched = val_schedule\n",
        "    \n",
        "    def _should_check_val_fx(self) -> bool:\n",
        "        \"\"\"Decide if we should run validation.\"\"\"\n",
        "        if not self._should_check_val_epoch():\n",
        "            return False\n",
        "\n",
        "        # val_check_batch is inf for iterable datasets with no length defined\n",
        "        is_infinite_dataset = self.trainer.val_check_batch == float(\"inf\")\n",
        "        is_last_batch = self.batch_progress.is_last_batch\n",
        "        if is_last_batch and is_infinite_dataset:\n",
        "            return True\n",
        "\n",
        "        if self.trainer.should_stop:\n",
        "            return True\n",
        "\n",
        "        # TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch\n",
        "        is_val_check_batch = is_last_batch\n",
        "        if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:\n",
        "            is_val_check_batch = (self.batch_idx + 1) % self.trainer.limit_train_batches == 0\n",
        "        elif self.trainer.val_check_batch != float(\"inf\"):\n",
        "            # if `check_val_every_n_epoch is `None`, run a validation loop every n training batches\n",
        "            # else condition it based on the batch_idx of the current epoch\n",
        "            current_iteration = self.total_batch_idx if self.trainer.check_val_every_n_epoch is None else self.batch_idx\n",
        "            is_val_check_batch = (current_iteration + 1) % self.trainer.val_check_batch == 0\n",
        "\n",
        "        \n",
        "        if self.val_sched is not None:\n",
        "            threshold = self.val_sched[0] if len(self.val_sched)>0 else np.inf\n",
        "            if self.total_batch_idx < threshold:\n",
        "                return False\n",
        "            else:\n",
        "                self.val_sched = self.val_sched[1:]\n",
        "                return True\n",
        "        return is_val_check_batch"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "BATCH_SZ = 128 if torch.cuda.is_available() else 64\n",
        "MAX_STEPS = 15000\n",
        "\n",
        "torch.manual_seed(42)\n",
        "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(DEVICE)"
      ],
      "metadata": {
        "id": "3plTEHBQZLu2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5LrUbwk2s5S_"
      },
      "outputs": [],
      "source": [
        "%reload_ext tensorboard\n",
        "%tensorboard --logdir=lightning_logs/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eNednse5CEXm"
      },
      "outputs": [],
      "source": [
        "count = 0\n",
        "max_steps = [4000, 6000, 10000, 14000]\n",
        "ns =  [500, 2000, 8000, 32000]\n",
        "n_trials = 5\n",
        "val_sched = np.geomspace(10, 13800, 50).astype(int)\n",
        "\n",
        "exp_details = {\n",
        "    \"ns\": ns,\n",
        "    \"n_trials\": n_trials,\n",
        "    \"val_sched\": val_sched\n",
        "}\n",
        "with open(\"/content/drive/My Drive/eigenlearning DB/NN_DB_exp_details.pickle\", \"wb\") as handle:\n",
        "    pickle.dump(exp_details, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
        "\n",
        "for run in range(n_trials):\n",
        "    for i, n in enumerate(ns):\n",
        "        model = LitResnet(lr=0.1, dataset_size=n)\n",
        "\n",
        "        loggers = [TensorBoardLogger(save_dir=\".\"),\n",
        "                CSVLogger(save_dir=\".\")]\n",
        "\n",
        "        trainer = pl.Trainer(\n",
        "            accelerator=\"auto\",\n",
        "            devices=1 if torch.cuda.is_available() else None,\n",
        "            callbacks=[TQDMProgressBar(refresh_rate=1),\n",
        "                       LearningRateMonitor(),],\n",
        "            logger=loggers,\n",
        "            log_every_n_steps=20,\n",
        "            val_check_interval=10, check_val_every_n_epoch=None,\n",
        "        )\n",
        "        val_loop = LogValLoop(max_steps=max_steps[i])\n",
        "        val_loop.set_val_schedule(val_sched)\n",
        "        val_loop.trainer = trainer\n",
        "        trainer.fit_loop.connect(epoch_loop=val_loop)\n",
        "        trainer.fit(model)\n",
        "\n",
        "        exp_name = '{}'.format(n)\n",
        "        version = count\n",
        "        shutil.move(\"/content/lightning_logs/version_{}/metrics.csv\".format(version),\n",
        "                    \"/content/drive/My Drive/eigenlearning DB/metrics_{}_{}.csv\".format(exp_name, run)) \n",
        "        count += 1\n",
        "\n",
        "        del model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_IM3OyZHoqRG"
      },
      "outputs": [],
      "source": [
        "# !rm -rf lightning_logs"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Plotting"
      ],
      "metadata": {
        "id": "Y8EiMkP-Zeb9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from matplotlib.patches import Patch\n",
        "from matplotlib.lines import Line2D\n",
        "from matplotlib.colors import hsv_to_rgb\n",
        "\n",
        "rcsetup()"
      ],
      "metadata": {
        "id": "Mah-9dlAZ123"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Synthetic experiments"
      ],
      "metadata": {
        "id": "cb_2D5ztZh3b"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "expo=2\n",
        "fname = \"r_spectrum_\" + str(expo)\n",
        "with open(\"/content/drive/My Drive/eigenlearning DB/kernel/{}_theory_30k_orders2.pickle\".format(fname), \"rb\") as handle:\n",
        "    results_theory = pickle.load(handle)\n",
        "            \n",
        "fig, ax = plt.subplots()\n",
        "\n",
        "M = results_theory['M']\n",
        "lambdas = results_theory['eigenvals']\n",
        "deltas_over_n = results_theory['deltas_over_n']\n",
        "show_ns = results_theory['ns']\n",
        "\n",
        "skip = []\n",
        "label_pos = [.76, .565, .365, .165]\n",
        "for i, n in enumerate(show_ns):\n",
        "    if n in skip:\n",
        "        continue\n",
        "    mses, mses_std, train_mses, train_mses_std = [np.zeros(len(deltas_over_n)) for _ in range(4)]\n",
        "    mses_theory, train_mses_theory = [np.zeros(len(deltas_over_n)) for _ in range(2)]\n",
        "    for j, delta in enumerate(deltas_over_n * n):    \n",
        "        result = results_theory[(n,delta)]\n",
        "        mses_theory[j] = result['mse']\n",
        "        train_mses_theory[j] = result['train_mse']\n",
        "\n",
        "    # Correction\n",
        "    mses_theory = 1/(M-n+1e-4) * (M * mses_theory - n * train_mses_theory)\n",
        "\n",
        "    tau_eff = 1 / deltas_over_n\n",
        "\n",
        "    if i < len(show_ns) - 1:\n",
        "        color = get_plot_color(i, len(show_ns))\n",
        "        ax.text(0.99, label_pos[i], r'$n={}$'.format(n), color=color, ha='right', transform=ax.transAxes, fontsize=15)\n",
        "        kappa_0 = ((np.pi / expo) / np.sin(np.pi / expo) / n) ** expo\n",
        "        ax.axvline(1/kappa_0, ls=\"--\", color=color, alpha=0.23)\n",
        "\n",
        "        ax.plot(tau_eff, mses_theory, color=color, ls='-')\n",
        "        ax.plot(tau_eff, train_mses_theory, color=color, ls=':')\n",
        "    else:\n",
        "        color = (0, 0, 0, .75)\n",
        "        ax.plot(tau_eff, mses_theory, color=color, ls='-.')\n",
        "\n",
        "\n",
        "ax.set_xlabel(r'$\\tau_\\mathrm{eff}$')\n",
        "ax.set_ylabel(r'$\\mathcal{E}(f)$')\n",
        "ax.set_xscale('log')\n",
        "ax.set_yscale('log')\n",
        "ax.set_ylim(4e-5,1.3e0)\n",
        "ax.set_xlim(1e-2,1e12)\n",
        "plt.legend(\n",
        "        (\n",
        "            Line2D([0], [0], color='xkcd:navy', alpha=.75, ls=\":\"),\n",
        "            Line2D([0], [0], color='xkcd:navy', alpha=.75, ls=\"-\"),\n",
        "            Line2D([0], [0], color='xkcd:navy', alpha=.65, ls=\"-.\"),),\n",
        "        (\n",
        "            \"train\",\n",
        "            \"test\",\n",
        "            r\"$n\\to\\infty$\"),\n",
        "        framealpha=0.95,\n",
        "        fontsize=15.5,\n",
        "        loc='lower left',\n",
        "\n",
        "    )\n",
        "plt.axhline(1, ls='--', color='black', lw=1, alpha=0.4)\n",
        "plt.tight_layout()\n",
        "# plt.show()\n",
        "plt.savefig(\"KRR_synthetic_deep_bootstrap.pdf\", bbox_inches='tight')"
      ],
      "metadata": {
        "id": "JNI4NIjrZjTa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## MNIST KRR (theory vs empirical)"
      ],
      "metadata": {
        "id": "EkSKPr16aAt3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "with open(\"/content/drive/My Drive/eigenlearning DB/kernel/r_mnist_theory_30k_04_59_orders2.pickle\", \"rb\") as handle:\n",
        "    results_theory = pickle.load(handle)\n",
        "with open(\"/content/drive/My Drive/eigenlearning DB/kernel/r_mnist_orders2_04_59.pickle\", \"rb\") as handle:\n",
        "    results = pickle.load(handle)\n",
        "            \n",
        "fig, ax = plt.subplots()\n",
        "\n",
        "M = results_theory['M']\n",
        "lambdas = results_theory['eigenvals']\n",
        "deltas_over_n = results_theory['deltas_over_n']\n",
        "show_ns = results_theory['ns']\n",
        "\n",
        "skip = []\n",
        "label_pos = [.89, .7, .43, .15]\n",
        "for i, n in enumerate(show_ns):\n",
        "    if n in skip:\n",
        "        continue\n",
        "    mses, mses_std, train_mses, train_mses_std = [np.zeros(len(deltas_over_n)) for _ in range(4)]\n",
        "    mses_theory, train_mses_theory = [np.zeros(len(deltas_over_n)) for _ in range(2)]\n",
        "    for j, delta in enumerate(deltas_over_n * n):\n",
        "        result = results[(n,delta)]\n",
        "        mses[j], mses_std[j] = result['mse']\n",
        "        train_mses[j], train_mses_std[j] = result['train_mse']\n",
        "    \n",
        "        result = results_theory[(n,delta)]\n",
        "        mses_theory[j] = result['mse']\n",
        "        train_mses_theory[j] = result['train_mse']\n",
        "\n",
        "    color = get_plot_color(i, len(show_ns))\n",
        "    # Correction\n",
        "    mses_theory = 1/(M-n) * (M * mses_theory - n * train_mses_theory)\n",
        "    tau_eff = 1 / deltas_over_n\n",
        "\n",
        "    kappa_0 = measures.find_C(n, lambdas)\n",
        "    ax.axvline(1/kappa_0, ls=\"--\", color=color, alpha=0.4)\n",
        "\n",
        "    ax.plot(tau_eff, mses_theory, color=color, label=n)\n",
        "    ax.plot(tau_eff, train_mses_theory, color=color)\n",
        "\n",
        "    ax.plot(tau_eff, mses, ls=\"-.\", color=color, alpha=0.5)\n",
        "    ax.fill_between(tau_eff, mses-mses_std, mses+mses_std, color=color, alpha=0.13)\n",
        "    ax.plot(tau_eff, train_mses, ls=\":\", color=color, alpha=0.5)\n",
        "    ax.fill_between(tau_eff, train_mses-train_mses_std, train_mses+train_mses_std, color=color, alpha=0.13)\n",
        "\n",
        "    ax.text(0.99, label_pos[i], r'$n={}$'.format(n), color=color, ha='right', transform=ax.transAxes, fontsize=14)\n",
        "\n",
        "ax.text(0.01, .92, \"B\", color='black', transform=ax.transAxes, fontsize=15, fontweight='bold')\n",
        "ax.set_xlabel(r'$\\tau_\\mathrm{eff}$')\n",
        "ax.set_ylabel(r'$\\mathcal{E}(f)$')\n",
        "ax.set_xscale('log')\n",
        "ax.set_yscale('log')\n",
        "ax.set_ylim(10e-2,1.3e0)\n",
        "ax.set_xlim(1e-2,1e6)\n",
        "plt.legend(\n",
        "    (\n",
        "        (Line2D([0], [0], color='xkcd:navy', alpha=.5, ls=\":\"), Patch(color='k', alpha=0.13, lw=0)),\n",
        "        (Line2D([0], [0], color='xkcd:navy', alpha=.5, ls=\"-.\"), Patch(color='k', alpha=0.13, lw=0)),\n",
        "        Line2D([0], [0], color='xkcd:navy'),),\n",
        "    (\n",
        "        \"KRR train\",\n",
        "        \"KRR test\",\n",
        "        \"theory\",),\n",
        "     framealpha=0.95, fontsize=13)\n",
        "plt.axhline(1, ls='--', color='black', lw=1, alpha=0.4)\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "# plt.savefig(\"KRR_deep_bootstrap.pdf\", bbox_inches='tight')"
      ],
      "metadata": {
        "id": "PFxfcutqZ8pT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Neural networks"
      ],
      "metadata": {
        "id": "b9siK3QkZ3Xs"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QZqoZzH2aepY"
      },
      "outputs": [],
      "source": [
        "import csv\n",
        "from scipy.ndimage import gaussian_filter1d\n",
        "\n",
        "with open(\"/content/drive/My Drive/eigenlearning DB/NN_DB_exp_details.pickle\", \"rb\") as handle:\n",
        "    exp_details = pickle.load(handle)\n",
        "\n",
        "ns = exp_details['ns']\n",
        "n_trials = exp_details['n_trials']\n",
        "trainsteps = exp_details['val_sched']\n",
        "\n",
        "all_data = {}\n",
        "skip = []\n",
        "for n in ns:\n",
        "    for run in range(n_trials):\n",
        "        if run in skip:\n",
        "            continue\n",
        "        exp_name = str(n)\n",
        "        fname = \"/content/drive/My Drive/eigenlearning DB/metrics_{}_{}.csv\".format(exp_name, run)\n",
        "        with open(fname, 'r', newline='') as csvfile:\n",
        "            reader = csv.reader(csvfile, delimiter=',', quotechar='|')\n",
        "            lines = list(reader)\n",
        "            header, table = lines[0], lines[1:]\n",
        "            run_data = {}\n",
        "            for i, metric in enumerate(header):\n",
        "                steps = [float(line[1]) for line in table if line[i]!='']\n",
        "                data = [float(line[i]) for line in table if line[i]!='']\n",
        "                run_data[metric] = (steps, data)\n",
        "            all_data[\"{}_{}\".format(exp_name, run)] = run_data\n",
        "\n",
        "def make_DB_plot(all_data, log=True, train=True, error=False):\n",
        "    ax = plt.gca()\n",
        "    x = np.hstack((trainsteps, [1e5])).astype(int)\n",
        "    label_pos = [.79, .645, .38, .165]\n",
        "    for i, n in enumerate(ns):\n",
        "        for metric in [\"TestError\", \"TrainError\"]:\n",
        "            y = []\n",
        "            for run in range(n_trials):\n",
        "                if run in skip:\n",
        "                    continue\n",
        "                exp_name = '{}'.format(n)\n",
        "                steps, data = all_data[\"{}_{}\".format(exp_name, run)][metric]\n",
        "                data = data + data[-1:]*(len(x)-len(data))\n",
        "                y.append(data)\n",
        "            yerr = np.std(y, axis=0)\n",
        "            y = np.mean(y, axis=0)\n",
        "            y_smooth = y.copy()\n",
        "            y_smooth[1:] = gaussian_filter1d(y_smooth[1:], sigma=1, mode='nearest')\n",
        "            yerr_smooth = gaussian_filter1d(yerr, sigma=1, mode='nearest')\n",
        "\n",
        "            color = get_plot_color(i, len(ns))\n",
        "            if metric == \"TestError\":\n",
        "                plt.plot(x, y_smooth, '-', label=n, color=color, alpha=0.85)\n",
        "                if error:\n",
        "                    ax.fill_between(x, y_smooth-yerr_smooth, y_smooth+yerr_smooth, color=color, alpha=0.13)\n",
        "            else:\n",
        "                if train:\n",
        "                    plt.plot(x, y_smooth, ':', color=color)\n",
        "                    if error:\n",
        "                        ax.fill_between(x, y_smooth-yerr_smooth, y_smooth+yerr_smooth, color=color, alpha=0.13)\n",
        "        \n",
        "        ax.text(0.99, label_pos[i], r'$n={}$'.format(n), color=color, ha='right', transform=ax.transAxes, fontsize=14)\n",
        "    \n",
        "    ax.text(0.01, .92, \"A\", color='black', transform=ax.transAxes, fontsize=15, fontweight='bold')\n",
        "    if log:\n",
        "        plt.xscale('log')\n",
        "        plt.yscale('log')\n",
        "        if train:\n",
        "            plt.ylim(7e-2, 1.2e0)\n",
        "        else:\n",
        "            plt.ylim(7e-2, 1e0)\n",
        "        plt.xlim(1e1, 6e4)\n",
        "    else:\n",
        "        plt.ylim(0, 1e0)\n",
        "        plt.xlim(0, 1.3e4)\n",
        "\n",
        "    plt.legend(\n",
        "        (\n",
        "            Line2D([0], [0], color='xkcd:navy', alpha=.65, ls=\":\"),\n",
        "            Line2D([0], [0], color='xkcd:navy', alpha=.65, ls=\"-\"),),\n",
        "        (\n",
        "            \"NN train\",\n",
        "            \"NN test\",),\n",
        "        framealpha=0.95)\n",
        "    plt.xlabel('Train steps')\n",
        "    plt.ylabel(\"Classification Error\")\n",
        "    plt.tight_layout()\n",
        "    plt.axhline(0.9, ls='--', color='black', lw=1, alpha=0.4)\n",
        "\n",
        "    names = [\n",
        "        \"train\" if train else \"test\",\n",
        "        \"err\" if error else \"noerr\",\n",
        "        \"log\" if log else \"lin\"\n",
        "    ]\n",
        "    fname = \"_\".join(names)\n",
        "    # plt.show()\n",
        "    plt.savefig(\"NN_deep_bootstrap.pdf\", bbox_inches='tight')\n",
        "\n",
        "make_DB_plot(all_data)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}