{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "temporal-watershed",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:25.807349Z",
     "start_time": "2021-09-25T08:56:25.801536Z"
    }
   },
   "outputs": [],
   "source": [
    "def set_latex():\n",
    "    for i in range(2):\n",
    "        import matplotlib\n",
    "        import matplotlib.pyplot as plt\n",
    "\n",
    "        plt.rc('text', usetex=True)\n",
    "        plt.rc('font', family='serif')\n",
    "\n",
    "        plt.style.use(\"default\")\n",
    "        plt.rcParams[\"font.size\"]=15\n",
    "\n",
    "        plt.rcParams['font.family'] = 'Times New Roman'\n",
    "        plt.rcParams['mathtext.fontset'] = 'stix'\n",
    "\n",
    "        try:\n",
    "            del matplotlib.font_manager.weight_dict['roman']\n",
    "            matplotlib.font_manager._rebuild()\n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "completed-chance",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:28.621842Z",
     "start_time": "2021-09-25T08:56:26.049699Z"
    }
   },
   "outputs": [],
   "source": [
    "set_latex()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "operational-physiology",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:30.580638Z",
     "start_time": "2021-09-25T08:56:28.624803Z"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import warnings \n",
    "from scipy import special\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.colors as colors\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import os\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee548aec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gradient(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False):\n",
    "    if torch.is_tensor(inputs):\n",
    "        inputs = [inputs]\n",
    "    else:\n",
    "        inputs = list(inputs)\n",
    "    grads = torch.autograd.grad(\n",
    "        outputs,\n",
    "        inputs,\n",
    "        grad_outputs,\n",
    "        allow_unused=True,\n",
    "        retain_graph=retain_graph,\n",
    "        create_graph=create_graph,\n",
    "    )\n",
    "    grads = [x if x is not None else torch.zeros_like(y) for x, y in zip(grads, inputs)]\n",
    "    return torch.cat([x.contiguous().view(-1) for x in grads])\n",
    "\n",
    "\n",
    "def compute_kernels(f, xtr, parameters=None):\n",
    "    if parameters is None:\n",
    "        parameters = list(f.parameters())\n",
    "\n",
    "    ktrtr = xtr.new_zeros(len(xtr), len(xtr))\n",
    "\n",
    "    params = []\n",
    "    current = []\n",
    "    for p in sorted(parameters, key=lambda p: p.numel(), reverse=True):\n",
    "        current.append(p)\n",
    "        if sum(p.numel() for p in current) > 2e9 // (8 * (len(xtr))):\n",
    "            if len(current) > 1:\n",
    "                params.append(current[:-1])\n",
    "                current = current[-1:]\n",
    "            else:\n",
    "                params.append(current)\n",
    "                current = []\n",
    "    if len(current) > 0:\n",
    "        params.append(current)\n",
    "\n",
    "    for i, p in enumerate(params):\n",
    "        jtr = xtr.new_empty(len(xtr), sum(u.numel() for u in p))  # (P, N~)\n",
    "\n",
    "        for j, x in enumerate(xtr):\n",
    "            jtr[j] = gradient(f(x[None]), p)  # (N~)\n",
    "\n",
    "        ktrtr.add_(jtr @ jtr.t())\n",
    "        del jtr\n",
    "\n",
    "    return ktrtr\n",
    "\n",
    "def rotation_o(u, t, deg=False):\n",
    "    if deg == True:\n",
    "        t = np.deg2rad(t)\n",
    "    R = np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]])\n",
    "    return np.dot(R, u)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "atomic-characteristic",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T08:56:30.600770Z",
     "start_time": "2021-09-25T08:56:30.586043Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_kernel(alpha, beta, u, finetune, arch):\n",
    "    kernel = []\n",
    "    taus = []\n",
    "    tau_dots = []\n",
    "    inner_product = []\n",
    "    for i in range(360):\n",
    "        Ru = rotation_o(u, i*np.pi/180)\n",
    "        H = hardtree_viz(np.vstack([u, Ru]), alpha=alpha, beta=beta, finetune=finetune, arch=arch)\n",
    "        kernel.append(H[1,0])\n",
    "        inner_product.append(np.dot(u, Ru))\n",
    "    return kernel, inner_product"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a59d62f",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.environ.get(\"GPU\"):\n",
    "    device = os.environ.get(\"GPU\") if torch.cuda.is_available() else \"cpu\"\n",
    "else:\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "torch.set_default_dtype(torch.float64)\n",
    "\n",
    "\n",
    "class InnerNode:\n",
    "    def __init__(self, config, depth, asym=False):\n",
    "        self.config = config\n",
    "        self.leaf = False\n",
    "        self.fc = nn.Linear(\n",
    "            self.config[\"input_dim\"], self.config[\"n_tree\"], bias=True\n",
    "        ).to(device)\n",
    "        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        self.prob = None\n",
    "        self.path_prob = None\n",
    "        self.left = None\n",
    "        self.right = None\n",
    "        self.leaf_accumulator = []\n",
    "        self.asym = asym\n",
    "\n",
    "        self.build_child(depth)\n",
    "\n",
    "    def build_child(self, depth):\n",
    "        if depth < self.config[\"max_depth\"]:\n",
    "            self.left = InnerNode(self.config, depth + 1, asym=self.asym)\n",
    "            if self.asym:\n",
    "                self.right = LeafNode(self.config)\n",
    "            else:\n",
    "                self.right = InnerNode(self.config, depth + 1, asym=self.asym)\n",
    "        else:\n",
    "            self.left = LeafNode(self.config)\n",
    "            self.right = LeafNode(self.config)\n",
    "\n",
    "    def forward(self, x):  # decision function\n",
    "        return (\n",
    "            0.5\n",
    "            * torch.erf(\n",
    "                self.config[\"scale\"]\n",
    "                * (\n",
    "                    torch.matmul(x, self.fc.weight.t())\n",
    "                    + self.config[\"bias_scale\"] * self.fc.bias\n",
    "                )\n",
    "            )\n",
    "            + 0.5\n",
    "        )\n",
    "\n",
    "    def calc_prob(self, x, path_prob):\n",
    "        self.prob = self.forward(x)  # probability of selecting right node\n",
    "        path_prob = path_prob.to(device)  # path_prob: [batch_size, n_tree]\n",
    "        self.path_prob = path_prob\n",
    "        left_leaf_accumulator = self.left.calc_prob(x, path_prob * (1 - self.prob))\n",
    "        right_leaf_accumulator = self.right.calc_prob(x, path_prob * self.prob)\n",
    "        self.leaf_accumulator.extend(left_leaf_accumulator)\n",
    "        self.leaf_accumulator.extend(right_leaf_accumulator)\n",
    "        return self.leaf_accumulator\n",
    "\n",
    "    def reset(self):\n",
    "        self.leaf_accumulator = []\n",
    "        self.penalties = []\n",
    "        self.left.reset()\n",
    "        self.right.reset()\n",
    "\n",
    "\n",
    "class SparseInnerNode(InnerNode):\n",
    "    def __init__(self, config, depth, asym=False, feature_index=None):\n",
    "        super().__init__(config, depth, asym)\n",
    "        if feature_index is None:\n",
    "            self.feature_index = np.random.randint(self.config[\"input_dim\"])\n",
    "        else:\n",
    "            self.feature_index = feature_index\n",
    "\n",
    "        self.fc = nn.Linear(1, self.config[\"n_tree\"], bias=True).to(device)\n",
    "        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "\n",
    "    def build_child(self, depth):\n",
    "        if depth < self.config[\"max_depth\"]:\n",
    "            self.left = SparseInnerNode(self.config, depth + 1, asym=self.asym)\n",
    "            if self.asym:\n",
    "                self.right = LeafNode(self.config)\n",
    "            else:\n",
    "                self.right = SparseInnerNode(self.config, depth + 1, asym=self.asym)\n",
    "        else:\n",
    "            self.left = LeafNode(self.config)\n",
    "            self.right = LeafNode(self.config)\n",
    "\n",
    "    def forward(self, x):  # decision function\n",
    "        return (\n",
    "            0.5\n",
    "            * torch.erf(\n",
    "                self.config[\"scale\"]\n",
    "                * (\n",
    "                    torch.matmul(\n",
    "                        x[:, self.feature_index].unsqueeze(dim=1), self.fc.weight.t()\n",
    "                    )\n",
    "                    + self.config[\"bias_scale\"] * self.fc.bias\n",
    "                )\n",
    "            )\n",
    "            + 0.5\n",
    "        )  # -> [batch_size, n_tree]\n",
    "\n",
    "\n",
    "class SparseFinetuneInnerNode(InnerNode):\n",
    "    def __init__(self, config, depth, asym=False, feature_index=None):\n",
    "        super().__init__(config, depth, asym)\n",
    "        if feature_index is None:\n",
    "            self.feature_index = np.random.randint(self.config[\"input_dim\"])\n",
    "        else:\n",
    "            self.feature_index = feature_index\n",
    "\n",
    "        self.fc = nn.Linear(\n",
    "            self.config[\"input_dim\"], self.config[\"n_tree\"], bias=True\n",
    "        ).to(device)\n",
    "        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for i, w_per_tree in enumerate(self.fc.weight):\n",
    "                for j, w in enumerate(w_per_tree):\n",
    "                    if j != feature_index:\n",
    "                        self.fc.weight[i][j] *= 0\n",
    "\n",
    "    def build_child(self, depth):\n",
    "        if depth < self.config[\"max_depth\"]:\n",
    "            self.left = SparseFinetuneInnerNode(self.config, depth + 1, asym=self.asym)\n",
    "            if self.asym:\n",
    "                self.right = LeafNode(self.config)\n",
    "            else:\n",
    "                self.right = SparseFinetuneInnerNode(\n",
    "                    self.config, depth + 1, asym=self.asym\n",
    "                )\n",
    "        else:\n",
    "            self.left = LeafNode(self.config)\n",
    "            self.right = LeafNode(self.config)\n",
    "\n",
    "\n",
    "class LeafNode:\n",
    "    def __init__(self, config):\n",
    "        self.config = config\n",
    "        self.leaf = True\n",
    "        self.param = nn.Parameter(\n",
    "            torch.randn(self.config[\"output_dim\"], self.config[\"n_tree\"]).to(device)\n",
    "        )  # [n_class, n_tree]\n",
    " \n",
    "    def forward(self):\n",
    "        return self.param\n",
    "\n",
    "    def calc_prob(self, x, path_prob):\n",
    "        path_prob = path_prob.to(device)  # [batch_size, n_tree]\n",
    "\n",
    "        Q = self.forward()\n",
    "        Q = Q.expand(\n",
    "            (path_prob.size()[0], self.config[\"output_dim\"], self.config[\"n_tree\"])\n",
    "        )  # -> [batch_size, n_class, n_tree]\n",
    "        return [[path_prob, Q]]\n",
    "\n",
    "    def reset(self):\n",
    "        pass\n",
    "\n",
    "\n",
    "class SoftTree(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        output_dim: int,\n",
    "        max_depth: int,\n",
    "        scale: float,\n",
    "        bias_scale: float,\n",
    "        n_tree: int,\n",
    "        asym: bool = False,\n",
    "        sparse: bool = False,\n",
    "    ):\n",
    "        super(SoftTree, self).__init__()\n",
    "        config = {\n",
    "            \"input_dim\": input_dim,\n",
    "            \"output_dim\": output_dim,\n",
    "            \"max_depth\": max_depth,\n",
    "            \"scale\": scale,\n",
    "            \"bias_scale\": bias_scale,\n",
    "            \"n_tree\": n_tree,\n",
    "        }\n",
    "        self.config = config\n",
    "        if sparse:\n",
    "            self.root = SparseInnerNode(config, depth=1, asym=asym)\n",
    "        else:\n",
    "            self.root = InnerNode(config, depth=1, asym=asym)\n",
    "\n",
    "        self.collect_parameters()\n",
    "\n",
    "    def collect_parameters(self):\n",
    "        nodes = [self.root]\n",
    "        self.module_list = nn.ModuleList()\n",
    "        self.param_list = nn.ParameterList()\n",
    "        while nodes:\n",
    "            node = nodes.pop(0)\n",
    "            if node.leaf:\n",
    "                param = node.param\n",
    "                self.param_list.append(param)\n",
    "            else:\n",
    "                fc = node.fc\n",
    "                nodes.append(node.right)\n",
    "                nodes.append(node.left)\n",
    "                self.module_list.append(fc)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config[\"input_dim\"])\n",
    "\n",
    "        path_prob_init = torch.Tensor(torch.ones(x.shape[0], self.config[\"n_tree\"]))\n",
    "\n",
    "        leaf_accumulator = self.root.calc_prob(x, path_prob_init)\n",
    "        pred = torch.zeros(x.shape[0], self.config[\"output_dim\"]).to(device)\n",
    "        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop\n",
    "            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)\n",
    "\n",
    "        pred /= np.sqrt(self.config[\"n_tree\"])  # NTK scaling\n",
    "\n",
    "        self.root.reset()\n",
    "        return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27b4b1cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SoftTreeExp(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        output_dim: int,\n",
    "        max_depth: int,\n",
    "        scale: float,\n",
    "        bias_scale: float,\n",
    "        n_tree: int,\n",
    "        asym: bool=False,\n",
    "        sparse: bool=True,\n",
    "        finetune: bool=False,\n",
    "        arch:int=0,\n",
    "    ):\n",
    "        super(SoftTreeExp, self).__init__()\n",
    "        config = {\n",
    "            \"input_dim\": input_dim,\n",
    "            \"output_dim\": output_dim,\n",
    "            \"scale\": scale,\n",
    "            \"bias_scale\": bias_scale,\n",
    "            \"n_tree\": n_tree,\n",
    "            \"max_depth\": max_depth\n",
    "        }\n",
    "        self.config = config\n",
    "        \n",
    "        assert sparse # only for sparse tree\n",
    "        assert finetune <= sparse\n",
    "        \n",
    "        if finetune: # AAI\n",
    "           #depth=1\n",
    "            self.root = SparseFinetuneInnerNode(config, depth=1, feature_index=0)\n",
    "            #depth=2\n",
    "            self.root.left = SparseFinetuneInnerNode(config, depth=2, feature_index=0 if arch==0 else 1)\n",
    "            self.root.right = SparseFinetuneInnerNode(config, depth=2, feature_index=0 if arch==0 else 1)\n",
    "        else: # AAA\n",
    "            # depth=1\n",
    "            self.root = SparseInnerNode(config, depth=1, feature_index=0)\n",
    "            #depth=2\n",
    "            self.root.left = SparseInnerNode(config, depth=2, feature_index=0 if arch==0 else 1)\n",
    "            self.root.right = SparseInnerNode(config, depth=2, feature_index=0 if arch==0 else 1)\n",
    " \n",
    "        #depth=3\n",
    "        self.root.left.left = LeafNode(config)\n",
    "        self.root.left.right = LeafNode(config)\n",
    "        self.root.right.left = LeafNode(config)\n",
    "        self.root.right.right = LeafNode(config)\n",
    "\n",
    "        self.collect_parameters()\n",
    "\n",
    "    def collect_parameters(self):\n",
    "        nodes = [self.root]\n",
    "        self.module_list = nn.ModuleList()\n",
    "        self.param_list = nn.ParameterList()\n",
    "        while nodes:\n",
    "            node = nodes.pop(0)\n",
    "            if node.leaf:\n",
    "                param = node.param\n",
    "                self.param_list.append(param)\n",
    "            else:\n",
    "                fc = node.fc\n",
    "                nodes.append(node.right)\n",
    "                nodes.append(node.left)\n",
    "                self.module_list.append(fc)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config[\"input_dim\"])\n",
    "\n",
    "        path_prob_init = torch.Tensor(torch.ones(x.shape[0], self.config[\"n_tree\"]))\n",
    "\n",
    "        leaf_accumulator = self.root.calc_prob(x, path_prob_init)\n",
    "        pred = torch.zeros(x.shape[0], self.config[\"output_dim\"])\n",
    "        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop\n",
    "            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)\n",
    "\n",
    "        pred /= np.sqrt(self.config[\"n_tree\"])  # NTK scaling\n",
    "\n",
    "        self.root.reset()\n",
    "        return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "spectacular-chicago",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T07:19:53.664888Z",
     "start_time": "2021-09-25T07:19:53.657662Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_kernel(alpha, beta, n_tree, color, u: tuple=(1,0), finetune: bool=False, arch: int=0):\n",
    "    st = SoftTreeExp(input_dim=2, output_dim=1, max_depth=2, scale=alpha, bias_scale=beta, n_tree=n_tree, sparse=True, finetune=finetune, arch=arch)\n",
    "    res = []\n",
    "    inner_product = []\n",
    "    for i in tqdm(range(180), leave=False):\n",
    "        Ru = rotation_o(u, i * np.pi / 180)\n",
    "        x = torch.Tensor([u, Ru]).to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        K = compute_kernels(st, x)\n",
    "        res.append(K[1,0].tolist())\n",
    "        inner_product.append(np.dot(u, Ru))\n",
    "\n",
    "    plt.plot(inner_product, res, color=color, linewidth=1, zorder=1)\n",
    "    return inner_product, res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sensitive-writer",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-25T07:40:03.406986Z",
     "start_time": "2021-09-25T07:33:47.666089Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_kernel_finite(alpha: float, beta: float, u: list, n_seeds: int = 10, finetune: bool = False, colormap=cm.bwr, arch: int=0)->None:\n",
    "    n_tree_combinations = [16, 64, 256, 1024, 4096]\n",
    "    for j, n_tree in enumerate(tqdm(n_tree_combinations, leave=False)):\n",
    "        res_all = []\n",
    "        kernel_list, inner_product_list = get_kernel(alpha, beta, u, finetune, arch)\n",
    "        for i in tqdm(range(n_seeds), leave=False):\n",
    "            inner_product, res = plot_kernel(\n",
    "                alpha, beta, n_tree, color=colormap(j/len(n_tree_combinations)), u=u, finetune=finetune, arch=arch\n",
    "            )\n",
    "            res_all.append(res)\n",
    "    plt.plot(inner_product_list[0:180], kernel_list[0:180],  color=\"black\", linewidth=3, linestyle=\"dotted\", zorder=2)\n",
    "    if u==(1,0):\n",
    "        plt.title(\"$x_i =(1, 0)$\"+f\", Tree architecture={'(A)' if arch==0 else '(B)'}\")\n",
    "    else:\n",
    "        plt.title(\"$x_i =({1}/{\\\\sqrt{2}}, {1}/{\\\\sqrt{2}})$\"+f\", Architecture={'(A)' if arch==0 else '(B)'}\")        \n",
    "    plt.xlabel(\"Inner product of the inputs\")\n",
    "\n",
    "    ylim_min = -0.5\n",
    "    ylim_max = 2.5\n",
    "\n",
    "    plt.ylim(ylim_min, ylim_max)\n",
    "    plt.grid(linestyle=\"dotted\")\n",
    "    \n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "569fa072",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_tau(alpha: float, S: np.array, diag_i: np.array, diag_j: np.array) -> np.array:\n",
    "    tau = 1 / 4 + 1 / (2 * math.pi) * np.arcsin(\n",
    "        ((alpha ** 2) * S)\n",
    "        / (np.sqrt(((alpha ** 2) * diag_i + 0.5) * ((alpha ** 2) * diag_j + 0.5)))\n",
    "    )\n",
    "    return tau\n",
    "\n",
    "\n",
    "def calc_tau_dot(\n",
    "    alpha: float, S: np.array, diag_i: np.array, diag_j: np.array\n",
    ") -> np.array:\n",
    "    tau_dot = (\n",
    "        (alpha ** 2)\n",
    "        / (math.pi)\n",
    "        * 1\n",
    "        / np.sqrt(\n",
    "            (2 * (alpha ** 2) * diag_i + 1) * (2 * (alpha ** 2) * diag_j + 1)\n",
    "            - (4 * (alpha ** 4) * (S ** 2))\n",
    "        )\n",
    "    )\n",
    "    return tau_dot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93d88d4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def hardtree_viz(\n",
    "    X: np.array, alpha: float, beta: float, finetune: bool, arch: int\n",
    "):\n",
    "    \n",
    "    assert arch in {0, 1}\n",
    "    \n",
    "    S_list = []\n",
    "    tau_list = []\n",
    "    tau_dot_list = []\n",
    "\n",
    "    for feature_index in range(len(X[0])):\n",
    "        S = np.outer(X[:, feature_index], X[:, feature_index].T) + beta**2\n",
    "        S_all = np.matmul(X, X.T) + beta**2\n",
    "        if finetune:\n",
    "            S_list.append(S_all)\n",
    "        else:\n",
    "            S_list.append(S)\n",
    "\n",
    "        _diag = [S[i, i] for i in range(len(S))]\n",
    "        diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))\n",
    "        diag_j = diag_i.transpose()\n",
    "        tau_list.append(calc_tau(alpha, S, diag_i, diag_j))\n",
    "        tau_dot_list.append(calc_tau_dot(alpha, S, diag_i, diag_j))\n",
    "        \n",
    "    K = np.zeros((X.shape[0], X.shape[0]))\n",
    "    if arch==0:\n",
    "        rulelist = [[0, 0], [0, 0], [0, 0], [0, 0]]\n",
    "    elif arch==1:\n",
    "        rulelist = [[0, 1], [0, 1], [0, 1], [0, 1]]  \n",
    "    H = np.zeros_like(S_list[0])\n",
    "    for rules in rulelist:\n",
    "        \n",
    "        # Internal nodes\n",
    "        for i, s in enumerate(rules):\n",
    "            ts = rules[0:i]+rules[i+1:]\n",
    "            _H_nodes = S_list[s]* tau_dot_list[s]\n",
    "            for t in ts:\n",
    "                _H_nodes *= tau_list[t] # nodes\n",
    "            K+= _H_nodes\n",
    "        _H_leaves = np.ones_like(K)\n",
    "        \n",
    "        # Leaves\n",
    "        for tau in [tau_list[i] for i in rules]:\n",
    "            _H_leaves *= tau\n",
    "        K += _H_leaves\n",
    "    \n",
    "    return K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9991086",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_seeds = 10\n",
    "alpha = 2.0\n",
    "beta = 0.5\n",
    "\n",
    "arch=0\n",
    "plt.figure(figsize=(18, 10))\n",
    "plt.subplot(2,4,1)\n",
    "plot_kernel_finite(alpha=alpha, beta=beta, u=(1,0), n_seeds=n_seeds, finetune=False, colormap=cm.PRGn, arch=arch)\n",
    "plt.ylabel(\"Kernel value (AAA)\")\n",
    "\n",
    "plt.subplot(2,4,2)\n",
    "plot_kernel_finite(alpha=alpha, beta=beta, u=(1/np.sqrt(2), 1/np.sqrt(2)), n_seeds=n_seeds, finetune=False, colormap=cm.PRGn, arch=arch)\n",
    "\n",
    "arch=1\n",
    "plt.subplot(2,4,3)\n",
    "plot_kernel_finite(alpha=alpha, beta=beta, u=(1,0), n_seeds=n_seeds, finetune=False, colormap=cm.PRGn, arch=arch)\n",
    "\n",
    "plt.subplot(2,4,4)\n",
    "plot_kernel_finite(alpha=alpha, beta=beta, u=(1/np.sqrt(2), 1/np.sqrt(2)), n_seeds=n_seeds, finetune=False, colormap=cm.PRGn, arch=arch)\n",
    "\n",
    "arch=0\n",
    "plt.subplot(2,4,5)\n",
    "plot_kernel_finite(alpha=alpha, beta=beta, u=(1,0), n_seeds=n_seeds, finetune=True, colormap=cm.bwr, arch=arch)\n",
    "plt.ylabel(\"Kernel value (AAI)\")\n",
    "\n",
    "plt.subplot(2,4,6)\n",
    "plot_kernel_finite(alpha=alpha, beta=beta, u=(1/np.sqrt(2), 1/np.sqrt(2)), n_seeds=n_seeds, finetune=True, colormap=cm.bwr, arch=arch)\n",
    "\n",
    "arch=1\n",
    "plt.subplot(2,4,7)\n",
    "plot_kernel_finite(alpha=alpha, beta=beta, u=(1,0), n_seeds=n_seeds, finetune=True, colormap=cm.bwr, arch=arch)\n",
    "\n",
    "plt.subplot(2,4,8)\n",
    "plot_kernel_finite(alpha=alpha, beta=beta, u=(1/np.sqrt(2), 1/np.sqrt(2)), n_seeds=n_seeds, finetune=True, colormap=cm.bwr, arch=arch)\n",
    "\n",
    "# Legend\n",
    "plt.subplot(2,4,6)\n",
    "patterns = [16, 64, 256, 1024, 4096]\n",
    "for i in range(5):\n",
    "    plt.plot([], [], color=cm.PRGn(i/5), linewidth=1, label=f\"$M={patterns[i]}$\")\n",
    "plt.plot([], [], color=\"black\", linewidth=3, linestyle=\"dotted\", label=\"$M={\\\\infty}$\")\n",
    "plt.legend(ncol=3, bbox_to_anchor=(0.5, -0.35), fontsize=12, title=\"AAA\", loc=\"center\", borderaxespad=0)\n",
    "\n",
    "plt.subplot(2,4,7)\n",
    "for i in range(5):\n",
    "    plt.plot([], [], color=cm.bwr(i/5), linewidth=1, label=f\"$M={patterns[i]}$\")\n",
    "plt.plot([], [], color=\"black\", linewidth=3, linestyle=\"dotted\", label=\"$M={\\\\infty}$\")\n",
    "plt.legend(ncol=3, bbox_to_anchor=(0.5, -0.35), fontsize=12, title=\"AAI\", loc=\"center\", borderaxespad=0)\n",
    "\n",
    "plt.savefig(\"./figures/kernels_asymptotic.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fbc41d3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
