{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:27.744893Z",
     "start_time": "2021-09-27T03:15:27.738004Z"
    }
   },
   "outputs": [],
   "source": [
    "def set_latex():\n",
    "    for i in range(2):\n",
    "        import matplotlib\n",
    "        import matplotlib.pyplot as plt\n",
    "\n",
    "        plt.rc('text', usetex=True)\n",
    "        plt.rc('font', family='serif')\n",
    "\n",
    "        plt.style.use(\"default\")\n",
    "        plt.rcParams[\"font.size\"]=15\n",
    "\n",
    "        plt.rcParams['font.family'] = 'Times New Roman'  # font familyの設定\n",
    "        plt.rcParams['mathtext.fontset'] = 'stix'  # math fontの設定\n",
    "\n",
    "        try:\n",
    "            del matplotlib.font_manager.weight_dict['roman']\n",
    "            matplotlib.font_manager._rebuild()\n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:32.310591Z",
     "start_time": "2021-09-27T03:15:27.749292Z"
    }
   },
   "outputs": [],
   "source": [
    "set_latex()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:35.145364Z",
     "start_time": "2021-09-27T03:15:32.313901Z"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "import numpy as np\n",
    "import copy\n",
    "import ntk\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import warnings \n",
    "from scipy import special\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.colors as colors\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from models import SoftTree\n",
    "from utils import compute_kernels, rotation_o\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:35.156771Z",
     "start_time": "2021-09-27T03:15:35.147632Z"
    }
   },
   "outputs": [],
   "source": [
    "def circle_transform(angle_vec):\n",
    "    cos_tensor = torch.cos(angle_vec)\n",
    "    sin_tensor = torch.sin(angle_vec)\n",
    "    return torch.stack((cos_tensor, sin_tensor), -1)\n",
    "\n",
    "n_features = 2\n",
    "n_dataset = 10\n",
    "\n",
    "train_data = torch.Tensor([np.random.randn(n_features) for i in range(n_dataset)])\n",
    "target_data = torch.tensor(np.random.randn(train_data.shape[0]))\n",
    "test_data = torch.Tensor([np.random.randn(n_features) for i in range(10)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tracking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:35.163984Z",
     "start_time": "2021-09-27T03:15:35.158843Z"
    }
   },
   "outputs": [],
   "source": [
    "def train_net(net, n_epochs, input_data, target, lr, initial_train):\n",
    "    criterion = nn.MSELoss(reduction='mean')\n",
    "    optimizer = optim.SGD(net.parameters(), lr=lr)\n",
    "    for epoch in range(n_epochs):  \n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(input_data.double())-initial_train.unsqueeze(1)\n",
    "        loss = criterion(outputs.view(-1), target)/2\n",
    "        loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:35.173329Z",
     "start_time": "2021-09-27T03:15:35.166262Z"
    }
   },
   "outputs": [],
   "source": [
    "def analytical_evolution_MSE(t, lr, H_train, H_test, initial_train, initial_test, target_data):\n",
    "    n_train = len(initial_train)\n",
    "\n",
    "    # first compute the exponential of the matrix (using eigendecomposition):\n",
    "    lam, P = np.linalg.eig(H_train)  # eig decomposition\n",
    "    lam = lam.astype(dtype='float64')\n",
    "\n",
    "    H_train_inv = np.dot(P, np.dot(np.diag(lam**(-1)), P.transpose()))\n",
    "\n",
    "    # note that you need to rescale the time by n_train, as the 2 paper use different convention for the loss function\n",
    "    # I am using np arrays, not torch tensors\n",
    "    exp_matrix = np.dot(\n",
    "        P, np.dot(np.diag(np.exp(-lr * t * lam / n_train)), P.transpose()))\n",
    "\n",
    "    # compute the prediction on train set\n",
    "    pred_train = target_data.cpu().numpy() + np.dot(exp_matrix,\n",
    "                                                    (initial_train - target_data).cpu().detach().numpy())\n",
    "\n",
    "    # compute the intermediate matrix used both in prediction on test set and weights evolution\n",
    "    tmp = np.dot(np.eye(lam.size) - exp_matrix,\n",
    "                 (initial_train - target_data).cpu().detach().numpy())\n",
    "    tmp = np.dot(H_train_inv, tmp)\n",
    "\n",
    "    # compute prediction on test set\n",
    "    pred_test = np.dot(H_test, tmp)\n",
    "    pred_test = initial_test.detach().cpu().numpy().reshape(-1) - pred_test\n",
    "\n",
    "    return pred_train, pred_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:45.776401Z",
     "start_time": "2021-09-27T03:15:35.178289Z"
    }
   },
   "outputs": [],
   "source": [
    "alpha = 2.0\n",
    "depth = 3\n",
    "\n",
    "H_analytical_train = ntk.tree(train_data.numpy(), max_depth=depth, alpha=alpha)[0][depth-1]\n",
    "H_analytical_test = ntk.tree(torch.cat([train_data, test_data]).numpy(), max_depth=depth, alpha=alpha)[0][depth-1][len(train_data):, 0:len(train_data)]\n",
    "\n",
    "ptrain_empiricals, ptest_empiricals = [], []\n",
    "for n_tree in (16, 1024):\n",
    "    ptrain_empirical, ptest_empirical, ptrain_analytical, ptest_analytical = [], [], [], []\n",
    "    \n",
    "    t_max = 1000\n",
    "    t_step = 10\n",
    "    lr = 0.1\n",
    "    t_list = np.arange(t_step, t_max+t_step, t_step)\n",
    "\n",
    "    st = SoftTree(input_dim=train_data.shape[1], output_dim=1, max_depth=depth, scale=alpha, n_tree=n_tree)\n",
    "    initial_train=st.forward(train_data).reshape(-1)\n",
    "    initial_test=st.forward(test_data).reshape(-1)\n",
    "\n",
    "    ptrain_analytical.append(torch.zeros_like(initial_train).detach().numpy())\n",
    "    ptrain_empirical.append(torch.zeros_like(initial_train).detach().numpy())\n",
    "\n",
    "    ptest_analytical.append(torch.zeros_like(initial_test).detach().numpy())\n",
    "    ptest_empirical.append(torch.zeros_like(initial_test).detach().numpy())\n",
    "\n",
    "    for t in tqdm(t_list):\n",
    "        train_net(st, t_step, train_data, target_data, lr, initial_train.detach())\n",
    "\n",
    "        ptrain_empirical.append(st.forward(train_data).detach().cpu().numpy().reshape(-1)-initial_train.detach().numpy())\n",
    "        ptest_empirical.append(st.forward(test_data).detach().cpu().numpy().reshape(-1)-initial_test.detach().numpy())\n",
    "\n",
    "        pred_train, pred_test = analytical_evolution_MSE(\n",
    "            t=t,\n",
    "            lr=lr,\n",
    "            H_train=H_analytical_train, \n",
    "            H_test=H_analytical_test,\n",
    "            initial_train=torch.zeros_like(initial_train), \n",
    "            initial_test=torch.zeros_like(initial_test), \n",
    "            target_data=target_data\n",
    "        )\n",
    "        ptrain_analytical.append(pred_train)\n",
    "        ptest_analytical.append(pred_test)\n",
    "    ptrain_empiricals.append(ptrain_empirical)\n",
    "    ptest_empiricals.append(ptest_empirical)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:45.786139Z",
     "start_time": "2021-09-27T03:15:45.780542Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from models import SoftTree\n",
    "from utils import compute_kernels, rotation_o\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:45.800909Z",
     "start_time": "2021-09-27T03:15:45.789092Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_kernel(alpha, max_depth, mode=\"tree\"):\n",
    "    assert mode in (\"tree\", \"mlp_relu\", \"mlp_erf\")\n",
    "    def rotation_o(u, t, deg=False):    \n",
    "        # 度数単位の角度をラジアンに変換\n",
    "        if deg == True:\n",
    "            t = np.deg2rad(t)\n",
    "\n",
    "        # 回転行列\n",
    "        R = np.array([[np.cos(t), -np.sin(t)],\n",
    "                      [np.sin(t),  np.cos(t)]])\n",
    "        return  np.dot(R, u)    \n",
    "\n",
    "    u = (1.0, 0.0)\n",
    "    \n",
    "    kernel_list = []\n",
    "    tau_list = []\n",
    "    tau_dot_list = []    \n",
    "    inner_product_list = []\n",
    "     \n",
    "    for depth in range(1,max_depth,1):\n",
    "        kernel = []\n",
    "        taus = []\n",
    "        tau_dots = []\n",
    "        inner_product = []\n",
    "        for i in range(360):\n",
    "            Ru = rotation_o(u, i*np.pi/180)\n",
    "            if mode == \"tree\":\n",
    "                H, tau, tau_dot = ntk.tree(np.vstack([u, Ru]), max_depth=depth, alpha=alpha)\n",
    "            else:\n",
    "                NotImplementedError\n",
    "            kernel.append(H[depth-1, 1,0])\n",
    "            taus.append(tau[1,0])\n",
    "            tau_dots.append(tau_dot[1,0])\n",
    "            inner_product.append(np.dot(u, Ru))\n",
    "        kernel_list.append(kernel)\n",
    "        tau_list.append(taus)\n",
    "        tau_dot_list.append(tau_dots)\n",
    "        inner_product_list.append(inner_product)\n",
    "    return kernel_list, tau_list, tau_dot_list, inner_product_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:45.812388Z",
     "start_time": "2021-09-27T03:15:45.803755Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_kernel(alpha, n_tree, color, is_first=False):\n",
    "    st = SoftTree(input_dim=2, output_dim=1, max_depth=3, scale=alpha, n_tree=n_tree)\n",
    "    u = (1, 0)\n",
    "    res = []\n",
    "    inner_product = []\n",
    "    for i in tqdm(range(180), leave=False):\n",
    "        Ru = rotation_o(u, i * np.pi / 180)\n",
    "        x = torch.Tensor([u, Ru]).to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        K = compute_kernels(st, x)\n",
    "        res.append(K[1,0].tolist())\n",
    "        inner_product.append(np.dot(u, Ru))\n",
    "    if is_first:\n",
    "        plt.plot(inner_product, res, color=color, linewidth=1, label=f\"$M={n_tree}$\")\n",
    "    else:\n",
    "        plt.plot(inner_product, res, color=color, linewidth=1)        \n",
    "    return inner_product, res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-26T12:02:57.297033Z",
     "start_time": "2021-09-26T11:55:51.161499Z"
    },
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "outputs": [],
   "source": [
    "n_seeds = 10\n",
    "\n",
    "fig = plt.figure(figsize=(15, 10))\n",
    "plt.subplot(1,3,1)\n",
    "depth = 3\n",
    "alpha = 2.0\n",
    "\n",
    "n_tree_combinations = [16, 64, 256, 1024, 4096]\n",
    "for j, n_tree in enumerate(n_tree_combinations):\n",
    "    res_all = []\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha ,depth+1)\n",
    "    for i in tqdm(range(n_seeds), leave=False):\n",
    "        inner_product, res = plot_kernel(alpha, n_tree, color=cm.bwr(j/len(n_tree_combinations)), is_first=(i==0))\n",
    "        res_all.append(res)\n",
    "    mse = 0\n",
    "    for i in range(len(res_all)):\n",
    "        mse+=mean_squared_error(res_all[i], kernel_list[depth-1][0:180])\n",
    "    mse/=len(res_all)\n",
    "plt.plot(inner_product_list[0][0:180], kernel_list[depth-1][0:180],  color=\"black\", linewidth=3, linestyle=\"dotted\", label=\"$M={\\\\infty}$\")\n",
    "#plt.title(\"$\\\\alpha=2$, $d=3$\")\n",
    "plt.xlabel(\"Inner product of the inputs\")\n",
    "plt.ylabel(\"$\\widehat{\\Theta}_0(x_i, x_j)$\")\n",
    "plt.legend(prop={'size': 12})\n",
    "plt.ylim(-0.3, 2.5)\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "# -----\n",
    "plt.subplot(2,3,2)\n",
    "n_dataset = 50\n",
    "n_features = 5\n",
    "depth = 3\n",
    "x = torch.Tensor([(j/np.sqrt(sum(j**2))) for j in np.array([np.random.randn(n_features) for i in range(n_dataset)])])\n",
    "\n",
    "norm = colors.LogNorm(vmin=0.5, vmax=100)\n",
    "\n",
    "for j, alpha in enumerate(tqdm((0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0), leave=False)):\n",
    "    res, n_tree_list = [], []\n",
    "    K_infinite,_,_ = ntk.tree(np.array(x), max_depth=depth, alpha=alpha)\n",
    "    K_infinite = K_infinite[-1]\n",
    "\n",
    "    for seed in tqdm(range(n_seeds), leave=False):\n",
    "        res_seed = []\n",
    "        for i in tqdm(range(8), leave=False):\n",
    "            n_tree = 2**(i+5)\n",
    "            st = SoftTree(input_dim=n_features, output_dim=1, max_depth=depth, scale=alpha, n_tree=n_tree)\n",
    "            K_empirical = compute_kernels(st, x)\n",
    "            res_seed.append(torch.norm((K_empirical - K_infinite)/torch.norm(K_empirical)).tolist())\n",
    "            n_tree_list.append(n_tree)\n",
    "        res.append(res_seed)\n",
    "    plt.errorbar(\n",
    "        sorted(np.unique(n_tree_list)), np.array(res).mean(axis=0), \n",
    "        yerr=np.array(res).std(axis=0), \n",
    "        color=cm.magma((j+1)/9), capsize=3, label=\"$d=3, \\\\alpha=$\"+f\"{alpha}\"\n",
    "    )\n",
    "\n",
    "plt.plot(sorted(np.unique(n_tree_list)), [0.1*(i**-0.5) for i in sorted(np.unique(n_tree_list))], linestyle=\"dashed\", color=\"black\", label=\"trend: $M^{-1/2}$\")\n",
    "\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "#plt.title(\"Change $\\\\alpha$ with fixed $d=3$\")\n",
    "plt.xlabel(\"$M$\")\n",
    "plt.ylabel(\"$\\\\frac{\\left|| \\\\widehat{H}_0 - {H}  \\\\right||_F}{||\\\\widehat{H}_0||_F}$\")\n",
    "\n",
    "#plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left', prop={'size': 12})\n",
    "plt.legend(prop={'size': 12}, loc=\"lower left\")\n",
    "\n",
    "# --------\n",
    "\n",
    "plt.subplot(2,3,3)\n",
    "n_dataset = 50\n",
    "n_features = 5\n",
    "depth = 3\n",
    "alpha = 2.0\n",
    "\n",
    "x = torch.Tensor([(j/np.sqrt(sum(j**2))) for j in np.array([np.random.randn(n_features) for i in range(n_dataset)])])\n",
    "\n",
    "norm = colors.LogNorm(vmin=0.5, vmax=100)\n",
    "\n",
    "for j, depth in enumerate(tqdm(range(1,7,1), leave=False)):\n",
    "    res, n_tree_list = [], []\n",
    "    K_infinite,_,_ = ntk.tree(np.array(x), max_depth=depth, alpha=alpha)\n",
    "    K_infinite = K_infinite[-1]\n",
    "\n",
    "    for seed in tqdm(range(n_seeds), leave=False):\n",
    "        res_seed = []\n",
    "        for i in tqdm(range(8), leave=False):\n",
    "            n_tree = 2**(i+5)\n",
    "            st = SoftTree(input_dim=n_features, output_dim=1, max_depth=depth, scale=alpha, n_tree=n_tree)\n",
    "            K_empirical = compute_kernels(st, x)\n",
    "            res_seed.append(torch.norm((K_empirical - K_infinite)/torch.norm(K_empirical)).tolist())\n",
    "            n_tree_list.append(n_tree)\n",
    "        res.append(res_seed)\n",
    "    plt.errorbar(\n",
    "        sorted(np.unique(n_tree_list)), np.array(res).mean(axis=0), \n",
    "        yerr=np.array(res).std(axis=0), \n",
    "        color=cm.viridis((j+1)/6), capsize=3, label=f\"$d={int(depth)}, \\\\alpha=2.0$\"\n",
    "    )\n",
    "\n",
    "plt.plot(sorted(np.unique(n_tree_list)), [0.1*(i**-0.5) for i in sorted(np.unique(n_tree_list))], linestyle=\"dashed\", color=\"black\", label=\"trend: $M^{-1/2}$\")\n",
    "\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "#plt.title(\"Change $d$ with fixed $\\\\alpha=2$\")\n",
    "plt.xlabel(\"$M$\")\n",
    "plt.ylabel(\"$\\\\frac{\\left|| \\\\widehat{H}_0 - {H}  \\\\right||_F}{||\\\\widehat{H}_0||_F}$\")\n",
    "\n",
    "plt.legend(prop={'size': 12}, loc=\"lower left\")\n",
    "\n",
    "# ------------\n",
    "\n",
    "cmap = plt.cm.nipy_spectral\n",
    "t_list = np.arange(0, t_max+t_step, t_step)\n",
    "\n",
    "plt.subplot(2,3,5)\n",
    "for i in range(len(ptrain_analytical[0])):\n",
    "    if i==0:\n",
    "        plt.plot(t_list, np.array(ptrain_analytical)[:,i], color=cmap(i/len(ptrain_analytical[0])), marker=\"\", linestyle=\"solid\", label=\"Analytical\", alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptrain_empiricals[0])[:,i], linestyle=\"dotted\", color=cmap(i/len(ptrain_analytical[0])), label=\"$M=16$\")\n",
    "        plt.plot(t_list, np.array(ptrain_empiricals[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptrain_analytical[0])), label=\"$M=1024$\")\n",
    "    else:\n",
    "        plt.plot(t_list, np.array(ptrain_analytical)[:,i], color=cmap(i/len(ptrain_analytical[0])), marker=\"\", linestyle=\"solid\", alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptrain_empiricals[0])[:,i], linestyle=\"dotted\", color=cmap(i/len(ptest_analytical[0])))\n",
    "        plt.plot(t_list, np.array(ptrain_empiricals[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptest_analytical[0])))\n",
    "\n",
    "plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "plt.ylabel(\"Train output\")\n",
    "#plt.title(\"$\\\\alpha=2$, $d=3$\")\n",
    "plt.legend(prop={'size': 12}, loc=\"lower left\")\n",
    "\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(2,3,6)\n",
    "for i in range(len(ptest_analytical[0])):\n",
    "    if i==0:\n",
    "        plt.plot(t_list, np.array(ptest_analytical)[:,i], color=cmap(i/len(ptest_analytical[0])), label=\"Analytical\", alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptest_empiricals[0])[:,i], linestyle=\"dotted\", color=cmap(i/len(ptrain_analytical[0])), label=\"$M=16$\")\n",
    "        plt.plot(t_list, np.array(ptest_empiricals[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptrain_analytical[0])), label=\"$M=1024$\")\n",
    "    else:\n",
    "        plt.plot(t_list, np.array(ptest_analytical)[:,i], color=cmap(i/len(ptest_analytical[0])), alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptest_empiricals[0])[:,i], linestyle=\"dotted\", color=cmap(i/len(ptest_analytical[0])))\n",
    "        plt.plot(t_list, np.array(ptest_empiricals[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptest_analytical[0])))\n",
    "plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "plt.ylabel(\"Test output\")\n",
    "#plt.title(\"$\\\\alpha=2$, $d=3$\")\n",
    "\n",
    "plt.legend(prop={'size': 12}, loc=\"lower left\")\n",
    "\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"./figures/finite.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:24:13.683398Z",
     "start_time": "2021-09-27T03:15:45.815175Z"
    }
   },
   "outputs": [],
   "source": [
    "n_seeds = 10\n",
    "\n",
    "fig = plt.figure(figsize=(15, 5))\n",
    "plt.subplot(1,3,1)\n",
    "depth = 3\n",
    "alpha = 2.0\n",
    "\n",
    "n_tree_combinations = [16, 64, 256, 1024, 4096]\n",
    "for j, n_tree in enumerate(n_tree_combinations):\n",
    "    res_all = []\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha ,depth+1)\n",
    "    for i in tqdm(range(n_seeds), leave=False):\n",
    "        inner_product, res = plot_kernel(alpha, n_tree, color=cm.bwr(j/len(n_tree_combinations)), is_first=(i==0))\n",
    "        res_all.append(res)\n",
    "    mse = 0\n",
    "    for i in range(len(res_all)):\n",
    "        mse+=mean_squared_error(res_all[i], kernel_list[depth-1][0:180])\n",
    "    mse/=len(res_all)\n",
    "plt.plot(inner_product_list[0][0:180], kernel_list[depth-1][0:180],  color=\"black\", linewidth=3, linestyle=\"dotted\", label=\"$M={\\\\infty}$\")\n",
    "plt.title(\"$d=3$, $\\\\alpha=2.0$\")\n",
    "plt.xlabel(\"Inner product of the inputs\")\n",
    "plt.ylabel(\"$\\widehat{\\Theta}_0(x_i, x_j)$\")\n",
    "plt.legend(prop={'size': 12})\n",
    "plt.ylim(-0.3, 2.5)\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "# -----\n",
    "plt.subplot(1,3,2)\n",
    "n_dataset = 50\n",
    "n_features = 5\n",
    "depth = 3\n",
    "x = torch.Tensor([(j/np.sqrt(sum(j**2))) for j in np.array([np.random.randn(n_features) for i in range(n_dataset)])])\n",
    "\n",
    "norm = colors.LogNorm(vmin=0.5, vmax=100)\n",
    "\n",
    "for j, alpha in enumerate(tqdm((0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0), leave=False)):\n",
    "    res, n_tree_list = [], []\n",
    "    K_infinite,_,_ = ntk.tree(np.array(x), max_depth=depth, alpha=alpha)\n",
    "    K_infinite = K_infinite[-1]\n",
    "\n",
    "    for seed in tqdm(range(n_seeds), leave=False):\n",
    "        res_seed = []\n",
    "        for i in tqdm(range(8), leave=False):\n",
    "            n_tree = 2**(i+5)\n",
    "            st = SoftTree(input_dim=n_features, output_dim=1, max_depth=depth, scale=alpha, n_tree=n_tree)\n",
    "            K_empirical = compute_kernels(st, x)\n",
    "            res_seed.append(torch.norm((K_empirical - K_infinite)/torch.norm(K_empirical)).tolist())\n",
    "            n_tree_list.append(n_tree)\n",
    "        res.append(res_seed)\n",
    "    plt.errorbar(\n",
    "        sorted(np.unique(n_tree_list)), np.array(res).mean(axis=0), \n",
    "        yerr=np.array(res).std(axis=0), \n",
    "        color=cm.magma((j+1)/9), capsize=3, label=\"$d=3, \\\\alpha=$\"+f\"{alpha}\"\n",
    "    )\n",
    "\n",
    "plt.plot(sorted(np.unique(n_tree_list)), [0.1*(i**-0.5) for i in sorted(np.unique(n_tree_list))], linestyle=\"dashed\", color=\"black\", label=\"trend: $M^{-1/2}$\")\n",
    "\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "plt.title(\"Change $\\\\alpha$ with fixed $d=3$\")\n",
    "plt.xlabel(\"$M$\")\n",
    "plt.ylabel(\"$\\\\frac{\\left|| \\\\widehat{H}_0 - {H}  \\\\right||_F}{||\\\\widehat{H}_0||_F}$\")\n",
    "\n",
    "#plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left', prop={'size': 12})\n",
    "plt.legend(prop={'size': 12}, loc=\"lower left\")\n",
    "\n",
    "# --------\n",
    "\n",
    "plt.subplot(1,3,3)\n",
    "n_dataset = 50\n",
    "n_features = 5\n",
    "depth = 3\n",
    "alpha = 2.0\n",
    "\n",
    "x = torch.Tensor([(j/np.sqrt(sum(j**2))) for j in np.array([np.random.randn(n_features) for i in range(n_dataset)])])\n",
    "\n",
    "norm = colors.LogNorm(vmin=0.5, vmax=100)\n",
    "\n",
    "for j, depth in enumerate(tqdm(range(1,7,1), leave=False)):\n",
    "    res, n_tree_list = [], []\n",
    "    K_infinite,_,_ = ntk.tree(np.array(x), max_depth=depth, alpha=alpha)\n",
    "    K_infinite = K_infinite[-1]\n",
    "\n",
    "    for seed in tqdm(range(n_seeds), leave=False):\n",
    "        res_seed = []\n",
    "        for i in tqdm(range(8), leave=False):\n",
    "            n_tree = 2**(i+5)\n",
    "            st = SoftTree(input_dim=n_features, output_dim=1, max_depth=depth, scale=alpha, n_tree=n_tree)\n",
    "            K_empirical = compute_kernels(st, x)\n",
    "            res_seed.append(torch.norm((K_empirical - K_infinite)/torch.norm(K_empirical)).tolist())\n",
    "            n_tree_list.append(n_tree)\n",
    "        res.append(res_seed)\n",
    "    plt.errorbar(\n",
    "        sorted(np.unique(n_tree_list)), np.array(res).mean(axis=0), \n",
    "        yerr=np.array(res).std(axis=0), \n",
    "        color=cm.viridis((j+1)/6), capsize=3, label=f\"$d={int(depth)}, \\\\alpha=2.0$\"\n",
    "    )\n",
    "\n",
    "plt.plot(sorted(np.unique(n_tree_list)), [0.1*(i**-0.5) for i in sorted(np.unique(n_tree_list))], linestyle=\"dashed\", color=\"black\", label=\"trend: $M^{-1/2}$\")\n",
    "\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "plt.title(\"Change $d$ with fixed $\\\\alpha=2.0$\")\n",
    "plt.xlabel(\"$M$\")\n",
    "plt.ylabel(\"$\\\\frac{\\left|| \\\\widehat{H}_0 - {H}  \\\\right||_F}{||\\\\widehat{H}_0||_F}$\")\n",
    "\n",
    "plt.legend(prop={'size': 12}, loc=\"lower left\")\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"./figures/finite.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:24:15.574921Z",
     "start_time": "2021-09-27T03:24:13.685756Z"
    }
   },
   "outputs": [],
   "source": [
    "cmap = plt.cm.nipy_spectral\n",
    "t_list = np.arange(0, t_max+t_step, t_step)\n",
    "\n",
    "plt.figure(figsize=(15, 5))\n",
    "plt.subplot(1,2,1)\n",
    "for i in range(len(ptrain_analytical[0])):\n",
    "    if i==0:\n",
    "        plt.plot(t_list, np.array(ptrain_analytical)[:,i], color=cmap(i/len(ptrain_analytical[0])), marker=\"\", linestyle=\"solid\", label=\"Analytical\", alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptrain_empiricals[0])[:,i], linestyle=\"dotted\", color=cmap(i/len(ptrain_analytical[0])), label=\"$M=16$\")\n",
    "        plt.plot(t_list, np.array(ptrain_empiricals[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptrain_analytical[0])), label=\"$M=1024$\")\n",
    "    else:\n",
    "        plt.plot(t_list, np.array(ptrain_analytical)[:,i], color=cmap(i/len(ptrain_analytical[0])), marker=\"\", linestyle=\"solid\", alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptrain_empiricals[0])[:,i], linestyle=\"dotted\", color=cmap(i/len(ptest_analytical[0])))\n",
    "        plt.plot(t_list, np.array(ptrain_empiricals[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptest_analytical[0])))\n",
    "\n",
    "plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "plt.title(\"Train output\")\n",
    "plt.legend(loc=\"lower left\")    \n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,2,2)\n",
    "for i in range(len(ptest_analytical[0])):\n",
    "    if i==0:\n",
    "        plt.plot(t_list, np.array(ptest_analytical)[:,i], color=cmap(i/len(ptest_analytical[0])), label=\"Analytical\", alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptest_empiricals[0])[:,i], linestyle=\"dotted\", color=cmap(i/len(ptrain_analytical[0])), label=\"$M=16$\")\n",
    "        plt.plot(t_list, np.array(ptest_empiricals[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptrain_analytical[0])), label=\"$M=1024$\")\n",
    "    else:\n",
    "        plt.plot(t_list, np.array(ptest_analytical)[:,i], color=cmap(i/len(ptest_analytical[0])), alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptest_empiricals[0])[:,i], linestyle=\"dotted\", color=cmap(i/len(ptest_analytical[0])))\n",
    "        plt.plot(t_list, np.array(ptest_empiricals[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptest_analytical[0])))\n",
    "plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "plt.title(\"Test output\")\n",
    "\n",
    "plt.legend(loc=\"lower left\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "plt.savefig(\"./figures/trajectory.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
