{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "65245462",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: TF_FORCE_UNIFIED_MEMORY=1\n"
     ]
    }
   ],
   "source": [
    "%env TF_FORCE_UNIFIED_MEMORY=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "74d4816a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sine config (from chelsea finn)\n",
    "config = {}\n",
    "config[\"n_epochs\"] = 70000\n",
    "config[\"n_tasks_per_epoch\"] = 24\n",
    "config[\"K\"] = 10\n",
    "config[\"L\"] = 10\n",
    "config[\"n_updates\"] = 5\n",
    "config[\"n_updates_test\"]= 10\n",
    "config[\"meta_lr\"] = 0.001\n",
    "config[\"inner_lr\"] = 1e-3\n",
    "config[\"data_noise\"] = 0.05 # not in cbfinn but we add it to compare to our algorithm\n",
    "config[\"n_test_tasks\"] = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "38465e3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import vmap\n",
    "from jax import numpy as np\n",
    "from flax import struct\n",
    "from typing import Any, Callable\n",
    "from flax import core\n",
    "import optax\n",
    "import models\n",
    "from jax import random\n",
    "from jax import value_and_grad, grad\n",
    "from functools import partial\n",
    "from jax.tree_util import tree_map\n",
    "from jax import jit\n",
    "import time\n",
    "from matplotlib import pyplot as plt\n",
    "from jax.lax import scan\n",
    "import pickle\n",
    "import dataset_multi_infinite"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b501dcce",
   "metadata": {},
   "source": [
    "## Inner and outer loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "57bfab88",
   "metadata": {},
   "outputs": [],
   "source": [
    "def error_fn(predictions, gt):\n",
    "    return np.mean( (predictions - gt)**2 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "367d8046",
   "metadata": {},
   "outputs": [],
   "source": [
    "def inner_loss(current_params, x_a, y_a, apply_fn):\n",
    "    predictions = apply_fn(current_params, x_a)\n",
    "    \n",
    "    return error_fn(predictions, y_a)\n",
    "\n",
    "def gd_step0(inner_lr, param_value, param_grad):\n",
    "    return param_value - inner_lr * param_grad\n",
    "\n",
    "def inner_updates(current_params, x_a, y_a, n_updates, inner_lr, apply_fn):\n",
    "    def f(parameters, x):\n",
    "        inner_gradients = grad(inner_loss)(parameters, x_a, y_a, apply_fn)\n",
    "        parameters = tree_map(partial(gd_step0, inner_lr), parameters, inner_gradients)\n",
    "        \n",
    "        return parameters, None\n",
    "    \n",
    "    updated_params, _ = scan(f, current_params, None, n_updates)\n",
    "    \n",
    "    return updated_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e73df2ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "def outer_loss_single_task(current_params, x_a, y_a, x_b, y_b, n_updates, inner_lr, apply_fn):\n",
    "    updated_params = inner_updates(current_params, x_a, y_a, n_updates, inner_lr, apply_fn)\n",
    "    \n",
    "    predictions = apply_fn(updated_params, x_b)\n",
    "    \n",
    "    return error_fn(predictions, y_b)\n",
    "\n",
    "def outer_loss(current_params, x_a, y_a, x_b, y_b, n_updates, inner_lr, apply_fn):\n",
    "    unaveraged_losses = vmap(partial(outer_loss_single_task, current_params=current_params, n_updates=n_updates, inner_lr=inner_lr, apply_fn=apply_fn))(x_a=x_a, y_a=y_a, x_b=x_b, y_b=y_b)\n",
    "    return np.mean(unaveraged_losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "398ed5a0",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3401ab63",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_train_batch_fn(key):\n",
    "    return dataset_multi_infinite.get_test_batch(key, config[\"n_tasks_per_epoch\"], config[\"K\"], config[\"L\"], config[\"data_noise\"])\n",
    "    \n",
    "def get_test_batch_fn(key):\n",
    "    return dataset_multi_infinite.get_test_batch(key, config[\"n_test_tasks\"], config[\"K\"], config[\"L\"], config[\"data_noise\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "737f51e8",
   "metadata": {},
   "source": [
    "## State"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "584350e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TrainStateNoProj(struct.PyTreeNode):\n",
    "    step: int\n",
    "    apply_fn: Callable = struct.field(pytree_node=False)\n",
    "    params: core.FrozenDict[str, Any]\n",
    "    tx_params: optax.GradientTransformation = struct.field(pytree_node=False)\n",
    "    opt_state_params: optax.OptState\n",
    "    inner_lr: float\n",
    "    n_updates: int\n",
    "    \n",
    "    def apply_gradients(self, *, grads_params, **kwargs):\n",
    "        \"\"\"\n",
    "        Updates both the params and the scaling matrix\n",
    "        Also requires new_batch_stats to keep track of what has been seen by the network\n",
    "        \"\"\"\n",
    "        # params part\n",
    "        updates_params, new_opt_state_params = self.tx_params.update(grads_params, self.opt_state_params, self.params)\n",
    "        new_params = optax.apply_updates(self.params, updates_params)\n",
    "\n",
    "        return self.replace(\n",
    "            step=self.step + 1,\n",
    "            params=new_params,\n",
    "            opt_state_params=new_opt_state_params,\n",
    "            **kwargs,\n",
    "        )\n",
    "\n",
    "\n",
    "    @classmethod\n",
    "    def create(cls, *, apply_fn, params, tx_params, inner_lr, n_updates, **kwargs):\n",
    "        opt_state_params = tx_params.init(params)\n",
    "        return cls(\n",
    "            step=0,\n",
    "            apply_fn=apply_fn,\n",
    "            params=params,\n",
    "            tx_params=tx_params,\n",
    "            opt_state_params=opt_state_params,\n",
    "            inner_lr = inner_lr,\n",
    "            n_updates = n_updates,\n",
    "            **kwargs,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c42d6d6",
   "metadata": {},
   "source": [
    "## Step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b337913e",
   "metadata": {},
   "outputs": [],
   "source": [
    "@partial(jit, static_argnums=(5, 6, 7))\n",
    "def get_loss_and_gradients(params, x_a, y_a, x_b, y_b, n_updates, inner_lr, apply_fn):\n",
    "    return value_and_grad(outer_loss)(params, x_a, y_a, x_b, y_b, n_updates, inner_lr, apply_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a9cbd323",
   "metadata": {},
   "outputs": [],
   "source": [
    "def step(key, state):\n",
    "    x_a, y_a, x_b, y_b = get_train_batch_fn(key)\n",
    "    \n",
    "    loss, gradients = get_loss_and_gradients(state.params, x_a, y_a, x_b, y_b, state.n_updates, state.inner_lr, state.apply_fn)\n",
    "    \n",
    "    state = state.apply_gradients(grads_params = gradients)\n",
    "    \n",
    "    return state, loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17029a09",
   "metadata": {},
   "source": [
    "## Test during training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "86bd7c3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_during_training(key, state):\n",
    "    x_a, y_a, x_b, y_b = get_test_batch_fn(key)\n",
    "    \n",
    "    def f(carry, task):\n",
    "        x_a, y_a, x_b, y_b = task\n",
    "        \n",
    "        updated_params = inner_updates(state.params, x_a, y_a, config[\"n_updates_test\"], state.inner_lr, state.apply_fn)\n",
    "        predictions = state.apply_fn(updated_params, x_b)\n",
    "        \n",
    "        return None, error_fn(predictions, y_b)\n",
    "    \n",
    "    _, errors = scan(f, None, (x_a, y_a, x_b, y_b))\n",
    "    return np.mean(errors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3ba73c78",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-27 14:03:38.732845: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /state/partition1/llgrid/pkg/anaconda/anaconda3-2022a/pkgs/cudatoolkit-11.3.1-h2bc3f7f_2/lib\n",
      "2022-08-27 14:03:38.732881: W external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)\n",
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "key = random.PRNGKey(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "75c08a80",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.small_network(40, \"relu\", 1)\n",
    "\n",
    "def apply_fn(params, inputs):\n",
    "    return model.apply({\"params\": params}, inputs)\n",
    "\n",
    "key, key_init0, key_init1 = random.split(key, 3)\n",
    "batch = get_train_batch_fn(key_init0)\n",
    "init_vars = model.init(key_init1, batch[0][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4aa799f",
   "metadata": {},
   "source": [
    "## Option #1: training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6ad756ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_params = optax.adam(learning_rate = config[\"meta_lr\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8ee5691d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 | 2.5701 (2.7633 s / epoch)\n",
      "Error: 3.2595062255859375\n",
      "100 | 1.2563 (0.0305 s / epoch)\n",
      "200 | 2.9994 (0.0240 s / epoch)\n",
      "300 | 2.1433 (0.0310 s / epoch)\n",
      "400 | 1.3627 (0.0388 s / epoch)\n",
      "500 | 2.2790 (0.0228 s / epoch)\n",
      "Error: 1.9137661457061768\n",
      "600 | 1.9467 (0.0203 s / epoch)\n",
      "700 | 2.1330 (0.0322 s / epoch)\n",
      "800 | 2.2484 (0.0386 s / epoch)\n",
      "900 | 2.7040 (0.0370 s / epoch)\n",
      "1000 | 1.3453 (0.0396 s / epoch)\n",
      "Error: 1.8821662664413452\n",
      "1100 | 2.0068 (0.0329 s / epoch)\n",
      "1200 | 2.1390 (0.0438 s / epoch)\n",
      "1300 | 0.9381 (0.0302 s / epoch)\n",
      "1400 | 2.0574 (0.0384 s / epoch)\n",
      "1500 | 1.0602 (0.0348 s / epoch)\n",
      "Error: 1.0699690580368042\n",
      "1600 | 1.1588 (0.0372 s / epoch)\n",
      "1700 | 0.7188 (0.0265 s / epoch)\n",
      "1800 | 1.1090 (0.0371 s / epoch)\n",
      "1900 | 0.7507 (0.0426 s / epoch)\n",
      "2000 | 0.5930 (0.0380 s / epoch)\n",
      "Error: 0.6749621629714966\n",
      "2100 | 0.9128 (0.0316 s / epoch)\n",
      "2200 | 0.5817 (0.0259 s / epoch)\n",
      "2300 | 0.6314 (0.0487 s / epoch)\n",
      "2400 | 0.8022 (0.0236 s / epoch)\n",
      "2500 | 1.3697 (0.0379 s / epoch)\n",
      "Error: 0.48777300119400024\n",
      "2600 | 0.7589 (0.0367 s / epoch)\n",
      "2700 | 0.6271 (0.0234 s / epoch)\n",
      "2800 | 0.5513 (0.0360 s / epoch)\n",
      "2900 | 0.5130 (0.0289 s / epoch)\n",
      "3000 | 0.3323 (0.0241 s / epoch)\n",
      "Error: 0.5485170483589172\n",
      "3100 | 0.7185 (0.0262 s / epoch)\n",
      "3200 | 0.6923 (0.0341 s / epoch)\n",
      "3300 | 0.2737 (0.0399 s / epoch)\n",
      "3400 | 0.5327 (0.0346 s / epoch)\n",
      "3500 | 0.6918 (0.0369 s / epoch)\n",
      "Error: 0.5713149309158325\n",
      "3600 | 0.7124 (0.0292 s / epoch)\n",
      "3700 | 0.5824 (0.0330 s / epoch)\n",
      "3800 | 0.6253 (0.0241 s / epoch)\n",
      "3900 | 0.6083 (0.0403 s / epoch)\n",
      "4000 | 0.4309 (0.0318 s / epoch)\n",
      "Error: 0.5112074017524719\n",
      "4100 | 0.5302 (0.0287 s / epoch)\n",
      "4200 | 0.4390 (0.0383 s / epoch)\n",
      "4300 | 0.4699 (0.0431 s / epoch)\n",
      "4400 | 0.2674 (0.0291 s / epoch)\n",
      "4500 | 0.5470 (0.0331 s / epoch)\n",
      "Error: 0.31648656725883484\n",
      "4600 | 0.6125 (0.0221 s / epoch)\n",
      "4700 | 0.5276 (0.0286 s / epoch)\n",
      "4800 | 0.5386 (0.0298 s / epoch)\n",
      "4900 | 0.2889 (0.0349 s / epoch)\n",
      "5000 | 0.8084 (0.0370 s / epoch)\n",
      "Error: 0.36829566955566406\n",
      "5100 | 0.4195 (0.0277 s / epoch)\n",
      "5200 | 0.7641 (0.0354 s / epoch)\n",
      "5300 | 0.4608 (0.0445 s / epoch)\n",
      "5400 | 0.3812 (0.0262 s / epoch)\n",
      "5500 | 0.5793 (0.0391 s / epoch)\n",
      "Error: 0.3548576533794403\n",
      "5600 | 0.5606 (0.0381 s / epoch)\n",
      "5700 | 0.3670 (0.0479 s / epoch)\n",
      "5800 | 0.2527 (0.0242 s / epoch)\n",
      "5900 | 0.5517 (0.0339 s / epoch)\n",
      "6000 | 0.7846 (0.0253 s / epoch)\n",
      "Error: 0.2769606411457062\n",
      "6100 | 0.6792 (0.0363 s / epoch)\n",
      "6200 | 0.6834 (0.0365 s / epoch)\n",
      "6300 | 0.6215 (0.0490 s / epoch)\n",
      "6400 | 0.5858 (0.0253 s / epoch)\n",
      "6500 | 0.7375 (0.0316 s / epoch)\n",
      "Error: 0.41574081778526306\n",
      "6600 | 0.7387 (0.0298 s / epoch)\n",
      "6700 | 0.5582 (0.0399 s / epoch)\n",
      "6800 | 0.6129 (0.0354 s / epoch)\n",
      "6900 | 0.8810 (0.0342 s / epoch)\n",
      "7000 | 0.6131 (0.0404 s / epoch)\n",
      "Error: 0.32844269275665283\n",
      "7100 | 0.2421 (0.0326 s / epoch)\n",
      "7200 | 0.7757 (0.0316 s / epoch)\n",
      "7300 | 0.6859 (0.0363 s / epoch)\n",
      "7400 | 0.6057 (0.0264 s / epoch)\n",
      "7500 | 0.5384 (0.0395 s / epoch)\n",
      "Error: 0.36531639099121094\n",
      "7600 | 0.8670 (0.0272 s / epoch)\n",
      "7700 | 0.8008 (0.0346 s / epoch)\n",
      "7800 | 0.7530 (0.0313 s / epoch)\n",
      "7900 | 0.5404 (0.0268 s / epoch)\n",
      "8000 | 0.4298 (0.0357 s / epoch)\n",
      "Error: 0.44487398862838745\n",
      "8100 | 0.4975 (0.0320 s / epoch)\n",
      "8200 | 0.3348 (0.0379 s / epoch)\n",
      "8300 | 0.2990 (0.0392 s / epoch)\n",
      "8400 | 0.2681 (0.0383 s / epoch)\n",
      "8500 | 0.6216 (0.0233 s / epoch)\n",
      "Error: 0.31822872161865234\n",
      "8600 | 0.5342 (0.0271 s / epoch)\n",
      "8700 | 0.8457 (0.0436 s / epoch)\n",
      "8800 | 0.5767 (0.0429 s / epoch)\n",
      "8900 | 0.4947 (0.0370 s / epoch)\n",
      "9000 | 0.3592 (0.0417 s / epoch)\n",
      "Error: 0.3675214350223541\n",
      "9100 | 0.4006 (0.0419 s / epoch)\n",
      "9200 | 0.3181 (0.0495 s / epoch)\n",
      "9300 | 0.6429 (0.0345 s / epoch)\n",
      "9400 | 0.3873 (0.0374 s / epoch)\n",
      "9500 | 0.4270 (0.0248 s / epoch)\n",
      "Error: 0.45445117354393005\n",
      "9600 | 0.5593 (0.0242 s / epoch)\n",
      "9700 | 0.3938 (0.0421 s / epoch)\n",
      "9800 | 0.3357 (0.0348 s / epoch)\n",
      "9900 | 0.4131 (0.0285 s / epoch)\n",
      "10000 | 0.4783 (0.0404 s / epoch)\n",
      "Error: 0.46549203991889954\n",
      "10100 | 0.3826 (0.0420 s / epoch)\n",
      "10200 | 0.3570 (0.0374 s / epoch)\n",
      "10300 | 0.6803 (0.0365 s / epoch)\n",
      "10400 | 0.3465 (0.0402 s / epoch)\n",
      "10500 | 0.4949 (0.0398 s / epoch)\n",
      "Error: 0.585740864276886\n",
      "10600 | 1.0874 (0.0398 s / epoch)\n",
      "10700 | 0.3438 (0.0379 s / epoch)\n",
      "10800 | 0.7139 (0.0402 s / epoch)\n",
      "10900 | 0.5756 (0.0397 s / epoch)\n",
      "11000 | 0.4875 (0.0428 s / epoch)\n",
      "Error: 0.3968771994113922\n",
      "11100 | 0.3731 (0.0448 s / epoch)\n",
      "11200 | 0.4094 (0.0432 s / epoch)\n",
      "11300 | 0.6054 (0.0296 s / epoch)\n",
      "11400 | 0.5576 (0.0307 s / epoch)\n",
      "11500 | 0.4078 (0.0386 s / epoch)\n",
      "Error: 0.4451444149017334\n",
      "11600 | 0.5678 (0.0437 s / epoch)\n",
      "11700 | 0.5740 (0.0307 s / epoch)\n",
      "11800 | 0.2942 (0.0377 s / epoch)\n",
      "11900 | 0.2799 (0.0414 s / epoch)\n",
      "12000 | 0.6807 (0.0466 s / epoch)\n",
      "Error: 0.4579548239707947\n",
      "12100 | 0.7542 (0.0326 s / epoch)\n",
      "12200 | 0.3932 (0.0363 s / epoch)\n",
      "12300 | 0.7168 (0.0369 s / epoch)\n",
      "12400 | 0.5664 (0.0384 s / epoch)\n",
      "12500 | 0.6101 (0.0375 s / epoch)\n",
      "Error: 0.39547473192214966\n",
      "12600 | 0.6661 (0.0330 s / epoch)\n",
      "12700 | 0.3507 (0.0396 s / epoch)\n",
      "12800 | 0.7067 (0.0346 s / epoch)\n",
      "12900 | 0.6638 (0.0379 s / epoch)\n",
      "13000 | 0.2967 (0.0371 s / epoch)\n",
      "Error: 0.23147620260715485\n",
      "13100 | 0.4766 (0.0401 s / epoch)\n",
      "13200 | 0.6014 (0.0240 s / epoch)\n",
      "13300 | 0.5215 (0.0370 s / epoch)\n",
      "13400 | 0.5118 (0.0369 s / epoch)\n",
      "13500 | 0.5611 (0.0275 s / epoch)\n",
      "Error: 0.5682196617126465\n",
      "13600 | 0.2728 (0.0375 s / epoch)\n",
      "13700 | 0.5496 (0.0276 s / epoch)\n",
      "13800 | 0.5970 (0.0390 s / epoch)\n",
      "13900 | 0.5959 (0.0381 s / epoch)\n",
      "14000 | 0.5577 (0.0388 s / epoch)\n",
      "Error: 0.41671788692474365\n",
      "14100 | 0.5099 (0.0324 s / epoch)\n",
      "14200 | 0.5055 (0.0416 s / epoch)\n",
      "14300 | 0.8173 (0.0294 s / epoch)\n",
      "14400 | 0.6071 (0.0280 s / epoch)\n",
      "14500 | 0.5859 (0.0257 s / epoch)\n",
      "Error: 0.38650912046432495\n",
      "14600 | 0.4290 (0.0430 s / epoch)\n",
      "14700 | 0.3248 (0.0328 s / epoch)\n",
      "14800 | 0.6765 (0.0271 s / epoch)\n",
      "14900 | 0.7322 (0.0285 s / epoch)\n",
      "15000 | 0.6662 (0.0291 s / epoch)\n",
      "Error: 0.42815810441970825\n",
      "15100 | 0.6660 (0.0375 s / epoch)\n",
      "15200 | 0.2878 (0.0417 s / epoch)\n",
      "15300 | 0.6060 (0.0328 s / epoch)\n",
      "15400 | 0.4159 (0.0316 s / epoch)\n",
      "15500 | 0.5749 (0.0281 s / epoch)\n",
      "Error: 0.36233964562416077\n",
      "15600 | 0.4602 (0.0245 s / epoch)\n",
      "15700 | 0.4018 (0.0239 s / epoch)\n",
      "15800 | 0.3944 (0.0322 s / epoch)\n",
      "15900 | 0.3529 (0.0520 s / epoch)\n",
      "16000 | 0.3194 (0.0414 s / epoch)\n",
      "Error: 0.3761681616306305\n",
      "16100 | 0.3667 (0.0381 s / epoch)\n",
      "16200 | 0.5561 (0.0333 s / epoch)\n",
      "16300 | 0.4523 (0.0325 s / epoch)\n",
      "16400 | 0.3890 (0.0275 s / epoch)\n",
      "16500 | 0.4527 (0.0285 s / epoch)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Exception ignored in: <function WeakKeyDictionary.__init__.<locals>.remove at 0x7ff09bff2040>\n",
      "Traceback (most recent call last):\n",
      "  File \"/state/partition1/llgrid/pkg/anaconda/anaconda3-2022a/lib/python3.8/weakref.py\", line 345, in remove\n",
      "    def remove(k, selfref=ref(self)):\n",
      "KeyboardInterrupt: \n",
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "state = TrainStateNoProj.create(apply_fn=apply_fn, params=init_vars[\"params\"], tx_params=optimizer_params, inner_lr=config[\"inner_lr\"], n_updates=config[\"n_updates\"])\n",
    "\n",
    "losses = []\n",
    "errors_test = []\n",
    "\n",
    "for epoch_index in range(config[\"n_epochs\"]):\n",
    "    t = time.time_ns()\n",
    "    key, subkey = random.split(key)\n",
    "    state, current_loss = step(subkey, state)\n",
    "    \n",
    "    if epoch_index % 100 == 0:\n",
    "        print(f\"{epoch_index} | {current_loss:.4f} ({(time.time_ns() - t) / 10**9:.4f} s / epoch)\")\n",
    "        \n",
    "    if epoch_index % 500 == 0:\n",
    "        # test time\n",
    "        key, subkey = random.split(key)\n",
    "        mse_test = test_during_training(subkey, state)\n",
    "        errors_test.append(mse_test)\n",
    "        print(f\"Error: {mse_test}\")\n",
    "    \n",
    "    losses.append(current_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "60ee4c91",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = {}\n",
    "output[\"trained_params\"] = state.params\n",
    "output[\"losses\"]=losses\n",
    "output[\"errors_test\"]=errors_test\n",
    "output[\"config\"] = config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "208d2f01",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"logs_final/maml_mixed.pickle\", \"wb\") as handle:\n",
    "    pickle.dump(output, handle, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03f2177a",
   "metadata": {},
   "source": [
    "## Option #2: loading a previously trained network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "0b2aa1ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"logs_final/maml_mixed.pickle\", \"rb\") as handle:\n",
    "    output = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "429e0381",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = output[\"config\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "bf7abf40",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Write i=0 to test the prediction on a sine; write i=-1 to test on a line\n",
    "i = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "1d09aeea",
   "metadata": {},
   "outputs": [],
   "source": [
    "key, subkey = random.split(key)\n",
    "x_a, y_a, x_b, y_b = get_train_batch_fn(subkey)\n",
    "updated_params = inner_updates(output[\"trained_params\"], x_a[i], y_a[i], config[\"n_updates_test\"], config[\"inner_lr\"], apply_fn)\n",
    "predictions = apply_fn(updated_params, x_b[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "1df0b9d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7f03ec419ee0>"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD6CAYAAAC8sMwIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAsVUlEQVR4nO3de3RV5Zn48e9zTq4nV8iNS8gFRJQ7klAUUQ9eq7bWtjrVdH7ttE5WWXW1ndVKLyw1QJnVMq6202kHB9vRupqpMrbY39j6q1pCEZ3WgCICctUEApgbkNvJ7Zzz/v7YuQEJJJx9rnk+a2Wds/fZefdzAnnynHe/+33FGINSSqno5Qh3AEoppQKjiVwppaKcJnKllIpymsiVUirKaSJXSqkop4lcKaWiXJwdjYhIDdAG+ACvMabEjnaVUkpdmi2JvI/bGNM0mgOzs7NNUVGRjadWSqnYt2vXriZjTM75++1M5KNWVFTEzp07w3FqpZSKWiJSO9x+u/rIDfCKiOwSkXKb2lRKKTUKdlXky4wxJ0UkF3hVRA4YY7YPPaAvwZcDFBQU2HRapZRStlTkxpiTfY8NwBZgyTDHbDLGlBhjSnJyLujiUUopdZkCrshFJAVwGGPa+p7fBqwdazu9vb3U1dXR1dUVaEgxLSkpifz8fOLj48MdilIqQtjRtZIHbBGR/vb+yxjz/8baSF1dHWlpaRQVFdHXljqPMYbm5mbq6uooLi4OdzhKqQgRcCI3xnwALAi0na6uLk3ilyAiZGVl0djYGO5QRm/DBjh6FD73OXC7rX1VVfDcczBjBqxaFd74lIoBYRl+OBJN4pcWdT+j0lJYvx6ef56Ke94BoOLFT4EIbNkS3tiUihF6i74KLrcbXnwRjGHNs8WsebZ4MIn3V+hKqYBoIh/io48+4nOf+xwzZsxg9uzZ3HnnnRw6dGjM7fzkJz/B4/FcVgxnz57l3//93y/reyOW2w1f//rg9te+pklcKRtFbyKvrISiInA4rMfKyoCaM8Zw7733ctNNN3H06FH279/PP//zP1NfXz/mtjSRD6qosApwWTc4kEnWrUXEek0pFbjoTOSVlVBeDrW1YIz1WF4eUDKvqqoiPj6er3zlKwP7Fi5cyPXXX88jjzzC3LlzmTdvHs8//zwA27Zt46abbuKzn/0sV111FWVlZRhj+OlPf8rJkydxu924+6rOV155hWuvvZZrrrmG++67j/b2dmpra5k5cyZNTU34/X6WL1/OK6+8wne+8x2OHj3KwoULeeSRRwL7OUWAihurMOkZmIzMgX392xU3VoUvMKViiTEm5F+LFy8259u/f/8F+0ZUWGiMlcLP/SosHH0b5/nXf/1X841vfOOC/S+88IK55ZZbjNfrNR999JGZNm2aOXnypKmqqjLp6enm+PHjxufzmaVLl5rXX3+9L7xC09jYaIwxprGx0Sxfvty0t7cbY4z5wQ9+YNasWWOMMeapp54yn/nMZ8yGDRtMeXm5McaYDz/80MyZM+eisY7pZxVuP/yhMeXlxmzdOvDPZLZutfb98IeDx/3619a/n4j1+OtfhylgpSIXsNMMk1MjatTKqB07Nrb9AdixYwcPPPAATqeTvLw8brzxRqqrq0lPT2fJkiXk5+cDVvVeU1PD9ddff873//Wvf2X//v0sW7YMgJ6eHq699loAHnroIf77v/+bJ598kt27d9see0QYMrzw8cf7nrjd5/aR93/C6u+O6v+EBVBWFpo4lYpi0ZnICwqsX/bh9l+mOXPm8MILL1yw3/ojOLzExMSB506nE6/XO+z333rrrfzmN7+54DWPx0NdXR0A7e3tpKWlXU7oUWPEPvHVqweTeD+Px9qviVypS4rOPvL168HlOnefy2Xtv0wrVqygu7ubp556amBfdXU1EyZM4Pnnn8fn89HY2Mj27dtZsuSCqWTOkZaWRltbGwBLly7ljTfe4MiRI4CVvPtHwnz729+mrKyMtWvX8o//+I8XfO+4cZFPWHpBVKlLi85EXlYGmzZBYaE1JKKw0NoOoHoTEbZs2cKrr77KjBkzmDNnDhUVFTz44IPMnz+fBQsWsGLFCjZs2MCkSZMu2lZ5eTkf//jHcbvd5OTk8Mwzz/DAAw8wf/58li5dyoEDB/jLX/5CdXX1QDJPSEjg6aefJisri2XLljF37tyYuNg5KiN9kiooYM2a0IaiVDSSi3UdBEtJSYk5f2GJ999/n6uvvjrksUSjmPtZnd9HDpCYCL/8JfL5MozBuq2/ulpv6VfjmojsMsMspRmdFbmKLed9wqpI+Rekuwv5vPUJSwRkhZuKffeFOVClIpMmchUZysqgpgb8firav4XZWoXJtuatN9k5mK1VVPxKZ3xUajiayFVkcrth5Urr+cqVeku/UhehiVxFpqoq2LiRx2+wHqnSu0CVGokmchV5qqrg/vth82Yq/uKGzZutbU3mSg1LE7mKPNXVVvLu705x9yXz6urwxqVUhNJEPkRdXR333HMPM2fOZPr06Tz88MN0d3eHO6zxZ9WqC/vE3W4deqjUCGxL5CLiFJF3ROQlu9oc0YYNF37Mrqqy9l8mYwyf/vSn+dSnPsXhw4c5fPgwnZ2drLIhefh8voDbUEqpkdhZkX8deN/G9kZWWnpun2l/n2pp6WU3uXXrVpKSkviHf/gHwJo75cc//jHPPvssP/vZz3j44YcHjr377rvZtm0bMPwUtQBFRUWsXbuW66+/nh/84Adcc801A99/+PBhFi9efNmxKqXUULYkchHJB+4CfmFHe5fkHnIB7LHHBi6MBTJEbd++fRck1/T0dIqKioadDAugqamJ73//+7z22mu8/fbblJSU8KMf/Wjg9aSkJHbs2MHq1avJyMgYmOHw6aef5otf/OJlx6qUUkPZVZH/BFgF+G1q79L6xxmvW2fLOGNjzLALG19sCoOhU9QuXLiQX/3qV9QOmZXx7/7u7waeP/TQQzz99NP4fD6ef/55HnzwwYDiVUrZYEg37cAEbQF204ZDwIlcRO4GGowxuy5xXLmI7BSRnY2NjYGedmCcMY8+ass44zlz5nD+/C+tra3U19eTlZWF3z/4N6qrqwsYnKJ29+7d7N69m/379/PLX/5y4LiUlJSB55/5zGd4+eWXeemll1i8eDFZWVkBxauUskFpKdxzD0yaZE3QNmmStR1AN2042FGRLwM+KSI1wHPAChH59fkHGWM2GWNKjDElOTk5gZ1xyDhj1q61ZZzxzTffjMfj4dlnnwWsC5Tf/OY3efjhhykuLmb37t34/X6OHz/OW2+9BVx8itrzJSUlcfvtt7Ny5cqBfnilVJidPAk9PdC/Nm99vbV98mR44xqjgBO5Mea7xph8Y0wR8DlgqzHm8wFHdjFBGGfcP43tCy+8wMyZM8nKysLhcLB69WqWLVtGcXEx8+bN41vf+tbAhcuRpqgdSVlZGSLCbbfddtlxKqXsU/HVRmuCNqwuVMEg3V1UfNWGXoMQsnUaWxG5CfiWMebuix0XDdPYvvnmmzzwwAP87ne/s22EyRNPPEFLSwvr1q0LqJ1I+1kpFbWGXBcTDAYZ3O8P3SW/0RppGltbl3ozxmwDttnZZrhcd91151y4DNS9997L0aNH2bp1q21tKqUCUFk58mu5uaGLwwbRuWZnFNqyZUu4Q1BKDbV69Tmbj1MxuBFls23qLfpKqfHpvLViKxiyruAwi6VHMk3kSqnxaaS1YgsLQxuHDTSRK6XGp/XrweU6d5/LZe2PMprIlVLj03lrxVJYaG2XlYU7sjHTRD6E0+lk4cKFzJ07l/vuuw/P0FXdx+iLX/wiL7zwgo3RKaVsN2StWGpqojKJQwwk8oH5EWyQnJzM7t272bt3LwkJCTz55JPnvK7T0SqlIlHUJ/I1ay59zOVYvnw5R44cYdu2bbjdbh588EHmzZuHz+fjkUceobS0lPnz5/Mf//EfgDXvysMPP8zs2bO56667aGhoCE5gSil1Hh1HPgyv18vLL7/MHXfcAcBbb73F3r17KS4uZtOmTWRkZFBdXU13dzfLli3jtttu45133uHgwYO899571NfXM3v2bL70pS+F+Z0opcaDqKzIKyqsaxMy5G5akcC7WTo7O1m4cCElJSUUFBTw5S9/GYAlS5ZQXFwMWAtJPPvssyxcuJCPfexjNDc3c/jwYbZv384DDzyA0+lkypQprFixIrBglFJqlKKyIq+oGEzaImDXdDH9feTnGzodrTGGf/u3f+P2228/55g//vGPw85nrpRSwRaVFXk43X777WzcuJHe3l4ADh06REdHBzfccAPPPfccPp+PU6dOURXg/OhKKTVaUVmRD/X446E930MPPURNTQ3XXHMNxhhycnJ48cUXuffee9m6dSvz5s3jyiuv5MYbbwxtYEqpccvWaWxHKxqmsY1k+rNSanwaaRpb7VpRSqkop4lcKaWinCZypZSKcgEnchFJEpG3RORdEdknIkG611IppdRw7Bi10g2sMMa0i0g8sENEXjbG/NWGtpVSSl1CwIncWMNe2vs24/u+Qj8URimlxilb+shFxCkiu4EG4FVjzN/saDfUhk5j+4lPfIKzZ8+GOySllLokWxK5McZnjFkI5ANLRGTu+ceISLmI7BSRnY2NjXac1nZDp7GdOHEiP//5z8MdklJKXZKto1aMMWeBbcAdw7y2yRhTYowpycnJsfO0QXHttddy4sQJwJr98LrrrmPRokVcd911HDx4EIA777yTPXv2ALBo0SLWrl0LwKOPPsovfvGL8ASulBp3Au4jF5EcoNcYc1ZEkoFbgB8G2u5r++sv2FeY5WJmXhpen59tBy+s6qfnpDA9J5WuXh87Djed89ots/NGfW6fz8ef//zngdkPr7rqKrZv305cXByvvfYa3/ve9/jtb3/LDTfcwOuvv05RURFxcXG88cYbAOzYsYPPf/7zY3m7Sil12eyoyCcDVSKyB6jG6iN/yYZ2Q65/GtusrCxOnz7NrbfeCkBLSwv33Xcfc+fO5Z/+6Z/Yt28fYC0+sX37dnbs2MFdd91Fe3s7Ho+HmpoaZs2aFb43smEDnD9pV1WVtV8pFXPsGLWyB1hkQyznuFgFHed0XPT1pHjnmCrwfv195C0tLdx99938/Oc/52tf+xqPPvoobrebLVu2UFNTw0033QRAaWkpO3fuZPr06dx66600NTXx1FNPsXjx4jGf21alpXD//bB5M7jdVhLv31ZKxRy9s3MYGRkZ/PSnP+WJJ56gt7eXlpYWpk6dCsAzzzwzcFxCQgLTpk1j8+bNLF26lOXLl/PEE0+wfPnyMEXex+22kvb998Njj52b1JVSMUcT+QgWLVrEggULeO6551i1ahXf/e53WbZs2QULMC9fvpy8vDxcLhfLly+nrq4u/IkcrKS9ciWsW2c9ahJXKmbpNLZRaFQ/q/7ulJUrYeNGrciVigE6je14MqRPvMKxdrCbRVctUiomaSKPRdXVAxX4mjUM9plXV4c7MqVUEETUUm/GGF3A+BJG1RW2atWF+9xu7VpRKkZFTEWelJREc3Pz6BLVOGWMobm5maSkpIseV1EBItYXDD6vqAh6iEqpMIiYi529vb3U1dXR1dUV8niiSVJSEvn5+cTHx4/qeBHQv41KxYaRLnZGTNdKfHw8xcXF4Q5DKaWiTsR0rajgePzxcEeglAo2TeQxTvvFlYp9msiVUirKaSJXSqkop4lcKaWinCbyWFBZCUVF4HBYj5WV4Y5IKRVCETP8UF2mykooLwePx9qurbW2AcrKwheXUipktCKPdqtXDybxfh6PtV8pNS5oIo92x46Nbb9SKuYEnMhFZJqIVInI+yKyT0S+bkdgapQKCsa2XykVc+yoyL3AN40xVwNLga+KyGwb2lWjsX49uFzn7nO5rP1KqXEh4ERujDlljHm773kb8D4wNdB21SiVlcGmTVBYaM2QVVhobeuFTqXGDVtHrYhIEbAI+Jud7apLKCvTxK3UOGbbxU4RSQV+C3zDGNM6zOvlIrJTRHY2NjbadVqllBr3bEnkIhKPlcQrjTG/G+4YY8wmY0yJMaYkJyfHjtMqpZTCnlErAvwSeN8Y86PAQ1JKKTUWdlTky4C/B1aIyO6+rzttaFepoPD7DceaPWw72MCh+jbAWkavoa2LHq8/zNEpNXYBX+w0xuwAdMVkFfF6vH4ON7RxpKGdjm4fKYlOCiZaQzc9PT5e298AQHKCg4zkeDKS4ynKSiErNTGcYSt1STrXiho3Xj/cSH1rN5MyEllcOIGpmclI3wrViXEObpyVQ4unl5ZO6+toQwfZqYlkpSbS3N5Ndc0ZctISyElNIjc9kaR4Z5jfkVIWTeQqJnX1+qg708mJs50sKZpIcoKT+fmZOB3CxJSEC46PczqYmpnM1MzkgX3GmIGFq/0G4p3C0YYODn7UDkB6chzLZ+aQkTy6hbCVChZN5CpmeHq8vH+qjYbWLs54egFISXTS3u0lOcFJTtrYukhEhL6CnZy0RG6+Og+/33Da00N9axeNbd2kJFhV+b6TLTS2dXNlXhqTM5IGKn2lQkETuYpqbV29+PyGTJdVZR9paCM7NZH5+RlMzUxmwjDVdyAcDiE7NZHs8/rN4xwOTnf0sO1gIymJToqzUyjOTiEtSat1FXyayFXUeq+uhfdOtDA5Iwn3Vbm4EuL47OJpOB2hr4ZnTUpjZm4qx894+KCxg70nWjnr6eWGK/WeCRV8mshV1PH7DX/78DQfNnVQnJ3CgmkZA6+FI4n3cziEwqwUCrNS8PR46fVZHewd3V6OnfZw1aQ07XJRQaGJXEWVXp+fHYebONXSxbypGczLz7j0N4WBK2HwV+vDpg721LVQd6aT5TOzdbSLsp0uLKGijt8YlhRPDHoSr6iwp525UzO4dkYWpzu62Xqgga5enz0NK9VHTP/4qhAqKSkxO3fuDPl5VfTq8foRgXhn6GoPEbDz1+NUSyfbDzWSnhSP+6pcrczVmInILmNMyfn7tSJXEa/H62fbwQa2H4ruWTMnZyRzw5U5OB2DwxqVsoMmchWxunp9vHv8LL/ffYLTHT1cmZcW9HNWVFiVeH+i7X9uVzfL5IxkbpszicQ4J70+P03t3fY0rMY17VpREam/G8Lnh2kTk5k9OT3kc57Y3bVyvl21pzlc386iggnMmhT8P1Iq+o3UtaKjVlRE8fR4cSXEkZuWxJV5aczITSU9Rm+qmTc1k/ZuH7tqz/BBYztzp2aQPyFZhyiqMdOuFRURunp9VB1s4NX99Xh9fpwOYVHBhLAm8ccfD277CXEObpiZzdLpE/H6Da8fbuK9Ey3BPamKSVqRq7Crb+3izaNN9Hj9LCqYENabetiwAUpLwe0e7BevqoLqali1yvbTiQjTc1Ipzk7h2GmPTpmrLotW5CpsjDHsPdHC1gMNxDsd3D5nElfmhfnux9JSuP9+K3mD9Xj//db+IBKx7gpNTYyjx+tnT91Z/P7QX79S0UkrchU2xljVeGGWi9KiiSEdIz4itxs2b7aS98qVsHGjte12hyyEj1q62HuilV6fn8WFE0N2XhW9bEnkIvKfwN1AgzFmrh1tqth11tNDYpyT5AQnN16ZQ1wkJPCh3G4ria9bB48+GtIkDlCQ5WJWexoHP7KWoZszJUNvHlIXZddv0DPAHTa1pWJYi6eX195vYFftGYDIS+Jgdads3Ggl8Y0bB7tZQmjRtExm5qVyqL6d/3n3JPWtXSGPQUUPWypyY8x2ESmyoy0Vuzw9XrYdasAhsLAgM9zhDK+/T7y/O8XtPnc7RBwOobRoIlfmprHvZAsTXPbOq65iSwSWQyoWWbfZN9Lt9XPTrFxSEyP08kx19blJu7/PvLo6LOFkuOK57opsEuIc9Pr8/O2DZp10S10gZL9NIlIOlAMUFBSE6rQqQrxz7Aytnb3cOCtn2DUzI8ZwQwz7K/Mwa+nspbbZQ0NbNyuuyiUlUv8YqmF5erwYQ1D+3UJWkRtjNhljSowxJTk5umrKeDNnagbLrshmckbypQ9Ww8pOTeSmq3Lo6vXx6v56Wjp7wx2SGgNB2HeyFa/Pb3vb2rWiguajli7+9kEzxhhSE+OYNtEV7pCiXm5aErdcnYffGF7dX0+zTroVNZITnCwpnhiUC/y2tCgivwH+F5glInUi8mU72lXRqb3by5tHmth6oIGm9h66vfZXIOPZhJQEbp2dR0ZyPAlxWotFOq/Pz5tHm4L6CcquUSsP2NGOim49Xj/vnWjhcH0bDhHmTk3n6snpkXGjT4xJS4rn1tl5gHWH7P5TrVyRm0pinI43jzT7T7VS0+RhRk4qGcnBmTtIr5Yo2zgEjp/2UJydwrz8jHPWrVTB09zRw566Fg7Vt1FaNJH8CdqFFSlaPL3sP9lKUbaLvPSkoJ1HSyUVkLOeHqoONOD3G+KcDu6aP5mPTc/SJB5C2amJ3N63WMX2Q028eaRJhyhGgF6fn9ePNJIQ5+CagglBPZcmcnXZGlq7BkZPdPYlDu1GCY+JKQncMWcS8/MzOHbaw1+ifFm8WLDvZCttXV6WXZEd9CkWtGxSl+X4aQ9vHm0iJTEO9ywd0xwJHA4ZWJyi16czJ4bb3Cnp5KQlBrVLpZ+WT2rMaps72HGkiUxXArdcnadJPMJkuhLISUvEGMPh+ragjFtWw/P7DUca2mjv9hLndDA1MzT3TehvoBqziSkJFGWlUFo0ITInvVIANLX3UF1zhpbOXkqKdDrcYDLGUNvsYc+JFtq7vMzP9zN3akbIzq+JXI3aybOdTMlMJi0pnmtnZIU7HHUJOWmJzJpkTYc7OTM5ZNXheFN3xsOeuhbOenrJdMVz46yckP+stZxSl9TW1cu2gw1sO9hITVNHuMNRY7BwWiaZrnidbCuITpzpxOc3LLsii4/PnRSWP5iayNWIen1+3j1+lj/sOUVDWzfXFGZSmKVjlPsNrOkZwZwO4boZWfR4/VTXnA53ODGhsa2bP79fT1Pf9AiLCiZw17zJFGalhG2ZwqjqWvH7DYcb2pmek6LD3EJg+6FG6lu7Kcp2sWjaBJIT9K7BodasiY5knulKYEnxRNKSgnNX4XhxpqOHd+vOcvJsF0nxDjp7rE84kTBNQlQl8jOeHnbVWhdvlhTrxZtgONPRQ3pyPE6HMC8/g/kIOWm6snu0m56TOvDc7zc4HGFc4HqoDRusha2HThNcVWXN/z7clMJh8taHpznS0E68U1gwLYNZeWkRdaE/ciIZhazURK6anMaRhnZOnO0Mdzgx58TZTl7dX8+7dWcBa6Y9TeLnqqgAEesLBp9HQ2UO8O7xs7z6fj0d3d5wh2IpLbVWYOpfTq9/habS0vDGRf/84dZ4/PTkOOZMSeeehVOZMyUjopI4RFkiB1iQrxdvgqGmqYPXDzWSnhzH7Mnp4Q4nYlW4NmC2VtH3+40xYLZWUeHaEN7ARmmCK4EWTy9/2HOKvSdawj/GvH8Fpvvvh8ceC+myep4eL+19f9B6vH521Z5hV+1p3jzaRNXBBv7v7pMcP20VjFdNSmfBtMyI6EYZTmRGdRFOh3DtdOvizd4TLeEOJyYcaWjjzaPNZKcmsuKqPF2x/WIiuIIcjYIsF3fNn8yUzGT21LXwP3tOhv/TrdsNK1fCunXWY5CTuM9v2HeyhZfePTWwCHiPz8/RhnY+aOygsa2b7l4fM3JTyUqN4NWshoiqPvJ+E1ISWFQwgUyXdfGm2+tDkIj9axnJunp9vHPsLFMyk7j+iuyI+8gYcYZUkI/fsDksCzMHKiUxjutnZlPf2sX7p1pJ6vu9Od3RQ31rFzNyUkP7u1RVBRs3wqOPWo9BXFqvq9fH1gMNnPX0kj8hmWsKrcmsUhPjuL90WlDOGQpRmcgBZk1KG3i+90QLHzR2MD0nlZl5qaTr1flRS4p3ctvsSaQmxeGMlAtgka6vgqxYt8JKPlGUxIfKS086Zx6Qk2c72VPXwnsnWpiRk8q8qRnBT+j9n2g2b6biL24qNruD9sfR5ze8friJ1s5ebrgyO6am+42J8qs4O5Upmckcrm/jpXdPsfVAffg/Lka4w/Vt7D/ZClgrtWsSH4PzK8j+bpYoN3dqBnfMnUT+hGQO1bfxyv6PaPEEeV3Q6uqBpL1mDYOfeKqrbT9Vr8+P3xiunZEVU0kcQPqvyoZSSUmJ2blzp+3tdvb4ONrYzpGGdvLSkwZuI+/q9Wm/7xC1zR28caSZqROSuWFmdthuYohKQypI3O4Lt2NEQ2sXO440cUVuKvPzM0NyThEIVjrq8fpJiHNgjInq/+8isssYU3L+frvW7LxDRA6KyBER+Y4dbV6O5AQnc6dm8MkFU7imMBOw+v1efOcEbxxpoqGtK1yhRYxTLZ3879FmctMSWTYjK6r/U4fFkAoSCGoFGU656Ul8fO5k5k6xJn5q7x4cimenYA/nPOvp4dX99bxxpKmv/dj8/x5wRS4iTuAQcCtQB1QDDxhj9o/0PcGqyIfT0e3lwEdtfNDYTq/PkOmKZ2ZuKsXZKePuwl5jWzdVBxpITYrjlqvz9OKwGpVur48/vneKiSmJXDcjK2h3VdtVkbd3e6lv7aK5vYejje0kOB0smJbJFbmpl/7mCBfMinwJcMQY84Expgd4DrjHhnZtkZIYx+LCCdy7aCpLiifiEHjn+Fn6/7/0hnscbQh5erwkJThxz8rVJK5GLTHOyZwpGZzsu2Hsw6aOgdvTw83vN9S3dvH2sTO0dVn9+U1t3fztg9N80NjOFbmp3L1gckwk8YuxY9TKVOD4kO064GPnHyQi5UA5QEFBgQ2nHZs4p4MrclO5IjeVjm7vQFXxyr56EuMczMxLZdoEV+TcunyZ2rp6OX66k85eH06HIEBqUhwzclIpzEphSmayzlOjxuzKvDTSk+J582gT/3u0GYDPLJ5KYpyTvSdaONXSRVpSHNmpieSkJpKeHDfmbozHHx/dcX6/ob6ti2PNHurOdNLt9eMQyE5JJC0pnimZyXxiwWRSEuKi/vd5tOxI5MP9pC74gGSM2QRsAqtrxYbzXrb+FW38fkNxdgqHG9p440gzSfFnBpJ9NC0ePPRi7s6aM5xq6SLOKfj9Br+xFoKYnm3NzKZJXF2uSRlJ3LtoKmc8vZzu6CExzvo/J2IlgboznXzQaE1zPDElnjvmTh5T+xfrFzfG4OnxkZIYR6/fz18ONuJwCPmZyUyb6GJSRtLA/+2EOMe4+8RpR7aqA4aOpM8HTtrQbtA5HMLsKelcPTmNUy1dHKpvY++JVtKT4inKjsPnNzgkMi+QdPX6qG32UNvcQXNHD/cumkpSvJNFBZkscTrO+WM1XqoSFXwiwsSUBCamDN7xOGdKBnP6Loq2dvXS0NqNzz9Yq1XXnGZSehJTM5Mv6/9ia1cvfz3ajIhw6+w8EuOc3Hx1HhNTEnTYbB87Enk1MFNEioETwOeAB21oN2REhCmZyUzJTKa920tyX3V74KNWPmjsYGaedXG0vwIJp7OeHt4+dob61m6MgUxXPPPzB5eUynSde0uxJnEVSulJ8efckNfZ46PujIfD9e0kxTu4Mi+NqyenjyoBG2NNW7372FlEYHHfXZiATuZ2noATuTHGKyIPA38CnMB/GmP2BRxZmKQOWUg4IzmexDgHb9eeZc/xFgqyXFyZl3ZONRJsXp+fE2c7SY53kptufXxs7/Yxe3I6RVkpZLj0LlYVuZITnNyzYCqnWrs40tDOnroWPnjxT9z0k8dIP3oICgpg/XooKzvn+zw9Xl4/3ERzew+TM5NYWpyl8+FfhC0dwcaYPwJ/tKOtSJI/wUX+BBdnOno4VN9GbbOHzl4f7lm5QT2vMYaTLV3UNnVQd6YTr99QlOUiNz2JlMQ4PrlgSlDPr5SdHA5hat+aoaee+S8O/uxnpBw9DMZgamuR8nLrwLIyfH6D0yEkxjlxiLB0+sRz5lJXw4upOzuDrcfrp9vrIy0pHk+Pl6a2HgqCsPTZ1gP1fNTSTUKcg4KJLgqzXOSmJUZkX71SY1JUBLW1APQ6nLwycylXNNeRMyGFAy//hca2bu6eP0X7vkcw0jjy6BmaEQGGXg3fe6KVIw3tlHonMDMv7RLfea6Obi+Nbd20d3vx9Pjo6PHi6fZx89W5JMU7mZmbxszctMu+OKRUxOlfCejYsYFdXkccrt5udk29CkSIO93JjNzUgapcjZ4m8su0uHACnh4v1TVnaGzrJjHewYycVDJdCbR29VLb5MFnDJ09Pjp7vXR0+7j+imwmpCRwqqWLtz60FsJNjHOQkugkLSkOf9+no2kTY2tCH6UG5nHPzYX6egCSvd24P9jJ8fRcPEXTKVo0JSIGFEQjTeSXyekQbpiZw1s1pzlxphO/MUzKSCbTBW1dXt470YJDrIs9SfFOJrgSrPkkKivJf/Rxck7VkzIph7jvr7vgQo9SMad/Tpp77oG4OPAOLjU3zdsOqx4GTeKXTfvIQ6myEsrLweMZ3OdywaZNmszV+PDYY9ZKQBkZ0No64qgVNbygzn6oRmn16nOTOFjbq1eHJx6lQmnoPO7x8fDnP0NNjSZxG2giD6UhF3pGtV+pWDF03va1awcXXI6RRTnCTRN5KI00WVgYJhFTKqTGyTzu4aKJPJTWr7f6xIdyuaz9SsWyVasGkvjA5Fhut7VfBUwTeSiVlVkXNgsLrSnjCgv1Qqcad9asCXcEsUeHH4ZaWZkmbqWUrbQiV0oFXbDX5hzvdBy5Uiqk7FqbczzSceRKKRWjNJErpUJqtGtzqtHTRK6UCintF7efJnKllIpyASVyEblPRPaJiF9ELuiAV0opFXyBVuR7gU8D222IRSml1GUI6IYgY8z7gC5BppRSYaR95EopFeUuWZGLyGvApGFeWm2M+f1oTyQi5UA5QIHO9qeUUra5ZCI3xtxix4mMMZuATWDd2WlHm0oppbRrRSmlol6gww/vFZE64FrgDyLyJ3vCUkopNVqBjlrZAmyxKRallFKXQbtW7FRZCUVF4HBYj5WV4Y5IKTUO6MISdqmshPJy8His7dpaaxt0IQmlVFBpRW6X1asHk3g/j8far5RSQaSJ3C7Hjo1tv1JK2UQTuV1GuslJb35SSgWZJnK7rF8PLte5+1wua79SSgWRJnK7lJXBpk1QWGgtSlhYaG3rhU6lVJDpqBU7lZVp4lZKhZxW5EopFeU0kSulVJTTRK6UUlFOE7lSSkU5TeRKKRXlNJErpVSU00SulFJRThO5UkpFOU3kSikV5TSRK6VUlAt0zc5/EZEDIrJHRLaISKZNcSmllBqlQCvyV4G5xpj5wCHgu4GHpJRSaiwCSuTGmFeMMd6+zb8C+YGHpJRSaizs7CP/EvDySC+KSLmI7BSRnY2NjTaeVimlxrdLTmMrIq8Bk4Z5abUx5vd9x6wGvMCIy8YbYzYBmwBKSkrMZUWrlFLqApdM5MaYWy72uoh8AbgbuNkYowlaKaVCLKCFJUTkDuDbwI3GGM+ljldKKWW/QPvIfwakAa+KyG4RedKGmJRSSo1BoKNWrjDGTDPGLOz7+opdgSmlAlNREe4IVKjonZ1Kxag1a8IdgQoVTeRKKRXlNJErFUMqKkDE+oLB59rNEtskHCMGS0pKzM6dO0N+XqXGExHQAcGxRUR2GWNKzt+vFblSSkU5TeRKxajHHw93BCpUNJErFaNG1S9eWQlFReBwWI+VI86yoSJYQHd2KqWiWGUllJeDp++m7NpaaxugrCx8cakx04pcqfGoshK+8IXBJN7P44HVq8MTk7psmsiVGm/6K3Gfb/jXjx0LbTwqYNGfyLWPT6mxWb36wkp8qIKC0MWibBHdiby/sqittQbM9vfxaTJX49WGDVBVde6+qiprf7+LVdwuF6xfH5zYVNBEdyIfrrLQPj41npWWwv33Dybzqipru7R08JiRKm6nEzZt0gudUSi6E/lIlYX28anxyu2GzZut5P3YY9bj5s3W/n7r11uV91AuF/zqV5rEo1R0J/KRKgvt41PjmdsNK1fCunXW49AkDlay3rQJCgut+/gLC7USj3LRnchHqiy0j0+NZ1VVsHEjPPqo9Xh+nzlYSbumBvx+61GTeFSL7kSulYVS5+rvE9+8GdauHexmGS6Zq5gRUCIXkXUisqdvmbdXRGSKXYGNmlYWSg2qrj63T7y/z7y6OrxxqaAKaBpbEUk3xrT2Pf8aMHs0y73pNLZKKTV2QZnGtj+J90kBdPZjpZQKsYAnzRKR9cD/AVoA9yUOV0opZbNLVuQi8pqI7B3m6x4AY8xqY8w0oBJ4+CLtlIvIThHZ2djYaN87UEqpcc62pd5EpBD4gzFm7qWO1T5ypZQau6D0kYvIzCGbnwQOBNKeUkqpsQt01MpvgVmAH6gFvmKMOTGK72vsOz7aZANN4Q4ihMbb+wV9z+NFtL7nQmNMzvk7betaGQ9EZOdwH2ti1Xh7v6DvebyItfcc3Xd2KqWU0kSulFLRThP52GwKdwAhNt7eL+h7Hi9i6j1rH7lSSkU5rciVUirKaSK/DCLyLRExIpId7liCTUT+RUQO9M1yuUVEMsMdU7CIyB0iclBEjojId8IdT7CJyDQRqRKR90Vkn4h8PdwxhYKIOEXkHRF5Kdyx2EUT+RiJyDTgVmC8rCf3KjDXGDMfOAR8N8zxBIWIOIGfAx8HZgMPiMjs8EYVdF7gm8aYq4GlwFfHwXsG+DrwfriDsJMm8rH7MbCKcTLTozHmFWOMt2/zr0B+OOMJoiXAEWPMB8aYHuA54J4wxxRUxphTxpi3+563YSW3qeGNKrhEJB+4C/hFuGOxkybyMRCRTwInjDHvhjuWMPkS8HK4gwiSqcDxIdt1xHhSG0pEioBFwN/CHEqw/QSrEPOHOQ5bBTyNbawRkdeAScO8tBr4HnBbaCMKvou9Z2PM7/uOWY31UbwylLGFkAyzb1x86hKRVOC3wDfOW2MgpojI3UCDMWaXiNwU5nBspYn8PMaYW4bbLyLzgGLgXREBq4vhbRFZYoz5KIQh2m6k99xPRL4A3A3cbGJ3vGodMG3Idj5wMkyxhIyIxGMl8UpjzO/CHU+QLQM+KSJ3AklAuoj82hjz+TDHFTAdR36ZRKQGKDHGROPEO6MmIncAPwJuNMbE7ETyIhKHdTH3ZuAEUA08aIzZF9bAgkisiuRXwGljzDfCHE5I9VXk3zLG3B3mUGyhfeTqUn4GpAGv9i2y/WS4AwqGvgu6DwN/wrrotzmWk3ifZcDfAyv6/m1391WrKspoRa6UUlFOK3KllIpymsiVUirKaSJXSqkop4lcKaWinCZypZSKcprIlVIqymkiV0qpKKeJXCmlotz/BxKFcBUYmovkAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(x_a[i], y_a[i], \"ro\", label=\"Context\")\n",
    "plt.plot(x_b[i], y_b[i], \"rx\", label=\"Query\")\n",
    "plt.plot(x_b[i], predictions, \"+b\", label=\"Pred\")\n",
    "plt.plot(np.linspace(-5, 5, 100), apply_fn(output[\"trained_params\"], np.linspace(-5, 5, 100)[:, np.newaxis]), \"--\", label=\"Raw\", alpha=0.4)\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03509a1f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
