{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "temporal-watershed",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:25.807349Z",
     "start_time": "2021-09-25T08:56:25.801536Z"
    }
   },
   "outputs": [],
   "source": [
    "def set_latex():\n",
    "    for i in range(2):\n",
    "        import matplotlib\n",
    "        import matplotlib.pyplot as plt\n",
    "\n",
    "        plt.rc('text', usetex=True)\n",
    "        plt.rc('font', family='serif')\n",
    "\n",
    "        plt.style.use(\"default\")\n",
    "        plt.rcParams[\"font.size\"]=15\n",
    "\n",
    "        plt.rcParams['font.family'] = 'Times New Roman'\n",
    "        plt.rcParams['mathtext.fontset'] = 'stix'\n",
    "\n",
    "        try:\n",
    "            del matplotlib.font_manager.weight_dict['roman']\n",
    "            matplotlib.font_manager._rebuild()\n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "completed-chance",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:28.621842Z",
     "start_time": "2021-09-25T08:56:26.049699Z"
    }
   },
   "outputs": [],
   "source": [
    "set_latex()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "operational-physiology",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:30.580638Z",
     "start_time": "2021-09-25T08:56:28.624803Z"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import warnings \n",
    "from scipy import special\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.colors as colors\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46fbc6f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rotation_o(u, t, deg=False):\n",
    "    if deg == True:\n",
    "        t = np.deg2rad(t)\n",
    "\n",
    "    R = np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]])\n",
    "    return np.dot(R, u)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "350c4402",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_tau(alpha: float, S: np.array, diag_i: np.array, diag_j: np.array) -> np.array:\n",
    "    tau = 1 / 4 + 1 / (2 * math.pi) * np.arcsin(\n",
    "        ((alpha ** 2) * S)\n",
    "        / (np.sqrt(((alpha ** 2) * diag_i + 0.5) * ((alpha ** 2) * diag_j + 0.5)))\n",
    "    )\n",
    "    return tau\n",
    "\n",
    "\n",
    "def calc_tau_dot(\n",
    "    alpha: float, S: np.array, diag_i: np.array, diag_j: np.array\n",
    ") -> np.array:\n",
    "    tau_dot = (\n",
    "        (alpha ** 2)\n",
    "        / (math.pi)\n",
    "        * 1\n",
    "        / np.sqrt(\n",
    "            (2 * (alpha ** 2) * diag_i + 1) * (2 * (alpha ** 2) * diag_j + 1)\n",
    "            - (4 * (alpha ** 4) * (S ** 2))\n",
    "        )\n",
    "    )\n",
    "    return tau_dot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93d88d4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def hardtree_viz(\n",
    "    X: np.array, alpha: float, beta: float, finetune: bool, arch: int\n",
    "):\n",
    "    \n",
    "    assert arch in {0, 1}\n",
    "    \n",
    "    S_list = []\n",
    "    tau_list = []\n",
    "    tau_dot_list = []\n",
    "\n",
    "    for feature_index in range(len(X[0])):\n",
    "        S = np.outer(X[:, feature_index], X[:, feature_index].T) + beta**2\n",
    "        S_all = np.matmul(X, X.T) + beta**2\n",
    "        if finetune:\n",
    "            S_list.append(S_all)\n",
    "        else:\n",
    "            S_list.append(S)\n",
    "\n",
    "        _diag = [S[i, i] for i in range(len(S))]\n",
    "        diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))\n",
    "        diag_j = diag_i.transpose()\n",
    "        tau_list.append(calc_tau(alpha, S, diag_i, diag_j))\n",
    "        tau_dot_list.append(calc_tau_dot(alpha, S, diag_i, diag_j))\n",
    "        \n",
    "    K = np.zeros((X.shape[0], X.shape[0]))\n",
    "    if arch==0:\n",
    "        rulelist = [[0, 0], [0, 0], [0, 0], [0, 0]]\n",
    "    elif arch==1:\n",
    "        rulelist = [[0, 1], [0, 1], [0, 1], [0, 1]]  \n",
    "    H = np.zeros_like(S_list[0])\n",
    "    for rules in rulelist:        \n",
    "        # Internal nodes\n",
    "        for i, s in enumerate(rules):\n",
    "            ts = rules[0:i]+rules[i+1:]\n",
    "            _H_nodes = S_list[s]* tau_dot_list[s]\n",
    "            for t in ts:\n",
    "                _H_nodes *= tau_list[t]\n",
    "            K+= _H_nodes\n",
    "        _H_leaves = np.ones_like(K)\n",
    "\n",
    "        # Leaves\n",
    "        for tau in [tau_list[i] for i in rules]:\n",
    "            _H_leaves *= tau\n",
    "        K += _H_leaves\n",
    "    \n",
    "    return K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b005cbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def softtree_viz(X: np.array, alpha:float, beta:float, max_depth:int):\n",
    "    K = np.zeros((max_depth, X.shape[0], X.shape[0]))\n",
    "    S = np.matmul(X, X.T) + beta**2\n",
    "    _diag = [S[i, i] for i in range(len(S))]\n",
    "    diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))\n",
    "    diag_j = diag_i.transpose()\n",
    "\n",
    "    tau = calc_tau(alpha, S, diag_i, diag_j)\n",
    "    tau_dot = calc_tau_dot(alpha, S, diag_i, diag_j)\n",
    "\n",
    "    for i, depth in enumerate((range(1, max_depth + 1, 1))):\n",
    "        H = (2 * S * (2 ** (depth - 1)) * depth * tau_dot * tau ** (depth - 1)) + (\n",
    "            (2 ** depth) * (tau ** depth)\n",
    "        )\n",
    "        K[depth - 1] = H\n",
    "\n",
    "    return K[max_depth-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "atomic-characteristic",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:30.600770Z",
     "start_time": "2021-09-25T08:56:30.586043Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_kernel_hard(alpha, beta, u, finetune, arch):\n",
    "    kernel = []\n",
    "    taus = []\n",
    "    tau_dots = []\n",
    "    inner_product = []\n",
    "    for i in range(360):\n",
    "        Ru = rotation_o(u, i*np.pi/180)\n",
    "        H = hardtree_viz(np.vstack([u, Ru]), alpha=alpha, beta=beta, finetune=finetune, arch=arch)\n",
    "        kernel.append(H[1,0])\n",
    "        inner_product.append(np.dot(u, Ru))\n",
    "    return kernel, inner_product"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4da6fecb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_kernel_soft(alpha, beta, u, depth=2):\n",
    "    kernel = []\n",
    "    taus = []\n",
    "    tau_dots = []\n",
    "    inner_product = []\n",
    "    for i in range(360):\n",
    "        Ru = rotation_o(u, i*np.pi/180)\n",
    "        H = softtree_viz(np.vstack([u, Ru]), alpha=alpha, beta=beta, max_depth=depth)\n",
    "        kernel.append(H[1,0])\n",
    "        inner_product.append(np.dot(u, Ru))\n",
    "    return kernel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42536b5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_pdf(alpha, beta, ylim_min, ylim_max):\n",
    "\n",
    "    plt.figure(figsize=(15,7))\n",
    "    plt.subplot(2,2,1)\n",
    "    for i in tqdm(range(15), leave=False):\n",
    "        Ru = rotation_o((1,0), 6 * i * np.pi / 180) \n",
    "        kernel_list, inner_product_list = get_kernel_hard(alpha=alpha, beta=beta, u=Ru, finetune=False, arch=0)\n",
    "        plt.plot(inner_product_list[0:180], kernel_list[0:180],  linewidth=1, color=cm.jet(i/15))\n",
    "\n",
    "    kernel_list = get_kernel_soft(alpha=alpha, beta=beta, u=Ru, depth=2)\n",
    "    plt.plot(inner_product_list[0:180], kernel_list[0:180],  linewidth=3, color=\"black\", linestyle=\"dotted\", label=\"Oblique\")\n",
    "    plt.grid(linestyle=\"dotted\")\n",
    "    plt.title(\"AAA, Tree architecture=(A)\")\n",
    "    plt.ylim(ylim_min, ylim_max)\n",
    "    plt.xlim(-1.0, 1.0)\n",
    "    plt.ylabel(\"Kernel value\")\n",
    "    plt.legend(loc='upper left')\n",
    "\n",
    "    plt.subplot(2,2,2)\n",
    "    for i in tqdm(range(15), leave=False):\n",
    "        Ru = rotation_o((1,0), 6 * i * np.pi / 180)\n",
    "        kernel_list, inner_product_list = get_kernel_hard(alpha=alpha, beta=beta, u=Ru, finetune=False, arch=1)\n",
    "        plt.plot(inner_product_list[0:180], kernel_list[0:180],  linewidth=1, color=cm.jet(i/15))\n",
    "    kernel_list = get_kernel_soft(alpha=alpha, beta=beta, u=Ru, depth=2)\n",
    "    plt.plot(inner_product_list[0:180], kernel_list[0:180],  linewidth=3, color=\"black\", linestyle=\"dotted\", label=\"Oblique\")\n",
    "    plt.grid(linestyle=\"dotted\")\n",
    "    plt.title(\"AAA, Tree architecture=(B)\")\n",
    "    plt.ylim(ylim_min, ylim_max)\n",
    "    plt.xlim(-1.0, 1.0)\n",
    "    plt.legend(loc='upper left')\n",
    "\n",
    "    plt.subplot(2,2,3)\n",
    "    for i in tqdm(range(15), leave=False):\n",
    "        Ru = rotation_o((1,0), 6 * i * np.pi / 180) \n",
    "        kernel_list, inner_product_list = get_kernel_hard(alpha=alpha, beta=beta, u=Ru, finetune=True, arch=0)\n",
    "        plt.plot(inner_product_list[0:180], kernel_list[0:180],  linewidth=1, color=cm.jet(i/15))\n",
    "    kernel_list = get_kernel_soft(alpha=alpha, beta=beta, u=Ru, depth=2)\n",
    "    plt.plot(inner_product_list[0:180], kernel_list[0:180],  linewidth=3, color=\"black\", linestyle=\"dotted\", label=\"Oblique\")\n",
    "    plt.grid(linestyle=\"dotted\")\n",
    "    plt.title(\"AAI, Tree architecture=(A)\")\n",
    "    plt.ylim(ylim_min, ylim_max)\n",
    "    plt.xlim(-1.0, 1.0)\n",
    "    plt.ylabel(\"Kernel value\")\n",
    "    plt.xlabel(\"Inner product of the inputs\")\n",
    "    plt.legend(loc='upper left')\n",
    "\n",
    "    plt.subplot(2,2,4)\n",
    "    for i in tqdm(range(15), leave=False):\n",
    "        Ru = rotation_o((1,0), 6 * i * np.pi / 180)\n",
    "        kernel_list, inner_product_list = get_kernel_hard(alpha=alpha, beta=beta, u=Ru, finetune=True, arch=1)\n",
    "        plt.plot(inner_product_list[0:180], kernel_list[0:180],  linewidth=1, color=cm.jet(i/15))\n",
    "    kernel_list = get_kernel_soft(alpha=alpha, beta=beta, u=Ru, depth=2)\n",
    "    plt.plot(inner_product_list[0:180], kernel_list[0:180],  linewidth=3, color=\"black\", linestyle=\"dotted\", label=\"Oblique\")\n",
    "    plt.grid(linestyle=\"dotted\")\n",
    "    plt.title(\"AAI, Tree architecture=(B)\")\n",
    "    plt.ylim(ylim_min, ylim_max)\n",
    "    plt.xlim(-1.0, 1.0)\n",
    "    plt.xlabel(\"Inner product of the inputs\")\n",
    "    plt.legend(loc='upper left')\n",
    "\n",
    "    cax = plt.axes([0.145, -0.02, 0.75, 0.02])\n",
    "    norm = matplotlib.colors.Normalize(vmin=0, vmax=90)\n",
    "    plt.colorbar(\n",
    "        matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.jet), \n",
    "        cax=cax, \n",
    "        orientation='horizontal', \n",
    "        label=\"Rotation angle (degree)\", \n",
    "        ticks = [15, 30, 45, 60, 75]\n",
    "    )\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.suptitle(f\"$\\\\alpha$={alpha}, $\\\\beta$={beta}\", y=1.0)\n",
    "    plt.savefig(f\"./figures/kernels_{alpha}_{beta}.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7561103b",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_pdf(alpha=2.0, beta=0.5, ylim_max=2.5, ylim_min=-0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f518c84c",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_pdf(alpha=1.0, beta=0.5, ylim_max=2.5, ylim_min=-0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4aab577",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_pdf(alpha=4.0, beta=0.5, ylim_max=2.5, ylim_min=-0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0fb6f4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_pdf(alpha=2.0, beta=0.1, ylim_max=2.5, ylim_min=-0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cd0edad",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_pdf(alpha=2.0, beta=1.0, ylim_max=2.5, ylim_min=-0.5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
