{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:27.744893Z",
     "start_time": "2021-09-27T03:15:27.738004Z"
    }
   },
   "outputs": [],
   "source": [
    "def set_latex():\n",
    "    for i in range(2):\n",
    "        import matplotlib\n",
    "        import matplotlib.pyplot as plt\n",
    "\n",
    "        plt.rc('text', usetex=True)\n",
    "        plt.rc('font', family='serif')\n",
    "\n",
    "        plt.style.use(\"default\")\n",
    "        plt.rcParams[\"font.size\"]=15\n",
    "\n",
    "        plt.rcParams['font.family'] = 'Times New Roman'\n",
    "        plt.rcParams['mathtext.fontset'] = 'stix'\n",
    "\n",
    "        try:\n",
    "            del matplotlib.font_manager.weight_dict['roman']\n",
    "            matplotlib.font_manager._rebuild()\n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:32.310591Z",
     "start_time": "2021-09-27T03:15:27.749292Z"
    }
   },
   "outputs": [],
   "source": [
    "set_latex()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:35.145364Z",
     "start_time": "2021-09-27T03:15:32.313901Z"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import os\n",
    "import numpy as np\n",
    "import copy\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import warnings \n",
    "from scipy import special\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.colors as colors\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "if os.environ.get(\"GPU\"):\n",
    "    device = os.environ.get(\"GPU\") if torch.cuda.is_available() else \"cpu\"\n",
    "else:\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "torch.set_default_dtype(torch.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class InnerNode:\n",
    "    def __init__(self, config, depth, asym=False):\n",
    "        self.config = config\n",
    "        self.leaf = False\n",
    "        self.fc = nn.Linear(\n",
    "            self.config[\"input_dim\"], self.config[\"n_tree\"], bias=True\n",
    "        ).to(device)\n",
    "        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        self.prob = None\n",
    "        self.path_prob = None\n",
    "        self.left = None\n",
    "        self.right = None\n",
    "        self.leaf_accumulator = []\n",
    "        self.asym = asym\n",
    "\n",
    "        self.build_child(depth)\n",
    "\n",
    "    def build_child(self, depth):\n",
    "        if depth < self.config[\"max_depth\"]:\n",
    "            self.left = InnerNode(self.config, depth + 1, asym=self.asym)\n",
    "            if self.asym:\n",
    "                self.right = LeafNode(self.config)\n",
    "            else:\n",
    "                self.right = InnerNode(self.config, depth + 1, asym=self.asym)\n",
    "        else:\n",
    "            self.left = LeafNode(self.config)\n",
    "            self.right = LeafNode(self.config)\n",
    "\n",
    "    def forward(self, x):  # decision function\n",
    "        return (\n",
    "            0.5\n",
    "            * torch.erf(\n",
    "                self.config[\"scale\"]\n",
    "                * (\n",
    "                    torch.matmul(x, self.fc.weight.t())\n",
    "                    + self.config[\"bias_scale\"] * self.fc.bias\n",
    "                )\n",
    "            )\n",
    "            + 0.5\n",
    "        )\n",
    "\n",
    "    def calc_prob(self, x, path_prob):\n",
    "        self.prob = self.forward(x)  # probability of selecting right node\n",
    "        path_prob = path_prob.to(device)  # path_prob: [batch_size, n_tree]\n",
    "        self.path_prob = path_prob\n",
    "        left_leaf_accumulator = self.left.calc_prob(x, path_prob * (1 - self.prob))\n",
    "        right_leaf_accumulator = self.right.calc_prob(x, path_prob * self.prob)\n",
    "        self.leaf_accumulator.extend(left_leaf_accumulator)\n",
    "        self.leaf_accumulator.extend(right_leaf_accumulator)\n",
    "        return self.leaf_accumulator\n",
    "\n",
    "    def reset(self):\n",
    "        self.leaf_accumulator = []\n",
    "        self.penalties = []\n",
    "        self.left.reset()\n",
    "        self.right.reset()\n",
    "\n",
    "\n",
    "class SparseInnerNode(InnerNode):\n",
    "    def __init__(self, config, depth, asym=False, feature_index=None):\n",
    "        super().__init__(config, depth, asym)\n",
    "        if feature_index is None:\n",
    "            self.feature_index = np.random.randint(self.config[\"input_dim\"])\n",
    "        else:\n",
    "            self.feature_index = feature_index\n",
    "\n",
    "        self.fc = nn.Linear(1, self.config[\"n_tree\"], bias=True).to(device)\n",
    "        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "\n",
    "    def build_child(self, depth):\n",
    "        if depth < self.config[\"max_depth\"]:\n",
    "            self.left = SparseInnerNode(self.config, depth + 1, asym=self.asym)\n",
    "            if self.asym:\n",
    "                self.right = LeafNode(self.config)\n",
    "            else:\n",
    "                self.right = SparseInnerNode(self.config, depth + 1, asym=self.asym)\n",
    "        else:\n",
    "            self.left = LeafNode(self.config)\n",
    "            self.right = LeafNode(self.config)\n",
    "\n",
    "    def forward(self, x):  # decision function\n",
    "        return (\n",
    "            0.5\n",
    "            * torch.erf(\n",
    "                self.config[\"scale\"]\n",
    "                * (\n",
    "                    torch.matmul(\n",
    "                        x[:, self.feature_index].unsqueeze(dim=1), self.fc.weight.t()\n",
    "                    )\n",
    "                    + self.config[\"bias_scale\"] * self.fc.bias\n",
    "                )\n",
    "            )\n",
    "            + 0.5\n",
    "        )  # -> [batch_size, n_tree]\n",
    "\n",
    "\n",
    "class SparseFinetuneInnerNode(InnerNode):\n",
    "    def __init__(self, config, depth, asym=False, feature_index=None):\n",
    "        super().__init__(config, depth, asym)\n",
    "        if feature_index is None:\n",
    "            self.feature_index = np.random.randint(self.config[\"input_dim\"])\n",
    "        else:\n",
    "            self.feature_index = feature_index\n",
    "\n",
    "        self.fc = nn.Linear(\n",
    "            self.config[\"input_dim\"], self.config[\"n_tree\"], bias=True\n",
    "        ).to(device)\n",
    "        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for i, w_per_tree in enumerate(self.fc.weight):\n",
    "                for j, w in enumerate(w_per_tree):\n",
    "                    if j != feature_index:\n",
    "                        self.fc.weight[i][j] *= 0\n",
    "\n",
    "    def build_child(self, depth):\n",
    "        if depth < self.config[\"max_depth\"]:\n",
    "            self.left = SparseFinetuneInnerNode(self.config, depth + 1, asym=self.asym)\n",
    "            if self.asym:\n",
    "                self.right = LeafNode(self.config)\n",
    "            else:\n",
    "                self.right = SparseFinetuneInnerNode(\n",
    "                    self.config, depth + 1, asym=self.asym\n",
    "                )\n",
    "        else:\n",
    "            self.left = LeafNode(self.config)\n",
    "            self.right = LeafNode(self.config)\n",
    "\n",
    "\n",
    "class LeafNode:\n",
    "    def __init__(self, config):\n",
    "        self.config = config\n",
    "        self.leaf = True\n",
    "        self.param = nn.Parameter(\n",
    "            torch.randn(self.config[\"output_dim\"], self.config[\"n_tree\"]).to(device)\n",
    "        )  # [n_class, n_tree]\n",
    " \n",
    "    def forward(self):\n",
    "        return self.param\n",
    "\n",
    "    def calc_prob(self, x, path_prob):\n",
    "        path_prob = path_prob.to(device)  # [batch_size, n_tree]\n",
    "\n",
    "        Q = self.forward()\n",
    "        Q = Q.expand(\n",
    "            (path_prob.size()[0], self.config[\"output_dim\"], self.config[\"n_tree\"])\n",
    "        )  # -> [batch_size, n_class, n_tree]\n",
    "        return [[path_prob, Q]]\n",
    "\n",
    "    def reset(self):\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SoftTreeExp(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        output_dim: int,\n",
    "        max_depth: int,\n",
    "        scale: float,\n",
    "        bias_scale: float,\n",
    "        n_tree: int,\n",
    "        arch: int,\n",
    "        oblivious: bool,\n",
    "        asym: bool=False,\n",
    "        sparse: bool=True,\n",
    "        finetune: bool=False,\n",
    "    ):\n",
    "        super(SoftTreeExp, self).__init__()\n",
    "        config = {\n",
    "            \"input_dim\": input_dim,\n",
    "            \"output_dim\": output_dim,\n",
    "            \"scale\": scale,\n",
    "            \"bias_scale\": bias_scale,\n",
    "            \"n_tree\": n_tree,\n",
    "            \"max_depth\": max_depth\n",
    "        }\n",
    "        self.config = config\n",
    "        \n",
    "        assert sparse # only for sparse tree\n",
    "        assert finetune <= sparse\n",
    "        \n",
    "        if finetune: # AAI\n",
    "            if arch==0:\n",
    "                # depth=1\n",
    "                self.root = SparseFinetuneInnerNode(config, depth=1, feature_index=0)\n",
    "                #depth=2\n",
    "                self.root.left = LeafNode(config)\n",
    "                self.root.right = SparseFinetuneInnerNode(config, depth=2, feature_index=1)\n",
    "                #depth=3\n",
    "                self.root.right.left = LeafNode(config)\n",
    "                self.root.right.right = LeafNode(config)\n",
    "            elif arch==1:\n",
    "                # depth=1\n",
    "                self.root = SparseFinetuneInnerNode(config, depth=1, feature_index=0)\n",
    "                # depth=2\n",
    "                self.root.left = LeafNode(config)\n",
    "                self.root.right = LeafNode(config)\n",
    "            elif arch==2:\n",
    "                # depth=1\n",
    "                self.root = SparseFinetuneInnerNode(config, depth=1, feature_index=0)\n",
    "                 #depth=2\n",
    "                self.root.left =  SparseFinetuneInnerNode(config, depth=2, feature_index=1)\n",
    "                self.root.right = SparseFinetuneInnerNode(config, depth=2, feature_index=1)                \n",
    "                if oblivious:\n",
    "                    self.root.right.fc.weight = self.root.left.fc.weight\n",
    "                    self.root.right.fc.bias = self.root.left.fc.bias\n",
    "                #depth=3\n",
    "                self.root.left.left = LeafNode(config)\n",
    "                self.root.left.right = LeafNode(config)\n",
    "                self.root.right.left = LeafNode(config)\n",
    "                self.root.right.right = LeafNode(config)\n",
    "            else:\n",
    "                raise ValueError\n",
    "                \n",
    "        else: # AAA\n",
    "            if arch==0:\n",
    "                # depth=1\n",
    "                self.root = SparseInnerNode(config, depth=1, feature_index=0)\n",
    "                #depth=2\n",
    "                self.root.left = LeafNode(config)\n",
    "                self.root.right = SparseInnerNode(config, depth=2, feature_index=1)\n",
    "                #depth=3\n",
    "                self.root.right.left = LeafNode(config)\n",
    "                self.root.right.right = LeafNode(config)\n",
    "            elif arch==1:\n",
    "                # depth=1\n",
    "                self.root = SparseInnerNode(config, depth=1, feature_index=0)\n",
    "                # depth=2\n",
    "                self.root.left = LeafNode(config)\n",
    "                self.root.right = LeafNode(config)\n",
    "            elif arch==2:\n",
    "                # depth=1\n",
    "                self.root = SparseInnerNode(config, depth=1, feature_index=0)\n",
    "                 #depth=2\n",
    "                self.root.left =  SparseInnerNode(config, depth=2, feature_index=1)\n",
    "                self.root.right = SparseInnerNode(config, depth=2, feature_index=1)                \n",
    "                if oblivious:\n",
    "                    self.root.right.fc.weight = self.root.left.fc.weight\n",
    "                    self.root.right.fc.bias = self.root.left.fc.bias\n",
    "                #depth=3\n",
    "                self.root.left.left = LeafNode(config)\n",
    "                self.root.left.right = LeafNode(config)\n",
    "                self.root.right.left = LeafNode(config)\n",
    "                self.root.right.right = LeafNode(config)\n",
    "            else:\n",
    "                raise ValueError\n",
    "            \n",
    "        self.collect_parameters()\n",
    "\n",
    "    def collect_parameters(self):\n",
    "        nodes = [self.root]\n",
    "        self.module_list = nn.ModuleList()\n",
    "        self.param_list = nn.ParameterList()\n",
    "        while nodes:\n",
    "            node = nodes.pop(0)\n",
    "            if node.leaf:\n",
    "                param = node.param\n",
    "                self.param_list.append(param)\n",
    "            else:\n",
    "                fc = node.fc\n",
    "                nodes.append(node.right)\n",
    "                nodes.append(node.left)\n",
    "                self.module_list.append(fc)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config[\"input_dim\"])\n",
    "\n",
    "        path_prob_init = torch.Tensor(torch.ones(x.shape[0], self.config[\"n_tree\"]))\n",
    "\n",
    "        leaf_accumulator = self.root.calc_prob(x, path_prob_init)\n",
    "        pred = torch.zeros(x.shape[0], self.config[\"output_dim\"])\n",
    "        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop\n",
    "            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)\n",
    "\n",
    "        pred /= np.sqrt(self.config[\"n_tree\"])  # NTK scaling\n",
    "\n",
    "        self.root.reset()\n",
    "        return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SoftTreeMerge(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        st1: nn.Module,\n",
    "        st2: nn.Module\n",
    "    ):\n",
    "        super(SoftTreeMerge, self).__init__()\n",
    "        self.st1 = st1\n",
    "        self.st2 = st2\n",
    "\n",
    "    def forward(self, x):\n",
    "        x1 = self.st1.forward(x)\n",
    "        x2 = self.st2.forward(x)\n",
    "        return x1+x2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.datasets import load_diabetes\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if False: # real dataset\n",
    "    import pandas as pd\n",
    "    from sklearn.datasets import load_diabetes\n",
    "    from sklearn.preprocessing import StandardScaler\n",
    "    from sklearn.model_selection import train_test_split\n",
    "\n",
    "    X, y = load_diabetes(as_frame = True, return_X_y=True)\n",
    "    X = X[['bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6', 'age', 'sex']]\n",
    "    target_scaler = StandardScaler()\n",
    "    target_scaler.fit(y.values.reshape(-1, 1))\n",
    "    y = pd.Series((target_scaler.transform(y.values.reshape(-1, 1)).squeeze()), name=\"y\")\n",
    "\n",
    "    for c in X.columns:\n",
    "        feature_scaler = StandardScaler()\n",
    "        feature_scaler.fit(X[c].values.reshape(-1,1))\n",
    "        X[c] = (feature_scaler.transform(X[c].values.reshape(-1, 1)).squeeze())\n",
    "\n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.95)\n",
    "\n",
    "    train_data = torch.Tensor(X_train.values)[0:50]\n",
    "    target_data = torch.tensor(y_train.values)[0:50]\n",
    "    test_data = torch.Tensor(X_test.values)[0:10]\n",
    "else: # random dataset\n",
    "    n_features = 2\n",
    "    n_dataset = 10\n",
    "    train_data = torch.Tensor([np.random.randn(n_features) for i in range(n_dataset)])\n",
    "    target_data = torch.tensor(np.random.randn(train_data.shape[0]))\n",
    "    test_data = torch.Tensor([np.random.randn(n_features) for i in range(10)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "source": [
    "## Tracking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:35.163984Z",
     "start_time": "2021-09-27T03:15:35.158843Z"
    }
   },
   "outputs": [],
   "source": [
    "def train_net(net, n_epochs, input_data, target, lr, initial_train):\n",
    "    criterion = nn.MSELoss(reduction='mean')\n",
    "    optimizer = optim.SGD(net.parameters(), lr=lr)\n",
    "    for epoch in range(n_epochs):  \n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(input_data.double())-initial_train.unsqueeze(1)\n",
    "        loss = criterion(outputs.view(-1), target)/2\n",
    "        loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_trajectory(st: nn.Module, linestyle: str, alpha: float, linewidth: float):\n",
    "    ptrain_empirical1, ptest_empirical1= [], []\n",
    "    t_max = 1000\n",
    "    t_step = 10\n",
    "    lr = 0.1\n",
    "    t_list = np.arange(t_step, t_max+t_step, t_step)\n",
    "\n",
    "    initial_train1=st.forward(train_data).reshape(-1)\n",
    "    initial_test1=st.forward(test_data).reshape(-1)\n",
    "\n",
    "    ptrain_empirical1.append(torch.zeros_like(initial_train1).detach().numpy())\n",
    "    ptest_empirical1.append(torch.zeros_like(initial_test1).detach().numpy())\n",
    "\n",
    "    for t in tqdm(t_list):\n",
    "        train_net(st, t_step, train_data, target_data, lr, initial_train1.detach())\n",
    "        ptrain_empirical1.append(st.forward(train_data).detach().cpu().numpy().reshape(-1)-initial_train1.detach().numpy())\n",
    "        ptest_empirical1.append(st.forward(test_data).detach().cpu().numpy().reshape(-1)-initial_test1.detach().numpy())\n",
    "\n",
    "    cmap = plt.cm.nipy_spectral\n",
    "    t_list = np.arange(0, t_max+t_step, t_step)\n",
    "\n",
    "    for i in range(len(ptest_empirical1[0])):\n",
    "        plt.plot(t_list, np.array(ptest_empirical1)[:,i], color=cmap(i/len(ptest_empirical1[0])), linestyle=linestyle, alpha=alpha, linewidth=linewidth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 2.0\n",
    "beta = 0.5\n",
    "depth = -1\n",
    "\n",
    "plt.figure(figsize=(15,6))\n",
    "\n",
    "# ^^^^^^^^^^^^^^^^^^\n",
    "finetune = False # AAA\n",
    "\n",
    "plt.subplot(2,2,1)\n",
    "n_tree = 8\n",
    "st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st = SoftTreeMerge(st1, st2)\n",
    "\n",
    "plot_trajectory(st, linestyle=\"solid\", linewidth=1, alpha=1.0)\n",
    "\n",
    "st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=1, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=2, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st = SoftTreeMerge(st1, st2)\n",
    "\n",
    "plot_trajectory(st, linestyle=\"dashed\", linewidth=3, alpha=0.5)\n",
    "\n",
    "plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "plt.ylabel(\"Model output\")\n",
    "plt.title(\"AAA $(M=16)$\")\n",
    "plt.ylim(-2.0, 2.0)\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "# ---------------------------\n",
    "\n",
    "plt.subplot(2,2,2)\n",
    "n_tree = 512\n",
    "st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st = SoftTreeMerge(st1, st2)\n",
    "\n",
    "plot_trajectory(st, linestyle=\"solid\", linewidth=1, alpha=1.0)\n",
    "\n",
    "st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=1, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=2, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st = SoftTreeMerge(st1, st2)\n",
    "plot_trajectory(st, linestyle=\"dashed\", linewidth=3, alpha=0.5)\n",
    "\n",
    "plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "plt.title(\"AAA $(M=1024)$\")\n",
    "plt.ylim(-2.0, 2.0)\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "# ^^^^^^^^^^^^^^^^^^\n",
    "finetune = True # AAI\n",
    "\n",
    "plt.subplot(2,2,3)\n",
    "n_tree = 8\n",
    "st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=1, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=2, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st = SoftTreeMerge(st1, st2)\n",
    "\n",
    "plot_trajectory(st, linestyle=\"dashed\", linewidth=3, alpha=0.5)\n",
    "\n",
    "st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st = SoftTreeMerge(st1, st2)\n",
    "plot_trajectory(st, linestyle=\"solid\", linewidth=1, alpha=1.0)\n",
    "plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "plt.ylabel(\"Model output\")\n",
    "plt.title(\"AAI $(M=16)$\")\n",
    "plt.ylim(-2.0, 2.0)\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "# ---------------------------\n",
    "\n",
    "plt.subplot(2,2,4)\n",
    "n_tree = 512\n",
    "st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=1, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=2, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st = SoftTreeMerge(st1, st2)\n",
    "plot_trajectory(st, linestyle=\"dashed\", linewidth=3, alpha=0.5)\n",
    "\n",
    "st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=True, max_depth=depth, finetune=finetune)\n",
    "st = SoftTreeMerge(st1, st2)\n",
    "plot_trajectory(st, linestyle=\"solid\", linewidth=1, alpha=1.0)\n",
    "plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "plt.title(\"AAI $(M=1024)$\")\n",
    "plt.ylim(-2.0, 2.0)\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.subplot(2,2,4)\n",
    "plt.plot([], [], color=\"black\", label=\"Architecture 1 (Non-Oblivious)\", linestyle=\"dashed\", alpha=0.5, linewidth=3)\n",
    "plt.plot([], [], color=\"black\", label=\"Architecture 2 (Oblivious)\", linestyle=\"solid\", alpha=1.0,  linewidth=1)\n",
    "plt.legend(ncol=3, bbox_to_anchor=(-0.05, -0.45), fontsize=15, loc=\"center\", borderaxespad=0)\n",
    "\n",
    "plt.savefig(\"./figures/trajectory_oblivious_conversion.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
