{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "temporal-watershed",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:25.807349Z",
     "start_time": "2021-09-25T08:56:25.801536Z"
    }
   },
   "outputs": [],
   "source": [
    "def set_latex():\n",
    "    for i in range(2):\n",
    "        import matplotlib\n",
    "        import matplotlib.pyplot as plt\n",
    "\n",
    "        plt.rc('text', usetex=True)\n",
    "        plt.rc('font', family='serif')\n",
    "\n",
    "        plt.style.use(\"default\")\n",
    "        plt.rcParams[\"font.size\"]=15\n",
    "\n",
    "        plt.rcParams['font.family'] = 'Times New Roman'  # font familyの設定\n",
    "        plt.rcParams['mathtext.fontset'] = 'stix'  # math fontの設定\n",
    "\n",
    "        try:\n",
    "            del matplotlib.font_manager.weight_dict['roman']\n",
    "            matplotlib.font_manager._rebuild()\n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "completed-chance",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:28.621842Z",
     "start_time": "2021-09-25T08:56:26.049699Z"
    }
   },
   "outputs": [],
   "source": [
    "set_latex()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "operational-physiology",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:30.580638Z",
     "start_time": "2021-09-25T08:56:28.624803Z"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import numpy as np\n",
    "import ntk\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import warnings \n",
    "from scipy import special\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.colors as colors\n",
    "\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "atomic-characteristic",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:30.600770Z",
     "start_time": "2021-09-25T08:56:30.586043Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_kernel(alpha, max_depth, mode=\"tree\"):\n",
    "    assert mode in (\"tree\", \"asymtree\", \"inf_asymtree\", \"mlp_relu\", \"mlp_erf\")\n",
    "    def rotation_o(u, t, deg=False):    \n",
    "        if deg == True:\n",
    "            t = np.deg2rad(t)\n",
    "\n",
    "        R = np.array([[np.cos(t), -np.sin(t)],\n",
    "                      [np.sin(t),  np.cos(t)]])\n",
    "        return  np.dot(R, u)    \n",
    "\n",
    "    u = (1.0, 0.0)\n",
    "    \n",
    "    kernel_list = []\n",
    "    tau_list = []\n",
    "    tau_dot_list = []    \n",
    "    inner_product_list = []\n",
    "     \n",
    "    for depth in range(1,max_depth,1):\n",
    "        kernel = []\n",
    "        taus = []\n",
    "        tau_dots = []\n",
    "        inner_product = []\n",
    "        for i in range(360):\n",
    "            Ru = rotation_o(u, i*np.pi/180)\n",
    "            if mode == \"tree\":\n",
    "                H, tau, tau_dot = ntk.tree_viz(np.vstack([u, Ru]), max_depth=depth, alpha=alpha)\n",
    "            if mode == \"asymtree\":\n",
    "                H, tau, tau_dot = ntk.asymtree_viz(np.vstack([u, Ru]), max_depth=depth, alpha=alpha)\n",
    "            if mode == \"inf_asymtree\":\n",
    "                H, tau, tau_dot = ntk.inf_asymtree_viz(np.vstack([u, Ru]), max_depth=depth, alpha=alpha)\n",
    "            else:\n",
    "                NotImplementedError\n",
    "            kernel.append(H[depth-1, 1,0])\n",
    "            taus.append(tau[1,0])\n",
    "            tau_dots.append(tau_dot[1,0])\n",
    "            inner_product.append(np.dot(u, Ru))\n",
    "        kernel_list.append(kernel)\n",
    "        tau_list.append(taus)\n",
    "        tau_dot_list.append(tau_dots)\n",
    "        inner_product_list.append(inner_product)\n",
    "    return kernel_list, tau_list, tau_dot_list, inner_product_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be0c5700",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 4))\n",
    "ax1 = fig.add_subplot(121)\n",
    "ax2 = fig.add_subplot(122)\n",
    "\n",
    "alpha = 4\n",
    "max_depth = 21\n",
    "\n",
    "kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel((2**(alpha-2)), max_depth, mode=\"tree\")\n",
    "\n",
    "for depth in range(1, len(kernel_list)+1, 1):\n",
    "    im =ax1.plot(inner_product_list[0], [i/max(kernel_list[depth-1]) for i in kernel_list[depth-1]], color=cm.cool(depth/max_depth))    \n",
    "ax1.set_title(\"Perfect binary tree\")\n",
    "ax1.grid(linestyle=\"dotted\")\n",
    "ax1.set_xlabel(\"Inner product of the inputs\")\n",
    "ax1.set_ylabel(\"Normalized kernel value, $\\\\alpha=4.0$\")\n",
    "ax1.set_ylim(-0.2,1.1)\n",
    "\n",
    "kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel((2**(alpha-2)), max_depth, mode=\"asymtree\")\n",
    "\n",
    "for depth in range(1, len(kernel_list)+1, 1):\n",
    "    im =ax2.plot(inner_product_list[0], [i/max(kernel_list[depth-1]) for i in kernel_list[depth-1]], color=cm.cool(depth/max_depth))\n",
    "\n",
    "kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel((2**(alpha-2)), max_depth, mode=\"inf_asymtree\")\n",
    "ax2.plot(inner_product_list[0][0:180], [i/max(kernel_list[depth-1]) for i in kernel_list[depth-1][0:180]], color=\"black\", linestyle=\"dotted\", label=\"Infinite depth limit\")\n",
    "\n",
    "ax2.set_title(\"Decision list\")\n",
    "ax2.grid(linestyle=\"dotted\")\n",
    "ax2.set_xlabel(\"Inner product of the inputs\")\n",
    "ax2.set_ylim(-0.2,1.1)\n",
    "plt.legend()\n",
    "\n",
    "cax = plt.axes([0.13, -0.05, 0.83, 0.04])\n",
    "norm = matplotlib.colors.Normalize(vmin=0, vmax=max_depth)\n",
    "fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.cool), cax=cax, orientation='horizontal', label=\"$D$\", ticks = [5,10,15])\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"./figures/different_param_asym.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "polished-message",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T07:19:52.879177Z",
     "start_time": "2021-09-25T07:19:52.533511Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from models import SoftTree\n",
    "from utils import compute_kernels, rotation_o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "latter-dealing",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T07:19:53.140587Z",
     "start_time": "2021-09-25T07:19:53.135127Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "spectacular-chicago",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T07:19:53.664888Z",
     "start_time": "2021-09-25T07:19:53.657662Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_kernel(alpha, n_tree, depth, color, is_first=False, asym=False):\n",
    "    st = SoftTree(input_dim=2, output_dim=1, max_depth=depth, scale=alpha, n_tree=n_tree, asym=asym)\n",
    "    u = (1, 0)\n",
    "    res = []\n",
    "    inner_product = []\n",
    "    for i in tqdm(range(180), leave=False):\n",
    "        Ru = rotation_o(u, i * np.pi / 180)\n",
    "        x = torch.Tensor([u, Ru]).to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        K = compute_kernels(st, x)\n",
    "        res.append(K[1,0].tolist())\n",
    "        inner_product.append(np.dot(u, Ru))\n",
    "    if is_first:\n",
    "        plt.plot(inner_product, res, color=color, linewidth=1, label=f\"$M={n_tree}$\")\n",
    "    else:\n",
    "        plt.plot(inner_product, res, color=color, linewidth=1)        \n",
    "    return inner_product, res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sensitive-writer",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T07:40:03.406986Z",
     "start_time": "2021-09-25T07:33:47.666089Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_kernel_finite(depth: int, alpha: float, asym: bool, n_seeds: int = 10)->None:\n",
    "    n_tree_combinations = [16, 64, 256, 1024, 4096]\n",
    "    for j, n_tree in enumerate(tqdm(n_tree_combinations, leave=False)):\n",
    "        res_all = []\n",
    "        kernel_list, tau_list, tau_dot_list, inner_product_list = get_kernel(alpha ,depth+1, mode=\"asymtree\" if asym else \"tree\")\n",
    "        for i in tqdm(range(n_seeds), leave=False):\n",
    "            inner_product, res = plot_kernel(alpha, n_tree, depth, color=cm.bwr(j/len(n_tree_combinations)), is_first=(i==0), asym=asym)\n",
    "            res_all.append(res)\n",
    "        mse = 0\n",
    "        for i in range(len(res_all)):\n",
    "            mse+=mean_squared_error(res_all[i], kernel_list[depth-1][0:180])\n",
    "        mse/=len(res_all)\n",
    "    plt.plot(inner_product_list[0][0:180], kernel_list[depth-1][0:180],  color=\"black\", linewidth=3, linestyle=\"dotted\", label=\"$M={\\\\infty}$\")\n",
    "    plt.title(\"Decision list\" if asym else \"Perfect binary tree\")\n",
    "    plt.xlabel(\"Inner product of the inputs\")\n",
    "    if not asym:\n",
    "        plt.ylabel(\"Kernel value, \"+ \"$\\\\alpha=$\"+ str(float(alpha)) + \", $D=$\" + str(int(depth)))\n",
    "    plt.ylim(-0.3, 2.5)\n",
    "    plt.grid(linestyle=\"dotted\")\n",
    "    plt.legend(loc='upper left', prop={'size': 12})\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "large-commissioner",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 4))\n",
    "plt.subplot(1,2,1)\n",
    "plot_kernel_finite(depth=5, alpha=4.0, asym=False, n_seeds=10)\n",
    "plt.subplot(1,2,2)\n",
    "plot_kernel_finite(depth=5, alpha=4.0, asym=True, n_seeds=10)\n",
    "plt.savefig(\"./figures/kernels_asym.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "loose-terrorist",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
