{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb96d678",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:25.807349Z",
     "start_time": "2021-09-25T08:56:25.801536Z"
    }
   },
   "outputs": [],
   "source": [
    "def set_latex():\n",
    "    for i in range(2):\n",
    "        import matplotlib\n",
    "        import matplotlib.pyplot as plt\n",
    "\n",
    "        plt.rc('text', usetex=True)\n",
    "        plt.rc('font', family='serif')\n",
    "\n",
    "        plt.style.use(\"default\")\n",
    "        plt.rcParams[\"font.size\"]=15\n",
    "\n",
    "        plt.rcParams['font.family'] = 'Times New Roman'  # font familyの設定\n",
    "        plt.rcParams['mathtext.fontset'] = 'stix'  # math fontの設定\n",
    "\n",
    "        try:\n",
    "            del matplotlib.font_manager.weight_dict['roman']\n",
    "            matplotlib.font_manager._rebuild()\n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbbc0585",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:28.621842Z",
     "start_time": "2021-09-25T08:56:26.049699Z"
    }
   },
   "outputs": [],
   "source": [
    "set_latex()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d982c62",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:30.580638Z",
     "start_time": "2021-09-25T08:56:28.624803Z"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import numpy as np\n",
    "import ntk\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import warnings \n",
    "from scipy import special\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.colors as colors\n",
    "\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6bb61b8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:30.600770Z",
     "start_time": "2021-09-25T08:56:30.586043Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_kernel(alpha, max_depth, mode=\"tree\"):\n",
    "    assert mode in (\"tree\", \"mlp_relu\", \"mlp_erf\")\n",
    "    def rotation_o(u, t, deg=False):    \n",
    "        # 度数単位の角度をラジアンに変換\n",
    "        if deg == True:\n",
    "            t = np.deg2rad(t)\n",
    "\n",
    "        # 回転行列\n",
    "        R = np.array([[np.cos(t), -np.sin(t)],\n",
    "                      [np.sin(t),  np.cos(t)]])\n",
    "        return  np.dot(R, u)    \n",
    "\n",
    "    u = (1.0, 0.0)\n",
    "    \n",
    "    kernel_list = []\n",
    "    tau_list = []\n",
    "    tau_dot_list = []    \n",
    "    inner_product_list = []\n",
    "     \n",
    "    for depth in range(1,max_depth,1):\n",
    "        kernel = []\n",
    "        taus = []\n",
    "        tau_dots = []\n",
    "        inner_product = []\n",
    "        for i in range(360):\n",
    "            Ru = rotation_o(u, i*np.pi/180)\n",
    "            if mode == \"tree\":\n",
    "                H, tau, tau_dot = ntk.tree(np.vstack([u, Ru]), max_depth=depth, alpha=alpha)\n",
    "            else:\n",
    "                NotImplementedError\n",
    "            kernel.append(H[depth-1, 1,0])\n",
    "            taus.append(tau[1,0])\n",
    "            tau_dots.append(tau_dot[1,0])\n",
    "            inner_product.append(np.dot(u, Ru))\n",
    "        kernel_list.append(kernel)\n",
    "        tau_list.append(taus)\n",
    "        tau_dot_list.append(tau_dots)\n",
    "        inner_product_list.append(inner_product)\n",
    "    return kernel_list, tau_list, tau_dot_list, inner_product_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6721b8e8",
   "metadata": {},
   "source": [
    "# 異なるスケーリングごとでの可視化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae1fd756",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:38.956907Z",
     "start_time": "2021-09-25T08:56:30.604580Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 4))\n",
    "ax1 = fig.add_subplot(143)\n",
    "ax2 = fig.add_subplot(144)\n",
    "ax3 = fig.add_subplot(141)\n",
    "ax4 = fig.add_subplot(142)\n",
    "\n",
    "max_alpha = 50\n",
    "\n",
    "depth =  1\n",
    "for alpha in range(1, max_alpha, 1):\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha*0.25, max_depth=depth+1)\n",
    "    ax1.plot(inner_product_list[0], [i/max(kernel_list[::-1][0]) for i in kernel_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    #ax1.plot(inner_product_list[0], [i for i in kernel_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    \n",
    "    ax1.set_title(\"Normalized $\\Theta^{(1)}(x_i, x_j)$\")\n",
    "    ax1.set_xlabel(\"Inner product of the inputs\")\n",
    "ax1.grid(linestyle=\"dotted\")\n",
    "\n",
    "depth =  2\n",
    "for alpha in range(1, max_alpha, 1):\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha*0.25, max_depth=depth+1)\n",
    "    ax2.plot(inner_product_list[0], [i/max(kernel_list[::-1][0]) for i in kernel_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    #ax2.plot(inner_product_list[0], [i for i in kernel_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    \n",
    "    ax2.set_title(\"Normalized $\\Theta^{(2)}(x_i, x_j)$\")\n",
    "    ax2.set_xlabel(\"inner product of the inputs\")\n",
    "ax2.grid(linestyle=\"dotted\")\n",
    "\n",
    "for alpha in range(1, max_alpha, 1):\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha*0.25, max_depth=2) # depth does not affect the result\n",
    "    ax3.plot(inner_product_list[0], [i for i in tau_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    ax3.set_title(\"${\\mathcal{T}}(x_i, x_j)$\")\n",
    "    ax3.set_xlabel(\"inner product of the inputs\")\n",
    "    ax3.grid(linestyle=\"dotted\")\n",
    "    \n",
    "    ax4.plot(inner_product_list[0], [i for i in tau_dot_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    ax4.set_title(\"$\\dot{\\mathcal{T}}(x_i, x_j)$\")\n",
    "    ax4.set_xlabel(\"inner product of the inputs\")\n",
    "    ax4.grid(linestyle=\"dotted\")\n",
    "    \n",
    "cax = plt.axes([0.14, -0.05, 0.75, 0.04]) \n",
    "cmap = matplotlib.cm.cool\n",
    "norm = matplotlib.colors.Normalize(vmin=0, vmax=max_alpha*0.25)\n",
    "\n",
    "fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.jet), cax=cax, orientation='horizontal', label=\"$\\\\alpha$\", ticks = [1,3,5,7,9,11,13,15,17,19])\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"./figures/different_alpha.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c63a4dd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:41.109441Z",
     "start_time": "2021-09-25T08:56:38.958850Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(4, 4))\n",
    "\n",
    "ax3 = fig.add_subplot(111)\n",
    "\n",
    "max_alpha = 50\n",
    "\n",
    "depth =  2\n",
    "for alpha in range(1, max_alpha, 1):\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha*0.25, max_depth=2) # depth does not affect the result\n",
    "    ax3.plot(inner_product_list[0], [2*i for i in tau_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    ax3.set_ylabel(\"${2\\mathcal{T}}(x_i, x_j)$\")\n",
    "    ax3.set_xlabel(\"inner product of the inputs\")\n",
    "    ax3.grid(linestyle=\"dotted\")\n",
    "    \n",
    "    #ax4.plot(inner_product_list[0], [i for i in tau_dot_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    #ax4.set_title(\"$\\dot{\\mathcal{T}}(x_i, x_j)$\")\n",
    "    #ax4.set_xlabel(\"inner product of the inputs\")\n",
    "    #ax4.grid(linestyle=\"dotted\")\n",
    "    \n",
    "cax = plt.axes([1.00, 0.2, 0.04, 0.74]) \n",
    "cmap = matplotlib.cm.cool\n",
    "norm = matplotlib.colors.Normalize(vmin=0, vmax=max_alpha*0.25)\n",
    "\n",
    "fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.jet), cax=cax, orientation='vertical', label=\"$\\\\alpha$\", ticks = [1,3,5,7,9,11,13,15,17,19])\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"./figures/tau.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3ce9dd0",
   "metadata": {},
   "source": [
    "## MLPとの比較"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe4a32c6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:41.129821Z",
     "start_time": "2021-09-25T08:56:41.117760Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_kernel_mlp(alpha, max_depth, mode=\"tree\"):\n",
    "    assert mode in (\"mlp_relu\", \"mlp_erf\")\n",
    "    def rotation_o(u, t, deg=False):    \n",
    "        # 度数単位の角度をラジアンに変換\n",
    "        if deg == True:\n",
    "            t = np.deg2rad(t)\n",
    "\n",
    "        # 回転行列\n",
    "        R = np.array([[np.cos(t), -np.sin(t)],\n",
    "                      [np.sin(t),  np.cos(t)]])\n",
    "        return  np.dot(R, u)    \n",
    "\n",
    "    u = (1.0, 0.0)\n",
    "    \n",
    "    kernel_list = []\n",
    "    tau_list = []\n",
    "    tau_dot_list = []    \n",
    "    inner_product_list = []\n",
    "     \n",
    "    for depth in range(1,max_depth,1):\n",
    "        kernel = []\n",
    "        taus = []\n",
    "        tau_dots = []\n",
    "        inner_product = []\n",
    "        for i in range(360):\n",
    "            Ru = rotation_o(u, i*np.pi/180)\n",
    "            if mode == \"mlp_relu\":\n",
    "                H = ntk.mlp_relu(np.vstack([u, Ru]), max_depth=depth)\n",
    "                tau = np.zeros_like(H)\n",
    "                tau_dot =np.zeros_like(H)\n",
    "            elif mode == \"mlp_erf\":\n",
    "                H = ntk.mlp_erf(np.vstack([u, Ru]), max_depth=depth, alpha=alpha)\n",
    "                tau =np.zeros_like(H)\n",
    "                tau_dot = np.zeros_like(H)\n",
    "            else:\n",
    "                NotImplementedError\n",
    "            kernel.append(H[depth-1, 1,0])\n",
    "            inner_product.append(np.dot(u, Ru))\n",
    "        kernel_list.append(kernel)\n",
    "        inner_product_list.append(inner_product)\n",
    "    return kernel_list, inner_product_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e95392d2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:59.860093Z",
     "start_time": "2021-09-25T08:56:41.139113Z"
    }
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,4))\n",
    "max_depth = 30\n",
    "alpha=4.0\n",
    "\n",
    "plt.subplot(1,2,1)\n",
    "kernel_list, inner_product_list = get_kernel_mlp(alpha=alpha, max_depth=max_depth, mode=\"mlp_relu\")\n",
    "for i in range(1, max_depth-1, 1):\n",
    "    plt.plot(inner_product_list[0], [j/max(kernel_list[i]) for j in kernel_list[i]], color=cm.cool(i/max_depth))\n",
    "plt.title(\"MLP(ReLU)\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,2,2)\n",
    "kernel_list,_,_, inner_product_list = get_kernel(alpha=alpha, max_depth=max_depth, mode=\"tree\")\n",
    "for i in range(max_depth-2):\n",
    "    plt.plot(inner_product_list[0], [j/max(kernel_list[i]) for j in kernel_list[i]], color=cm.cool(i/max_depth))\n",
    "plt.title(\"Tree\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa4acd93",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:59.913692Z",
     "start_time": "2021-09-25T08:56:59.863660Z"
    }
   },
   "outputs": [],
   "source": [
    "f = open(\"../data/abalone/abalone_R.dat\", \"r\").readlines()[1:]\n",
    "X = np.asarray(list(map(lambda x: list(map(float, x.split()[1:-1])), f)))\n",
    "X = X[0:1000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ba8fc3c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:57:01.455040Z",
     "start_time": "2021-09-25T08:56:59.916641Z"
    }
   },
   "outputs": [],
   "source": [
    "H_mlp = ntk.mlp_relu(X, max_depth=10)\n",
    "H_tree = ntk.tree(X, max_depth=10, alpha=2.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c26cf57",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T09:05:39.628514Z",
     "start_time": "2021-09-25T09:05:39.610377Z"
    }
   },
   "outputs": [],
   "source": [
    "H_mlp.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1e65231",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T09:04:18.311551Z",
     "start_time": "2021-09-25T09:04:12.350992Z"
    }
   },
   "outputs": [],
   "source": [
    "res = []\n",
    "for i in range(1, 10, 1):\n",
    "    v_mlp = np.linalg.eigvals(H_mlp[i])\n",
    "    res.append(max(v_mlp)/min(v_mlp))\n",
    "    plt.plot(v_mlp)\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71801d50",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T09:05:52.038416Z",
     "start_time": "2021-09-25T09:05:52.030179Z"
    }
   },
   "outputs": [],
   "source": [
    "H_tree[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de5cd4c4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T09:03:42.288972Z",
     "start_time": "2021-09-25T09:03:36.696124Z"
    }
   },
   "outputs": [],
   "source": [
    "res = []\n",
    "for i in range(1, 10, 1):\n",
    "    v_tree = np.linalg.eigvals(H_tree[0][i])\n",
    "    res.append(max(v_tree)/min(v_tree))\n",
    "    plt.plot(v_tree)\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ec4a605",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T09:03:44.983541Z",
     "start_time": "2021-09-25T09:03:44.758331Z"
    }
   },
   "outputs": [],
   "source": [
    "plt.plot(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d65e7e2",
   "metadata": {},
   "source": [
    "# 異なる深さごとでの可視化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1407c29e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T06:45:39.225772Z",
     "start_time": "2021-09-25T06:45:13.418277Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 4))\n",
    "ax1 = fig.add_subplot(141)\n",
    "ax2 = fig.add_subplot(142)\n",
    "ax3 = fig.add_subplot(143)\n",
    "ax4 = fig.add_subplot(144)\n",
    "\n",
    "for i, alpha in enumerate(range(1, 5, 1)):\n",
    "    max_depth = 31\n",
    "\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel((2**(alpha-2)), max_depth)\n",
    "\n",
    "    for depth in range(1, len(kernel_list)+1, 1):\n",
    "        im = eval(f\"ax{alpha}\").plot(inner_product_list[0], [i/max(kernel_list[depth-1]) for i in kernel_list[depth-1]], color=cm.cool(depth/max_depth))    \n",
    "    eval(f\"ax{alpha}\").set_title(\"Normalized \"+\"$\\Theta^{(d)}(x_i, x_j), \\\\alpha=$\"+f\"{(2**(alpha-2))}\")\n",
    "    eval(f\"ax{alpha}\").grid(linestyle=\"dotted\")\n",
    "    eval(f\"ax{alpha}\").set_xlabel(\"Inner product of the inputs\")\n",
    "    eval(f\"ax{alpha}\").set_ylim(-0.6,1.1)\n",
    "    \n",
    "cax = plt.axes([0.14, -0.05, 0.75, 0.04]) \n",
    "cmap = matplotlib.cm.cool\n",
    "norm = matplotlib.colors.Normalize(vmin=1, vmax=max_depth-1)\n",
    "\n",
    "fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.cool), cax=cax, orientation='horizontal', label=\"Depth\", ticks = [1,3,5,7,9,11,13,15,17,19])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"./figures/different_depth.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7738540c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T06:57:54.384700Z",
     "start_time": "2021-09-25T06:57:38.918177Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 4))\n",
    "ax1 = fig.add_subplot(143)\n",
    "ax2 = fig.add_subplot(144)\n",
    "ax3 = fig.add_subplot(141)\n",
    "ax4 = fig.add_subplot(142)\n",
    "\n",
    "max_alpha = 50\n",
    "\n",
    "alpha = 2\n",
    "max_depth = 31\n",
    "\n",
    "kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel((2**(alpha-2)), max_depth)\n",
    "\n",
    "for depth in range(1, len(kernel_list)+1, 1):\n",
    "    im = eval(f\"ax{alpha}\").plot(inner_product_list[0], [i/max(kernel_list[depth-1]) for i in kernel_list[depth-1]], color=cm.cool(depth/max_depth))    \n",
    "eval(f\"ax{alpha}\").set_title(\"Normalized \"+\"$\\Theta^{(d)}(x_i, x_j), \\\\alpha=$\"+f\"{(2**(alpha-2))}\")\n",
    "eval(f\"ax{alpha}\").grid(linestyle=\"dotted\")\n",
    "eval(f\"ax{alpha}\").set_xlabel(\"Inner product of the inputs\")\n",
    "eval(f\"ax{alpha}\").set_ylim(-0.1,1.1)\n",
    "\n",
    "depth =  2\n",
    "for alpha in range(1, max_alpha, 1):\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha*0.25, max_depth=depth+1)\n",
    "    ax1.plot(inner_product_list[0], [i/max(kernel_list[::-1][0]) for i in kernel_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    #ax1.plot(inner_product_list[0], [i for i in kernel_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)        \n",
    "    ax1.set_title(\"Normalized $\\Theta^{(2)}(x_i, x_j)$\")\n",
    "    ax1.set_xlabel(\"inner product of the inputs\")\n",
    "    ax1.set_ylim(-0.1,1.1)\n",
    "    \n",
    "ax1.grid(linestyle=\"dotted\")\n",
    "\n",
    "for alpha in range(1, max_alpha, 1):\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha*0.25, max_depth=2) # depth does not affect the result\n",
    "    ax3.plot(inner_product_list[0], [i for i in tau_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    ax3.set_title(\"${\\mathcal{T}}(x_i, x_j)$\")\n",
    "    ax3.set_xlabel(\"inner product of the inputs\")\n",
    "    ax3.grid(linestyle=\"dotted\")\n",
    "    \n",
    "    ax4.plot(inner_product_list[0], [i for i in tau_dot_list[::-1][0]], label=depth, color=cm.jet(alpha/max_alpha), linewidth=1)    \n",
    "    ax4.set_title(\"$\\dot{\\mathcal{T}}(x_i, x_j)$\")\n",
    "    ax4.set_xlabel(\"inner product of the inputs\")\n",
    "    ax4.grid(linestyle=\"dotted\")\n",
    "    \n",
    "cax = plt.axes([0.05, -0.05, 0.68, 0.04]) \n",
    "norm = matplotlib.colors.Normalize(vmin=0, vmax=max_alpha*0.25)\n",
    "fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.jet), cax=cax, orientation='horizontal', label=\"$\\\\alpha$\", ticks = [1,3,5,7,9,11,13,15,17,19])\n",
    "\n",
    "cax = plt.axes([0.79, -0.05, 0.19, 0.04]) \n",
    "norm = matplotlib.colors.Normalize(vmin=0, vmax=max_depth)\n",
    "fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.cool), cax=cax, orientation='horizontal', label=\"$d$\", ticks = [5,10,15,20,25])\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"./figures/different_param.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00457a44",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9cfea81",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T06:37:56.422196Z",
     "start_time": "2021-09-25T06:37:43.341234Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 4))\n",
    "ax1 = fig.add_subplot(141)\n",
    "ax2 = fig.add_subplot(142)\n",
    "ax3 = fig.add_subplot(143)\n",
    "ax4 = fig.add_subplot(144)\n",
    "\n",
    "for i, alpha in enumerate(range(4, 8, 1)):\n",
    "    max_depth = 21\n",
    "\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel((2**(alpha-2)), max_depth)\n",
    "\n",
    "    for depth in range(1, len(kernel_list)+1, 1):\n",
    "        im = eval(f\"ax{alpha-3}\").plot(inner_product_list[0], [i/max(kernel_list[depth-1]) for i in kernel_list[depth-1]], color=cm.cool(depth/max_depth))    \n",
    "    eval(f\"ax{alpha-3}\").set_title(\"Normalized \"+\"$\\Theta^{(d)}(x_i, x_j), \\\\alpha=$\"+f\"{(2**(alpha-2))}\")\n",
    "    eval(f\"ax{alpha-3}\").grid(linestyle=\"dotted\")\n",
    "    eval(f\"ax{alpha-3}\").set_xlabel(\"Inner product of the inputs\")\n",
    "    eval(f\"ax{alpha-3}\").set_yscale(\"log\")\n",
    "    \n",
    "    \n",
    "cax = plt.axes([0.14, -0.05, 0.75, 0.04]) \n",
    "cmap = matplotlib.cm.cool\n",
    "norm = matplotlib.colors.Normalize(vmin=1, vmax=max_depth-1)\n",
    "\n",
    "fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.cool), cax=cax, orientation='horizontal', label=\"Depth\", ticks = [1,3,5,7,9,11,13,15,17,19])\n",
    "\n",
    "plt.tight_layout()\n",
    "#plt.savefig(\"./figures/different_depth.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8853d9b2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T06:38:09.667725Z",
     "start_time": "2021-09-25T06:37:56.423975Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 4))\n",
    "ax1 = fig.add_subplot(141)\n",
    "ax2 = fig.add_subplot(142)\n",
    "ax3 = fig.add_subplot(143)\n",
    "ax4 = fig.add_subplot(144)\n",
    "\n",
    "for i, alpha in enumerate(range(4, 8, 1)):\n",
    "    max_depth = 21\n",
    "\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel((2**(alpha-2)), max_depth)\n",
    "\n",
    "    for depth in range(1, len(kernel_list)+1, 1):\n",
    "        im = eval(f\"ax{alpha-3}\").plot(inner_product_list[0], [-i/max(kernel_list[depth-1]) for i in kernel_list[depth-1]], color=cm.cool(depth/max_depth))    \n",
    "    eval(f\"ax{alpha-3}\").set_title(\"Normalized \"+\"$\\Theta^{(d)}(x_i, x_j), \\\\alpha=$\"+f\"{(2**(alpha-2))}\")\n",
    "    eval(f\"ax{alpha-3}\").grid(linestyle=\"dotted\")\n",
    "    eval(f\"ax{alpha-3}\").set_xlabel(\"Inner product of the inputs\")\n",
    "    eval(f\"ax{alpha-3}\").set_yscale(\"log\")\n",
    "    \n",
    "    \n",
    "cax = plt.axes([0.14, -0.05, 0.75, 0.04]) \n",
    "cmap = matplotlib.cm.cool\n",
    "norm = matplotlib.colors.Normalize(vmin=1, vmax=max_depth-1)\n",
    "\n",
    "fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.cool), cax=cax, orientation='horizontal', label=\"Depth\", ticks = [1,3,5,7,9,11,13,15,17,19])\n",
    "\n",
    "plt.tight_layout()\n",
    "#plt.savefig(\"./figures/different_depth.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d150d858",
   "metadata": {},
   "source": [
    "# 異なるスケーリングでのErfの可視化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3c169f8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T06:38:11.447588Z",
     "start_time": "2021-09-25T06:38:09.669510Z"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "max_alpha = 50\n",
    "fig = plt.figure(figsize=(5, 4))\n",
    "ax = fig.add_subplot(111)\n",
    "for alpha in range(1, max_alpha, 1):\n",
    "    x = np.linspace(-3, 3, 100000)\n",
    "    ax.plot(x, 0.5*special.erf(0.25*alpha*x)+0.5, color=cm.jet(alpha/max_alpha))\n",
    "    ax.set_xlabel('${w}_{m,n}^{\\\\top} {x}_i$')\n",
    "    ax.set_ylabel('$\\sigma({w}_{m,n}^{\\\\top} {x}_i)$')\n",
    "\n",
    "e = math.e\n",
    "y = 1 / (1 + e**-x)\n",
    "plt.plot(x, y, color=\"m\", linestyle=\"dotted\", linewidth=3)\n",
    "\n",
    "plt.tight_layout()\n",
    "ax.grid(linestyle=\"dotted\")\n",
    "\n",
    "\n",
    "cax = plt.axes([1.00, 0.2, 0.04, 0.74]) \n",
    "cmap = matplotlib.cm.cool\n",
    "norm = matplotlib.colors.Normalize(vmin=0, vmax=max_alpha*0.25)\n",
    "\n",
    "fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.jet), cax=cax, orientation='vertical', label=\"$\\\\alpha$\", ticks = [1,3,5,7,9,11,13,15,17,19])\n",
    "plt.savefig(\"./figures/scaling.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27a7fa9d",
   "metadata": {},
   "source": [
    "## 有限モデルとの比較"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fee07abc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T07:19:52.879177Z",
     "start_time": "2021-09-25T07:19:52.533511Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from models import SoftTree\n",
    "from utils import compute_kernels, rotation_o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05003f0c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T07:19:53.140587Z",
     "start_time": "2021-09-25T07:19:53.135127Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b5a3b96",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T07:19:53.664888Z",
     "start_time": "2021-09-25T07:19:53.657662Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_kernel(alpha, n_tree, color, is_first=False):\n",
    "    st = SoftTree(input_dim=2, output_dim=1, max_depth=3, scale=alpha, n_tree=n_tree)\n",
    "    u = (1, 0)\n",
    "    res = []\n",
    "    inner_product = []\n",
    "    for i in tqdm(range(180), leave=False):\n",
    "        Ru = rotation_o(u, i * np.pi / 180)\n",
    "        x = torch.Tensor([u, Ru]).to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        K = compute_kernels(st, x)\n",
    "        res.append(K[1,0].tolist())\n",
    "        inner_product.append(np.dot(u, Ru))\n",
    "    if is_first:\n",
    "        plt.plot(inner_product, res, color=color, linewidth=1, label=f\"$M={n_tree}$\")\n",
    "    else:\n",
    "        plt.plot(inner_product, res, color=color, linewidth=1)        \n",
    "    return inner_product, res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6cf3b10",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T07:40:03.406986Z",
     "start_time": "2021-09-25T07:33:47.666089Z"
    }
   },
   "outputs": [],
   "source": [
    "n_seeds = 10\n",
    "\n",
    "fig = plt.figure(figsize=(13, 7))\n",
    "plt.subplot(1,2,1)\n",
    "depth = 3\n",
    "alpha = 2.0\n",
    "\n",
    "n_tree_combinations = [16, 64, 256, 1024, 4096]\n",
    "for j, n_tree in enumerate(n_tree_combinations):\n",
    "    res_all = []\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha ,depth+1)\n",
    "    for i in tqdm(range(n_seeds), leave=False):\n",
    "        inner_product, res = plot_kernel(alpha, n_tree, color=cm.bwr(j/len(n_tree_combinations)), is_first=(i==0))\n",
    "        res_all.append(res)\n",
    "    mse = 0\n",
    "    for i in range(len(res_all)):\n",
    "        mse+=mean_squared_error(res_all[i], kernel_list[depth-1][0:180])\n",
    "    mse/=len(res_all)\n",
    "plt.plot(inner_product_list[0][0:180], kernel_list[depth-1][0:180],  color=\"black\", linewidth=3, linestyle=\"dotted\", label=\"$M={\\\\infty}$\")\n",
    "plt.title(\"$\\\\alpha=2$, $d=3$\")\n",
    "plt.xlabel(\"Inner product of the inputs\")\n",
    "plt.ylabel(\"$\\widehat{\\Theta}_0(x_i, x_j)$\")\n",
    "plt.legend(prop={'size': 12})\n",
    "plt.ylim(-0.3, 2.5)\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "# -----\n",
    "plt.subplot(2,2,2)\n",
    "n_dataset = 50\n",
    "n_features = 5\n",
    "depth = 3\n",
    "x = torch.Tensor([(j/np.sqrt(sum(j**2))) for j in np.array([np.random.randn(n_features) for i in range(n_dataset)])])\n",
    "\n",
    "norm = colors.LogNorm(vmin=0.5, vmax=100)\n",
    "\n",
    "for j, alpha in enumerate(tqdm((0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0), leave=False)):\n",
    "    res, n_tree_list = [], []\n",
    "    K_infinite,_,_ = ntk.tree(np.array(x), max_depth=depth, alpha=alpha)\n",
    "    K_infinite = K_infinite[-1]\n",
    "\n",
    "    for seed in tqdm(range(n_seeds), leave=False):\n",
    "        res_seed = []\n",
    "        for i in tqdm(range(8), leave=False):\n",
    "            n_tree = 2**(i+5)\n",
    "            st = SoftTree(input_dim=n_features, output_dim=1, max_depth=depth, scale=alpha, n_tree=n_tree)\n",
    "            K_empirical = compute_kernels(st, x)\n",
    "            res_seed.append(torch.norm((K_empirical - K_infinite)/torch.norm(K_empirical)).tolist())\n",
    "            n_tree_list.append(n_tree)\n",
    "        res.append(res_seed)\n",
    "    plt.errorbar(\n",
    "        sorted(np.unique(n_tree_list)), np.array(res).mean(axis=0), \n",
    "        yerr=np.array(res).std(axis=0), \n",
    "        color=cm.magma((j+1)/9), capsize=3, label=\"$\\\\alpha=$\"+f\"{alpha}\"\n",
    "    )\n",
    "\n",
    "plt.plot(sorted(np.unique(n_tree_list)), [0.1*(i**-0.5) for i in sorted(np.unique(n_tree_list))], linestyle=\"dashed\", color=\"black\", label=\"trend: $M^{-1/2}$\")\n",
    "\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "plt.title(\"Change $\\\\alpha$ with fixed $d=3$\")\n",
    "plt.xlabel(\"$M$\")\n",
    "plt.ylabel(\"$\\\\frac{\\left|| \\\\widehat{H}_0 - {H}  \\\\right||_F}{||\\\\widehat{H}_0||_F}$\")\n",
    "\n",
    "plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left', prop={'size': 12})\n",
    "\n",
    "# --------\n",
    "plt.subplot(2,2,4)\n",
    "n_dataset = 50\n",
    "n_features = 5\n",
    "depth = 3\n",
    "alpha = 2.0\n",
    "\n",
    "x = torch.Tensor([(j/np.sqrt(sum(j**2))) for j in np.array([np.random.randn(n_features) for i in range(n_dataset)])])\n",
    "\n",
    "norm = colors.LogNorm(vmin=0.5, vmax=100)\n",
    "\n",
    "for j, depth in enumerate(tqdm(range(1,7,1), leave=False)):\n",
    "    res, n_tree_list = [], []\n",
    "    K_infinite,_,_ = ntk.tree(np.array(x), max_depth=depth, alpha=alpha)\n",
    "    K_infinite = K_infinite[-1]\n",
    "\n",
    "    for seed in tqdm(range(n_seeds), leave=False):\n",
    "        res_seed = []\n",
    "        for i in tqdm(range(8), leave=False):\n",
    "            n_tree = 2**(i+5)\n",
    "            st = SoftTree(input_dim=n_features, output_dim=1, max_depth=depth, scale=alpha, n_tree=n_tree)\n",
    "            K_empirical = compute_kernels(st, x)\n",
    "            res_seed.append(torch.norm((K_empirical - K_infinite)/torch.norm(K_empirical)).tolist())\n",
    "            n_tree_list.append(n_tree)\n",
    "        res.append(res_seed)\n",
    "    plt.errorbar(\n",
    "        sorted(np.unique(n_tree_list)), np.array(res).mean(axis=0), \n",
    "        yerr=np.array(res).std(axis=0), \n",
    "        color=cm.viridis((j+1)/6), capsize=3, label=\"$d=$\"+f\"{int(depth)}\"\n",
    "    )\n",
    "\n",
    "plt.plot(sorted(np.unique(n_tree_list)), [0.1*(i**-0.5) for i in sorted(np.unique(n_tree_list))], linestyle=\"dashed\", color=\"black\", label=\"trend: $M^{-1/2}$\")\n",
    "\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "plt.title(\"Change $d$ with fixed $\\\\alpha=2$\")\n",
    "plt.xlabel(\"$M$\")\n",
    "plt.ylabel(\"$\\\\frac{\\left|| \\\\widehat{H}_0 - {H}  \\\\right||_F}{||\\\\widehat{H}_0||_F}$\")\n",
    "\n",
    "plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left', prop={'size': 12})\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"./figures/finite.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "044e1ec2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T07:20:06.871338Z",
     "start_time": "2021-09-25T07:19:58.115994Z"
    }
   },
   "outputs": [],
   "source": [
    "n_seeds = 10\n",
    "\n",
    "fig = plt.figure(figsize=(13, 7))\n",
    "plt.subplot(1,2,1)\n",
    "depth = 3\n",
    "alpha = 2.0\n",
    "\n",
    "n_tree_combinations = [16]\n",
    "for j, n_tree in enumerate(n_tree_combinations):\n",
    "    res_all = []\n",
    "    kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha ,depth+1)\n",
    "    for i in tqdm(range(n_seeds), leave=False):\n",
    "        inner_product, res = plot_kernel(alpha, n_tree, color=cm.bwr(j/len(n_tree_combinations)), is_first=(i==0))\n",
    "        res_all.append(res)\n",
    "    mse = 0\n",
    "    for i in range(len(res_all)):\n",
    "        mse+=mean_squared_error(res_all[i], kernel_list[depth-1][0:180])\n",
    "    mse/=len(res_all)\n",
    "plt.plot(inner_product_list[0][0:180], kernel_list[depth-1][0:180],  color=\"black\", linewidth=3, linestyle=\"dotted\", label=\"$M={\\\\infty}$\")\n",
    "plt.title(\"$\\\\alpha=2$, depth $=3$\")\n",
    "plt.xlabel(\"Inner product of the inputs\")\n",
    "plt.ylabel(\"$\\widehat{\\Theta}_0(x_i, x_j)$\")\n",
    "plt.legend(prop={'size': 12})\n",
    "#plt.ylim(-0.3, 2.5)\n",
    "plt.grid(linestyle=\"dotted\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64298298",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T06:43:47.542758Z",
     "start_time": "2021-09-25T06:43:47.539371Z"
    }
   },
   "outputs": [],
   "source": [
    "n_features = 5\n",
    "n_dataset = 2\n",
    "x = torch.Tensor([j for j in np.array([np.random.randn(n_features) for i in range(n_dataset)])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35ede989",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T06:43:47.790644Z",
     "start_time": "2021-09-25T06:43:47.547089Z"
    }
   },
   "outputs": [],
   "source": [
    "compute_kernels(st, x)[0][1].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "674aadb4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T06:43:47.816981Z",
     "start_time": "2021-09-25T06:43:47.805191Z"
    }
   },
   "outputs": [],
   "source": [
    "K_infinite,_,_ = ntk.tree(np.array(x), max_depth=depth, alpha=alpha)\n",
    "print(K_infinite[-1][0][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2060685e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-16T14:57:18.332535Z",
     "start_time": "2021-08-16T14:57:18.318872Z"
    },
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "outputs": [],
   "source": [
    "def get_kernel_mlp(alpha, max_depth, mode=\"tree\"):\n",
    "    assert mode in (\"mlp_relu\", \"mlp_erf\")\n",
    "    def rotation_o(u, t, deg=False):    \n",
    "        # 度数単位の角度をラジアンに変換\n",
    "        if deg == True:\n",
    "            t = np.deg2rad(t)\n",
    "\n",
    "        # 回転行列\n",
    "        R = np.array([[np.cos(t), -np.sin(t)],\n",
    "                      [np.sin(t),  np.cos(t)]])\n",
    "        return  np.dot(R, u)    \n",
    "\n",
    "    u = (1.0, 0.0)\n",
    "    \n",
    "    kernel_list = []\n",
    "    tau_list = []\n",
    "    tau_dot_list = []    \n",
    "    inner_product_list = []\n",
    "     \n",
    "    for depth in range(1,max_depth,1):\n",
    "        kernel = []\n",
    "        taus = []\n",
    "        tau_dots = []\n",
    "        inner_product = []\n",
    "        for i in range(360):\n",
    "            Ru = rotation_o(u, i*np.pi/180)\n",
    "            if mode == \"mlp_relu\":\n",
    "                H = ntk.mlp_relu(np.vstack([u, Ru]), max_depth=depth)\n",
    "                tau = np.zeros_like(H)\n",
    "                tau_dot =np.zeros_like(H)\n",
    "            elif mode == \"mlp_erf\":\n",
    "                H = ntk.mlp_erf(np.vstack([u, Ru]), max_depth=depth, alpha=alpha)\n",
    "                tau =np.zeros_like(H)\n",
    "                tau_dot = np.zeros_like(H)\n",
    "            else:\n",
    "                NotImplementedError\n",
    "            kernel.append(H[depth-1, 1,0])\n",
    "            inner_product.append(np.dot(u, Ru))\n",
    "        kernel_list.append(kernel)\n",
    "        inner_product_list.append(inner_product)\n",
    "    return kernel_list, inner_product_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2382f184",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-16T14:57:44.619334Z",
     "start_time": "2021-08-16T14:57:18.361075Z"
    },
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,4))\n",
    "max_depth = 30\n",
    "alpha=4.0\n",
    "\n",
    "plt.subplot(1,3,1)\n",
    "kernel_list, inner_product_list = get_kernel_mlp(alpha=alpha, max_depth=max_depth, mode=\"mlp_relu\")\n",
    "for i in range(1, max_depth-1, 1):\n",
    "    plt.plot(inner_product_list[0], kernel_list[i], color=cm.cool(i/max_depth))\n",
    "plt.title(\"MLP(ReLU)\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,3,2)\n",
    "kernel_list, inner_product_list = get_kernel_mlp(alpha=alpha, max_depth=max_depth, mode=\"mlp_erf\")\n",
    "for i in range(1, max_depth-1, 1):\n",
    "    plt.plot(inner_product_list[0], kernel_list[i], color=cm.cool(i/max_depth))\n",
    "plt.title(\"MLP(Erf)\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,3,3)\n",
    "kernel_list,_,_, inner_product_list = get_kernel(alpha=alpha, max_depth=max_depth, mode=\"tree\")\n",
    "for i in range(max_depth-2):\n",
    "    plt.plot(inner_product_list[0], kernel_list[i], color=cm.cool(i/max_depth))\n",
    "plt.title(\"Tree\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c3453d5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-16T14:58:10.451525Z",
     "start_time": "2021-08-16T14:57:44.622552Z"
    },
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,4))\n",
    "max_depth = 30\n",
    "alpha=4.0\n",
    "\n",
    "plt.subplot(1,3,1)\n",
    "kernel_list, inner_product_list = get_kernel_mlp(alpha=alpha, max_depth=max_depth, mode=\"mlp_relu\")\n",
    "for i in range(1, max_depth-1, 1):\n",
    "    plt.plot(inner_product_list[0], [j/max(kernel_list[i]) for j in kernel_list[i]], color=cm.cool(i/max_depth))\n",
    "plt.title(\"MLP(ReLU) (normed)\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,3,2)\n",
    "kernel_list, inner_product_list = get_kernel_mlp(alpha=alpha, max_depth=max_depth, mode=\"mlp_erf\")\n",
    "for i in range(1, max_depth-1, 1):\n",
    "    plt.plot(inner_product_list[0], [j/max(kernel_list[i]) for j in kernel_list[i]], color=cm.cool(i/max_depth))\n",
    "plt.title(\"MLP(Erf) (normed)\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,3,3)\n",
    "kernel_list,_,_, inner_product_list = get_kernel(alpha=alpha, max_depth=max_depth, mode=\"tree\")\n",
    "for i in range(max_depth-2):\n",
    "    plt.plot(inner_product_list[0], [j/max(kernel_list[i]) for j in kernel_list[i]], color=cm.cool(i/max_depth))\n",
    "plt.title(\"Tree (normed)\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1ad8940",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-16T14:58:25.439980Z",
     "start_time": "2021-08-16T14:58:10.454504Z"
    },
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,4))\n",
    "max_depth = 30\n",
    "alpha=4.0\n",
    "\n",
    "plt.subplot(1,3,1)\n",
    "kernel_list, inner_product_list = get_kernel_mlp(alpha=alpha, max_depth=max_depth, mode=\"mlp_relu\")\n",
    "for i in range(1, max_depth-1, 1):\n",
    "    kernel_list[i] = [j/max(kernel_list[i]) for j in kernel_list[i]]\n",
    "    plt.plot(inner_product_list[0][1:179], [kernel_list[i][j]-kernel_list[i][j+1] for j in range(1, 179, 1)], color=cm.cool(i/max_depth))\n",
    "plt.title(\"MLP(ReLU) (normed)\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,3,2)\n",
    "kernel_list, inner_product_list = get_kernel_mlp(alpha=alpha, max_depth=max_depth, mode=\"mlp_erf\")\n",
    "for i in range(1, max_depth-1, 1):\n",
    "    kernel_list[i] = [j/max(kernel_list[i]) for j in kernel_list[i]]\n",
    "    plt.plot(inner_product_list[0][1:179], [kernel_list[i][j]-kernel_list[i][j+1] for j in range(1, 179, 1)], color=cm.cool(i/max_depth))\n",
    "plt.title(\"MLP(Erf) (normed)\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,3,3)\n",
    "kernel_list,_,_, inner_product_list = get_kernel(alpha=alpha, max_depth=max_depth, mode=\"tree\")\n",
    "for i in range(max_depth-2):\n",
    "    kernel_list[i] = [j/max(kernel_list[i]) for j in kernel_list[i]]\n",
    "    plt.plot(inner_product_list[0][1:179], [kernel_list[i][j]-kernel_list[i][j+1] for j in range(1, 179, 1)], color=cm.cool(i/max_depth))\n",
    "plt.title(\"Tree (normed)\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fa91c3b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-16T14:58:25.447976Z",
     "start_time": "2021-08-16T14:56:17.301Z"
    },
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,4))\n",
    "max_depth = 30\n",
    "alpha=4.0\n",
    "\n",
    "plt.subplot(1,3,1)\n",
    "kernel_list, inner_product_list = get_kernel_mlp(alpha=alpha, max_depth=max_depth, mode=\"mlp_relu\")\n",
    "for i in range(1, max_depth-1, 1):\n",
    "    kernel_list[i] = [j/max(kernel_list[i]) for j in kernel_list[i]]\n",
    "    plt.plot(inner_product_list[0][1:179], [kernel_list[i][j]-kernel_list[i][j+1] for j in range(1, 179, 1)], color=cm.cool(i/max_depth))\n",
    "plt.title(\"MLP(ReLU) (normed)\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,3,2)\n",
    "kernel_list, inner_product_list = get_kernel_mlp(alpha=alpha, max_depth=max_depth, mode=\"mlp_erf\")\n",
    "for i in range(1, max_depth-1, 1):\n",
    "    kernel_list[i] = [j/max(kernel_list[i]) for j in kernel_list[i]]\n",
    "    plt.plot(inner_product_list[0][1:179], [kernel_list[i][j]-kernel_list[i][j+1] for j in range(1, 179, 1)], color=cm.cool(i/max_depth))\n",
    "plt.title(\"MLP(Erf) (normed)\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,3,3)\n",
    "kernel_list,_,_, inner_product_list = get_kernel(alpha=alpha, max_depth=max_depth, mode=\"tree\")\n",
    "for i in range(max_depth-2):\n",
    "    kernel_list[i] = [j/max(kernel_list[i]) for j in kernel_list[i]]\n",
    "    plt.plot(inner_product_list[0][1:179], [kernel_list[i][j]-kernel_list[i][j+1] for j in range(1, 179, 1)], color=cm.cool(i/max_depth))\n",
    "plt.title(\"Tree (normed)\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c56f40b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-08-16T14:58:25.451837Z",
     "start_time": "2021-08-16T14:56:17.303Z"
    },
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,4))\n",
    "max_depth = 30\n",
    "alpha=4.0\n",
    "\n",
    "plt.subplot(1,3,1)\n",
    "kernel_list, inner_product_list = get_kernel_mlp(alpha=alpha, max_depth=max_depth, mode=\"mlp_relu\")\n",
    "for i in range(1, max_depth-1, 1):\n",
    "    kernel_list[i] = [j/max(kernel_list[i]) for j in kernel_list[i]]\n",
    "    plt.plot(inner_product_list[0][1:179], [abs(kernel_list[i][j]-kernel_list[i][j+1]) for j in range(1, 179, 1)], color=cm.cool(i/max_depth))\n",
    "plt.title(\"MLP(ReLU) (normed)\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,3,2)\n",
    "kernel_list, inner_product_list = get_kernel_mlp(alpha=alpha, max_depth=max_depth, mode=\"mlp_erf\")\n",
    "for i in range(1, max_depth-1, 1):\n",
    "    kernel_list[i] = [j/max(kernel_list[i]) for j in kernel_list[i]]\n",
    "    plt.plot(inner_product_list[0][1:179], [abs(kernel_list[i][j]-kernel_list[i][j+1]) for j in range(1, 179, 1)], color=cm.cool(i/max_depth))\n",
    "plt.title(\"MLP(Erf) (normed)\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,3,3)\n",
    "kernel_list,_,_, inner_product_list = get_kernel(alpha=alpha, max_depth=max_depth, mode=\"tree\")\n",
    "for i in range(max_depth-2):\n",
    "    kernel_list[i] = [j/max(kernel_list[i]) for j in kernel_list[i]]\n",
    "    plt.plot(inner_product_list[0][1:179],[abs(kernel_list[i][j]-kernel_list[i][j+1]) for j in range(1, 179, 1)], color=cm.cool(i/max_depth))\n",
    "plt.title(\"Tree (normed)\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3a1e6d6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
