{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17367bce",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-04T01:27:14.421909Z",
     "start_time": "2021-10-04T01:27:14.391662Z"
    }
   },
   "outputs": [],
   "source": [
    "def set_latex():\n",
    "    for i in range(2):\n",
    "        import matplotlib\n",
    "        import matplotlib.pyplot as plt\n",
    "\n",
    "        plt.rc('text', usetex=True)\n",
    "        plt.rc('font', family='serif')\n",
    "\n",
    "        plt.style.use(\"default\")\n",
    "        plt.rcParams[\"font.size\"]=15\n",
    "\n",
    "        plt.rcParams['font.family'] = 'Times New Roman'\n",
    "        plt.rcParams['mathtext.fontset'] = 'stix'\n",
    "\n",
    "        try:\n",
    "            del matplotlib.font_manager.weight_dict['roman']\n",
    "            matplotlib.font_manager._rebuild()\n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4276dfd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-04T01:27:15.320061Z",
     "start_time": "2021-10-04T01:27:14.436100Z"
    }
   },
   "outputs": [],
   "source": [
    "set_latex()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3625183",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-04T01:27:16.702671Z",
     "start_time": "2021-10-04T01:27:15.327853Z"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import numpy as np\n",
    "import ntk\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import warnings \n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from scipy import special\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.colors as colors\n",
    "from models import SoftTree\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff2cd359",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-04T01:27:16.729347Z",
     "start_time": "2021-10-04T01:27:16.706982Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_kernel(alpha, max_depth):\n",
    "    def rotation_o(u, t, deg=False):    \n",
    "        if deg == True:\n",
    "            t = np.deg2rad(t)\n",
    "\n",
    "        R = np.array([[np.cos(t), -np.sin(t)],\n",
    "                      [np.sin(t),  np.cos(t)]])\n",
    "        return  np.dot(R, u)    \n",
    "\n",
    "    u = (1.0, 0.0)\n",
    "    \n",
    "    kernel_list = []\n",
    "    tau_list = []\n",
    "    tau_dot_list = []    \n",
    "    inner_product_list = []\n",
    "     \n",
    "    for depth in range(1,max_depth,1):\n",
    "        kernel = []\n",
    "        taus = []\n",
    "        tau_dots = []\n",
    "        inner_product = []\n",
    "        for i in range(360):\n",
    "            Ru = rotation_o(u, i*np.pi/180)\n",
    "            H, tau, tau_dot = ntk.tree(np.vstack([u, Ru]), max_depth=depth, alpha=alpha)\n",
    "            kernel.append(H[depth-1, 1,0])\n",
    "            taus.append(tau[1,0])\n",
    "            tau_dots.append(tau_dot[1,0])\n",
    "            inner_product.append(np.dot(u, Ru))\n",
    "        kernel_list.append(kernel)\n",
    "        tau_list.append(taus)\n",
    "        tau_dot_list.append(tau_dots)\n",
    "        inner_product_list.append(inner_product)\n",
    "    return kernel_list, tau_list, tau_dot_list, inner_product_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98ba3e84",
   "metadata": {},
   "source": [
    "#### utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1f8050b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-04T01:27:16.831966Z",
     "start_time": "2021-10-04T01:27:16.732463Z"
    }
   },
   "outputs": [],
   "source": [
    "def gradient(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False):\n",
    "    \"\"\"\n",
    "    Compute the gradient of `outputs` with respect to `inputs`\n",
    "    gradient(x.sum(), x)\n",
    "    gradient((x * y).sum(), [x, y])\n",
    "    \"\"\"\n",
    "    if torch.is_tensor(inputs):\n",
    "        inputs = [inputs]\n",
    "    else:\n",
    "        inputs = list(inputs)\n",
    "    grads = torch.autograd.grad(\n",
    "        outputs,\n",
    "        inputs,\n",
    "        grad_outputs,\n",
    "        allow_unused=True,\n",
    "        retain_graph=retain_graph,\n",
    "        create_graph=create_graph,\n",
    "    )\n",
    "    grads = [x if x is not None else torch.zeros_like(y) for x, y in zip(grads, inputs)]\n",
    "    return torch.cat([x.contiguous().view(-1) for x in grads])\n",
    "\n",
    "\n",
    "def compute_kernels(f, xtr, parameters=None):\n",
    "    if parameters is None:\n",
    "        parameters = list(f.parameters())\n",
    "\n",
    "    ktrtr = xtr.new_zeros(len(xtr), len(xtr))\n",
    "\n",
    "    params = []\n",
    "    current = []\n",
    "    for p in sorted(parameters, key=lambda p: p.numel(), reverse=True):\n",
    "        current.append(p)\n",
    "        if sum(p.numel() for p in current) > 2e9 // (8 * (len(xtr))):\n",
    "            if len(current) > 1:\n",
    "                params.append(current[:-1])\n",
    "                current = current[-1:]\n",
    "            else:\n",
    "                params.append(current)\n",
    "                current = []\n",
    "    if len(current) > 0:\n",
    "        params.append(current)\n",
    "\n",
    "    for i, p in enumerate(params):\n",
    "        jtr = xtr.new_empty(len(xtr), sum(u.numel() for u in p))  # (P, N~)\n",
    "\n",
    "        for j, x in enumerate(xtr):\n",
    "            jtr[j] = gradient(f(x[None]), p)  # (N~)\n",
    "\n",
    "        ktrtr.add_(jtr @ jtr.t())\n",
    "        del jtr\n",
    "\n",
    "    return ktrtr\n",
    "\n",
    "\n",
    "def rotation_o(u, t, deg=False):\n",
    "    if deg == True:\n",
    "        t = np.deg2rad(t)\n",
    "\n",
    "    R = np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]])\n",
    "    return np.dot(R, u)\n",
    "\n",
    "def plot_kernel(alpha, n_tree, color, is_first=False):\n",
    "    st = SoftTree(input_dim=2, output_dim=1, max_depth=3, scale=alpha, n_tree=n_tree)\n",
    "    u = (1, 0)\n",
    "    res = []\n",
    "    inner_product = []\n",
    "    for i in tqdm(range(180), leave=False):\n",
    "        Ru = rotation_o(u, i * np.pi / 180)\n",
    "        x = torch.Tensor([u, Ru]).to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        K = compute_kernels(st, x)\n",
    "        res.append(K[1,0].tolist())\n",
    "        inner_product.append(np.dot(u, Ru))\n",
    "    if is_first:\n",
    "        plt.plot(inner_product, res, color=color, linewidth=1, label=f\"$M={n_tree}$\")\n",
    "    else:\n",
    "        plt.plot(inner_product, res, color=color, linewidth=1)        \n",
    "    return inner_product, res\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a7c6353",
   "metadata": {},
   "source": [
    "## Figure2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66151f55",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-04T01:27:19.419920Z",
     "start_time": "2021-10-04T01:27:16.838025Z"
    }
   },
   "outputs": [],
   "source": [
    "max_alpha = 50\n",
    "fig = plt.figure(figsize=(5, 4))\n",
    "ax = fig.add_subplot(111)\n",
    "for alpha in range(1, max_alpha, 1):\n",
    "    x = np.linspace(-3, 3, 100000)\n",
    "    ax.plot(x, 0.5*special.erf(0.25*alpha*x)+0.5, color=cm.jet(alpha/max_alpha))\n",
    "    ax.set_xlabel('${w}_{m,n}^{\\\\top} {x}_i$')\n",
    "    ax.set_ylabel('$\\sigma({w}_{m,n}^{\\\\top} {x}_i)$')\n",
    "\n",
    "e = math.e\n",
    "y = 1 / (1 + e**-x)\n",
    "plt.plot(x, y, color=\"m\", linestyle=\"dotted\", linewidth=3)\n",
    "\n",
    "plt.tight_layout()\n",
    "ax.grid(linestyle=\"dotted\")\n",
    "\n",
    "\n",
    "cax = plt.axes([1.00, 0.2, 0.04, 0.74]) \n",
    "cmap = matplotlib.cm.cool\n",
    "norm = matplotlib.colors.Normalize(vmin=0, vmax=max_alpha*0.25)\n",
    "\n",
    "fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.jet), cax=cax, orientation='vertical', label=\"$\\\\alpha$\", ticks = [1,3,5,7,9,11,13,15,17,19])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfba50af",
   "metadata": {},
   "source": [
    "## Figure3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4bb1a2a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-04T01:28:48.936902Z",
     "start_time": "2021-10-04T01:27:19.432787Z"
    }
   },
   "outputs": [],
   "source": [
    "n_seeds = 10\n",
    "\n",
    "fig = plt.figure(figsize=(15, 5))\n",
    "plt.subplot(1,3,1)\n",
    "depth = 3\n",
    "alpha = 2.0\n",
    "\n",
    "n_tree_combinations = [16, 64, 256, 1024, 4096]\n",
    "for j, n_tree in enumerate(n_tree_combinations):\n",
    "    res_all = []\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha ,depth+1)\n",
    "    for i in tqdm(range(n_seeds), leave=False):\n",
    "        inner_product, res = plot_kernel(alpha, n_tree, color=cm.bwr(j/len(n_tree_combinations)), is_first=(i==0))\n",
    "        res_all.append(res)\n",
    "    mse = 0\n",
    "    for i in range(len(res_all)):\n",
    "        mse+=mean_squared_error(res_all[i], kernel_list[depth-1][0:180])\n",
    "    mse/=len(res_all)\n",
    "plt.plot(inner_product_list[0][0:180], kernel_list[depth-1][0:180],  color=\"black\", linewidth=3, linestyle=\"dotted\", label=\"$M={\\\\infty}$\")\n",
    "plt.title(\"$d=3$, $\\\\alpha=2.0$\")\n",
    "plt.xlabel(\"Inner product of the inputs\")\n",
    "plt.ylabel(\"$\\widehat{\\Theta}_0(x_i, x_j)$\")\n",
    "plt.legend(prop={'size': 12})\n",
    "plt.ylim(-0.3, 2.5)\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "# -----\n",
    "plt.subplot(1,3,2)\n",
    "n_dataset = 50\n",
    "n_features = 5\n",
    "depth = 3\n",
    "x = torch.Tensor([(j/np.sqrt(sum(j**2))) for j in np.array([np.random.randn(n_features) for i in range(n_dataset)])])\n",
    "\n",
    "norm = colors.LogNorm(vmin=0.5, vmax=100)\n",
    "\n",
    "for j, alpha in enumerate(tqdm((0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0), leave=False)):\n",
    "    res, n_tree_list = [], []\n",
    "    K_infinite,_,_ = ntk.tree(np.array(x), max_depth=depth, alpha=alpha)\n",
    "    K_infinite = K_infinite[-1]\n",
    "\n",
    "    for seed in tqdm(range(n_seeds), leave=False):\n",
    "        res_seed = []\n",
    "        for i in tqdm(range(8), leave=False):\n",
    "            n_tree = 2**(i+5)\n",
    "            st = SoftTree(input_dim=n_features, output_dim=1, max_depth=depth, scale=alpha, n_tree=n_tree)\n",
    "            K_empirical = compute_kernels(st, x)\n",
    "            res_seed.append(torch.norm((K_empirical - K_infinite)/torch.norm(K_empirical)).tolist())\n",
    "            n_tree_list.append(n_tree)\n",
    "        res.append(res_seed)\n",
    "    plt.errorbar(\n",
    "        sorted(np.unique(n_tree_list)), np.array(res).mean(axis=0), \n",
    "        yerr=np.array(res).std(axis=0), \n",
    "        color=cm.magma((j+1)/9), capsize=3, label=\"$d=3, \\\\alpha=$\"+f\"{alpha}\"\n",
    "    )\n",
    "\n",
    "plt.plot(sorted(np.unique(n_tree_list)), [0.1*(i**-0.5) for i in sorted(np.unique(n_tree_list))], linestyle=\"dashed\", color=\"black\", label=\"trend: $M^{-1/2}$\")\n",
    "\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "plt.title(\"Change $\\\\alpha$ with fixed $d=3$\")\n",
    "plt.xlabel(\"$M$\")\n",
    "plt.ylabel(\"$\\\\frac{\\left|| \\\\widehat{H}_0 - {H}  \\\\right||_F}{||\\\\widehat{H}_0||_F}$\")\n",
    "\n",
    "plt.legend(prop={'size': 12}, loc=\"lower left\")\n",
    "\n",
    "# --------\n",
    "\n",
    "plt.subplot(1,3,3)\n",
    "n_dataset = 50\n",
    "n_features = 5\n",
    "depth = 3\n",
    "alpha = 2.0\n",
    "\n",
    "x = torch.Tensor([(j/np.sqrt(sum(j**2))) for j in np.array([np.random.randn(n_features) for i in range(n_dataset)])])\n",
    "\n",
    "norm = colors.LogNorm(vmin=0.5, vmax=100)\n",
    "\n",
    "for j, depth in enumerate(tqdm(range(1,7,1), leave=False)):\n",
    "    res, n_tree_list = [], []\n",
    "    K_infinite,_,_ = ntk.tree(np.array(x), max_depth=depth, alpha=alpha)\n",
    "    K_infinite = K_infinite[-1]\n",
    "\n",
    "    for seed in tqdm(range(n_seeds), leave=False):\n",
    "        res_seed = []\n",
    "        for i in tqdm(range(8), leave=False):\n",
    "            n_tree = 2**(i+5)\n",
    "            st = SoftTree(input_dim=n_features, output_dim=1, max_depth=depth, scale=alpha, n_tree=n_tree)\n",
    "            K_empirical = compute_kernels(st, x)\n",
    "            res_seed.append(torch.norm((K_empirical - K_infinite)/torch.norm(K_empirical)).tolist())\n",
    "            n_tree_list.append(n_tree)\n",
    "        res.append(res_seed)\n",
    "    plt.errorbar(\n",
    "        sorted(np.unique(n_tree_list)), np.array(res).mean(axis=0), \n",
    "        yerr=np.array(res).std(axis=0), \n",
    "        color=cm.viridis((j+1)/6), capsize=3, label=f\"$d={int(depth)}, \\\\alpha=2.0$\"\n",
    "    )\n",
    "\n",
    "plt.plot(sorted(np.unique(n_tree_list)), [0.1*(i**-0.5) for i in sorted(np.unique(n_tree_list))], linestyle=\"dashed\", color=\"black\", label=\"trend: $M^{-1/2}$\")\n",
    "\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "plt.title(\"Change $d$ with fixed $\\\\alpha=2.0$\")\n",
    "plt.xlabel(\"$M$\")\n",
    "plt.ylabel(\"$\\\\frac{\\left|| \\\\widehat{H}_0 - {H}  \\\\right||_F}{||\\\\widehat{H}_0||_F}$\")\n",
    "\n",
    "plt.legend(prop={'size': 12}, loc=\"lower left\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9eda84d8",
   "metadata": {},
   "source": [
    "## Figure4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9da440c2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-04T01:29:15.312646Z",
     "start_time": "2021-10-04T01:28:48.962281Z"
    }
   },
   "outputs": [],
   "source": [
    "def circle_transform(angle_vec):\n",
    "    cos_tensor = torch.cos(angle_vec)\n",
    "    sin_tensor = torch.sin(angle_vec)\n",
    "    return torch.stack((cos_tensor, sin_tensor), -1)\n",
    "\n",
    "n_features = 2\n",
    "n_dataset = 10\n",
    "\n",
    "train_data = torch.Tensor([np.random.randn(n_features) for i in range(n_dataset)])\n",
    "target_data = torch.tensor(np.random.randn(train_data.shape[0]))\n",
    "test_data = torch.Tensor([np.random.randn(n_features) for i in range(10)])\n",
    "\n",
    "def train_net(net, n_epochs, input_data, target, lr, initial_train):\n",
    "    criterion = nn.MSELoss(reduction='mean')\n",
    "    optimizer = optim.SGD(net.parameters(), lr=lr)\n",
    "    for epoch in range(n_epochs):  \n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(input_data.double())-initial_train.unsqueeze(1)\n",
    "        loss = criterion(outputs.view(-1), target)/2\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "def analytical_evolution_MSE(t, lr, H_train, H_test, initial_train, initial_test, target_data):\n",
    "    n_train = len(initial_train)\n",
    "\n",
    "    # first compute the exponential of the matrix (using eigendecomposition):\n",
    "    lam, P = np.linalg.eig(H_train)  # eig decomposition\n",
    "    lam = lam.astype(dtype='float64')\n",
    "\n",
    "    H_train_inv = np.dot(P, np.dot(np.diag(lam**(-1)), P.transpose()))\n",
    "\n",
    "    # note that you need to rescale the time by n_train, as the 2 paper use different convention for the loss function\n",
    "    # I am using np arrays, not torch tensors\n",
    "    exp_matrix = np.dot(\n",
    "        P, np.dot(np.diag(np.exp(-lr * t * lam / n_train)), P.transpose()))\n",
    "\n",
    "    # compute the prediction on train set\n",
    "    pred_train = target_data.cpu().numpy() + np.dot(exp_matrix,\n",
    "                                                    (initial_train - target_data).cpu().detach().numpy())\n",
    "\n",
    "    # compute the intermediate matrix used both in prediction on test set and weights evolution\n",
    "    tmp = np.dot(np.eye(lam.size) - exp_matrix,\n",
    "                 (initial_train - target_data).cpu().detach().numpy())\n",
    "    tmp = np.dot(H_train_inv, tmp)\n",
    "\n",
    "    # compute prediction on test set\n",
    "    pred_test = np.dot(H_test, tmp)\n",
    "    pred_test = initial_test.detach().cpu().numpy().reshape(-1) - pred_test\n",
    "\n",
    "    return pred_train, pred_test\n",
    "\n",
    "\n",
    "alpha = 2.0\n",
    "depth = 3\n",
    "\n",
    "H_analytical_train = ntk.tree(train_data.numpy(), max_depth=depth, alpha=alpha)[0][depth-1]\n",
    "H_analytical_test = ntk.tree(torch.cat([train_data, test_data]).numpy(), max_depth=depth, alpha=alpha)[0][depth-1][len(train_data):, 0:len(train_data)]\n",
    "\n",
    "ptrain_empiricals, ptest_empiricals = [], []\n",
    "for n_tree in (16, 1024):\n",
    "    ptrain_empirical, ptest_empirical, ptrain_analytical, ptest_analytical = [], [], [], []\n",
    "    \n",
    "    t_max = 1000\n",
    "    t_step = 10\n",
    "    lr = 0.1\n",
    "    t_list = np.arange(t_step, t_max+t_step, t_step)\n",
    "\n",
    "    st = SoftTree(input_dim=train_data.shape[1], output_dim=1, max_depth=depth, scale=alpha, n_tree=n_tree)\n",
    "    initial_train=st.forward(train_data).reshape(-1)\n",
    "    initial_test=st.forward(test_data).reshape(-1)\n",
    "\n",
    "    ptrain_analytical.append(torch.zeros_like(initial_train).detach().numpy())\n",
    "    ptrain_empirical.append(torch.zeros_like(initial_train).detach().numpy())\n",
    "\n",
    "    ptest_analytical.append(torch.zeros_like(initial_test).detach().numpy())\n",
    "    ptest_empirical.append(torch.zeros_like(initial_test).detach().numpy())\n",
    "\n",
    "    for t in tqdm(t_list):\n",
    "        train_net(st, t_step, train_data, target_data, lr, initial_train.detach())\n",
    "\n",
    "        ptrain_empirical.append(st.forward(train_data).detach().cpu().numpy().reshape(-1)-initial_train.detach().numpy())\n",
    "        ptest_empirical.append(st.forward(test_data).detach().cpu().numpy().reshape(-1)-initial_test.detach().numpy())\n",
    "\n",
    "        pred_train, pred_test = analytical_evolution_MSE(\n",
    "            t=t,\n",
    "            lr=lr,\n",
    "            H_train=H_analytical_train, \n",
    "            H_test=H_analytical_test,\n",
    "            initial_train=torch.zeros_like(initial_train), \n",
    "            initial_test=torch.zeros_like(initial_test), \n",
    "            target_data=target_data\n",
    "        )\n",
    "        ptrain_analytical.append(pred_train)\n",
    "        ptest_analytical.append(pred_test)\n",
    "    ptrain_empiricals.append(ptrain_empirical)\n",
    "    ptest_empiricals.append(ptest_empirical)\n",
    "    \n",
    "cmap = plt.cm.nipy_spectral\n",
    "t_list = np.arange(0, t_max+t_step, t_step)\n",
    "\n",
    "plt.figure(figsize=(15, 5))\n",
    "plt.subplot(1,2,1)\n",
    "for i in range(len(ptrain_analytical[0])):\n",
    "    if i==0:\n",
    "        plt.plot(t_list, np.array(ptrain_analytical)[:,i], color=cmap(i/len(ptrain_analytical[0])), marker=\"\", linestyle=\"solid\", label=\"Analytical\", alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptrain_empiricals[0])[:,i], linestyle=\"dotted\", color=cmap(i/len(ptrain_analytical[0])), label=\"$M=16$\")\n",
    "        plt.plot(t_list, np.array(ptrain_empiricals[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptrain_analytical[0])), label=\"$M=1024$\")\n",
    "    else:\n",
    "        plt.plot(t_list, np.array(ptrain_analytical)[:,i], color=cmap(i/len(ptrain_analytical[0])), marker=\"\", linestyle=\"solid\", alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptrain_empiricals[0])[:,i], linestyle=\"dotted\", color=cmap(i/len(ptest_analytical[0])))\n",
    "        plt.plot(t_list, np.array(ptrain_empiricals[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptest_analytical[0])))\n",
    "\n",
    "plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "plt.title(\"Train output\")\n",
    "plt.ylabel(\"$f(x_i, w, \\pi)$\")\n",
    "\n",
    "plt.legend(loc=\"lower left\")    \n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,2,2)\n",
    "for i in range(len(ptest_analytical[0])):\n",
    "    if i==0:\n",
    "        plt.plot(t_list, np.array(ptest_analytical)[:,i], color=cmap(i/len(ptest_analytical[0])), label=\"Analytical\", alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptest_empiricals[0])[:,i], linestyle=\"dotted\", color=cmap(i/len(ptrain_analytical[0])), label=\"$M=16$\")\n",
    "        plt.plot(t_list, np.array(ptest_empiricals[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptrain_analytical[0])), label=\"$M=1024$\")\n",
    "    else:\n",
    "        plt.plot(t_list, np.array(ptest_analytical)[:,i], color=cmap(i/len(ptest_analytical[0])), alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptest_empiricals[0])[:,i], linestyle=\"dotted\", color=cmap(i/len(ptest_analytical[0])))\n",
    "        plt.plot(t_list, np.array(ptest_empiricals[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptest_analytical[0])))\n",
    "plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "plt.title(\"Test output\")\n",
    "plt.ylabel(\"$f(x_i, w, \\pi)$\")\n",
    "\n",
    "plt.legend(loc=\"lower left\")\n",
    "plt.grid(linestyle=\"dotted\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71476d20",
   "metadata": {},
   "source": [
    "## Figure6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "368654c7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-04T01:29:41.126508Z",
     "start_time": "2021-10-04T01:29:15.320121Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 4))\n",
    "ax1 = fig.add_subplot(143)\n",
    "ax2 = fig.add_subplot(144)\n",
    "ax3 = fig.add_subplot(141)\n",
    "ax4 = fig.add_subplot(142)\n",
    "\n",
    "max_alpha = 50\n",
    "\n",
    "alpha = 2\n",
    "max_depth = 31\n",
    "\n",
    "kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel((2**(alpha-2)), max_depth)\n",
    "\n",
    "for depth in range(1, len(kernel_list)+1, 1):\n",
    "    im = eval(f\"ax{alpha}\").plot(inner_product_list[0], [i/max(kernel_list[depth-1]) for i in kernel_list[depth-1]], color=cm.cool(depth/max_depth))    \n",
    "eval(f\"ax{alpha}\").set_title(\"Normalized \"+\"$\\Theta^{(d)}(x_i, x_j), \\\\alpha=$\"+f\"{(2**(alpha-2))}\")\n",
    "eval(f\"ax{alpha}\").grid(linestyle=\"dotted\")\n",
    "eval(f\"ax{alpha}\").set_xlabel(\"Inner product of the inputs\")\n",
    "eval(f\"ax{alpha}\").set_ylim(-0.1,1.1)\n",
    "\n",
    "depth =  2\n",
    "for alpha in range(1, max_alpha, 1):\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha*0.25, max_depth=depth+1)\n",
    "    ax1.plot(inner_product_list[0], [i/max(kernel_list[::-1][0]) for i in kernel_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    ax1.set_title(\"Normalized $\\Theta^{(2)}(x_i, x_j)$\")\n",
    "    ax1.set_xlabel(\"inner product of the inputs\")\n",
    "    ax1.set_ylim(-0.1,1.1)\n",
    "    \n",
    "ax1.grid(linestyle=\"dotted\")\n",
    "\n",
    "for alpha in range(1, max_alpha, 1):\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha*0.25, max_depth=2) # depth does not affect the result\n",
    "    ax3.plot(inner_product_list[0], [i for i in tau_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    ax3.set_title(\"${\\mathcal{T}}(x_i, x_j)$\")\n",
    "    ax3.set_xlabel(\"inner product of the inputs\")\n",
    "    ax3.grid(linestyle=\"dotted\")\n",
    "    \n",
    "    ax4.plot(inner_product_list[0], [i for i in tau_dot_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    ax4.set_title(\"$\\dot{\\mathcal{T}}(x_i, x_j)$\")\n",
    "    ax4.set_xlabel(\"inner product of the inputs\")\n",
    "    ax4.grid(linestyle=\"dotted\")\n",
    "    \n",
    "cax = plt.axes([0.05, -0.05, 0.68, 0.04]) \n",
    "norm = matplotlib.colors.Normalize(vmin=0, vmax=max_alpha*0.25)\n",
    "fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.jet), cax=cax, orientation='horizontal', label=\"$\\\\alpha$\", ticks = [1,3,5,7,9,11,13,15,17,19])\n",
    "\n",
    "cax = plt.axes([0.79, -0.05, 0.19, 0.04]) \n",
    "norm = matplotlib.colors.Normalize(vmin=0, vmax=max_depth)\n",
    "fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.cool), cax=cax, orientation='horizontal', label=\"$d$\", ticks = [5,10,15,20,25])\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b8ed559",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
