{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:27.744893Z",
     "start_time": "2021-09-27T03:15:27.738004Z"
    }
   },
   "outputs": [],
   "source": [
    "def set_latex():\n",
    "    for i in range(2):\n",
    "        import matplotlib\n",
    "        import matplotlib.pyplot as plt\n",
    "\n",
    "        plt.rc('text', usetex=True)\n",
    "        plt.rc('font', family='serif')\n",
    "\n",
    "        plt.style.use(\"default\")\n",
    "        plt.rcParams[\"font.size\"]=15\n",
    "\n",
    "        plt.rcParams['font.family'] = 'Times New Roman'\n",
    "        plt.rcParams['mathtext.fontset'] = 'stix'\n",
    "\n",
    "        try:\n",
    "            del matplotlib.font_manager.weight_dict['roman']\n",
    "            matplotlib.font_manager._rebuild()\n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:32.310591Z",
     "start_time": "2021-09-27T03:15:27.749292Z"
    }
   },
   "outputs": [],
   "source": [
    "set_latex()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:35.145364Z",
     "start_time": "2021-09-27T03:15:32.313901Z"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import os\n",
    "import numpy as np\n",
    "import copy\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import warnings \n",
    "from scipy import special\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.colors as colors\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "if os.environ.get(\"GPU\"):\n",
    "    device = os.environ.get(\"GPU\") if torch.cuda.is_available() else \"cpu\"\n",
    "else:\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "torch.set_default_dtype(torch.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_tau(alpha: float, S: np.array, diag_i: np.array, diag_j: np.array) -> np.array:\n",
    "    tau = 1 / 4 + 1 / (2 * math.pi) * np.arcsin(\n",
    "        ((alpha ** 2) * S)\n",
    "        / (np.sqrt(((alpha ** 2) * diag_i + 0.5) * ((alpha ** 2) * diag_j + 0.5)))\n",
    "    )\n",
    "    return tau\n",
    "\n",
    "\n",
    "def calc_tau_dot(\n",
    "    alpha: float, S: np.array, diag_i: np.array, diag_j: np.array\n",
    ") -> np.array:\n",
    "    tau_dot = (\n",
    "        (alpha ** 2)\n",
    "        / (math.pi)\n",
    "        * 1\n",
    "        / np.sqrt(\n",
    "            (2 * (alpha ** 2) * diag_i + 1) * (2 * (alpha ** 2) * diag_j + 1)\n",
    "            - (4 * (alpha ** 4) * (S ** 2))\n",
    "        )\n",
    "    )\n",
    "    return tau_dot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class InnerNode:\n",
    "    def __init__(self, config, depth, asym=False):\n",
    "        self.config = config\n",
    "        self.leaf = False\n",
    "        self.fc = nn.Linear(\n",
    "            self.config[\"input_dim\"], self.config[\"n_tree\"], bias=True\n",
    "        ).to(device)\n",
    "        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        self.prob = None\n",
    "        self.path_prob = None\n",
    "        self.left = None\n",
    "        self.right = None\n",
    "        self.leaf_accumulator = []\n",
    "        self.asym = asym\n",
    "\n",
    "        self.build_child(depth)\n",
    "\n",
    "    def build_child(self, depth):\n",
    "        if depth < self.config[\"max_depth\"]:\n",
    "            self.left = InnerNode(self.config, depth + 1, asym=self.asym)\n",
    "            if self.asym:\n",
    "                self.right = LeafNode(self.config)\n",
    "            else:\n",
    "                self.right = InnerNode(self.config, depth + 1, asym=self.asym)\n",
    "        else:\n",
    "            self.left = LeafNode(self.config)\n",
    "            self.right = LeafNode(self.config)\n",
    "\n",
    "    def forward(self, x):  # decision function\n",
    "        return (\n",
    "            0.5\n",
    "            * torch.erf(\n",
    "                self.config[\"scale\"]\n",
    "                * (\n",
    "                    torch.matmul(x, self.fc.weight.t())\n",
    "                    + self.config[\"bias_scale\"] * self.fc.bias\n",
    "                )\n",
    "            )\n",
    "            + 0.5\n",
    "        )\n",
    "\n",
    "    def calc_prob(self, x, path_prob):\n",
    "        self.prob = self.forward(x)  # probability of selecting right node\n",
    "        path_prob = path_prob.to(device)  # path_prob: [batch_size, n_tree]\n",
    "        self.path_prob = path_prob\n",
    "        left_leaf_accumulator = self.left.calc_prob(x, path_prob * (1 - self.prob))\n",
    "        right_leaf_accumulator = self.right.calc_prob(x, path_prob * self.prob)\n",
    "        self.leaf_accumulator.extend(left_leaf_accumulator)\n",
    "        self.leaf_accumulator.extend(right_leaf_accumulator)\n",
    "        return self.leaf_accumulator\n",
    "\n",
    "    def reset(self):\n",
    "        self.leaf_accumulator = []\n",
    "        self.penalties = []\n",
    "        self.left.reset()\n",
    "        self.right.reset()\n",
    "\n",
    "\n",
    "class SparseInnerNode(InnerNode):\n",
    "    def __init__(self, config, depth, asym=False, feature_index=None):\n",
    "        super().__init__(config, depth, asym)\n",
    "        if feature_index is None:\n",
    "            self.feature_index = np.random.randint(self.config[\"input_dim\"])\n",
    "        else:\n",
    "            self.feature_index = feature_index\n",
    "\n",
    "        self.fc = nn.Linear(1, self.config[\"n_tree\"], bias=True).to(device)\n",
    "        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "\n",
    "    def build_child(self, depth):\n",
    "        if depth < self.config[\"max_depth\"]:\n",
    "            self.left = SparseInnerNode(self.config, depth + 1, asym=self.asym)\n",
    "            if self.asym:\n",
    "                self.right = LeafNode(self.config)\n",
    "            else:\n",
    "                self.right = SparseInnerNode(self.config, depth + 1, asym=self.asym)\n",
    "        else:\n",
    "            self.left = LeafNode(self.config)\n",
    "            self.right = LeafNode(self.config)\n",
    "\n",
    "    def forward(self, x):  # decision function\n",
    "        return (\n",
    "            0.5\n",
    "            * torch.erf(\n",
    "                self.config[\"scale\"]\n",
    "                * (\n",
    "                    torch.matmul(\n",
    "                        x[:, self.feature_index].unsqueeze(dim=1), self.fc.weight.t()\n",
    "                    )\n",
    "                    + self.config[\"bias_scale\"] * self.fc.bias\n",
    "                )\n",
    "            )\n",
    "            + 0.5\n",
    "        )  # -> [batch_size, n_tree]\n",
    "\n",
    "\n",
    "class SparseFinetuneInnerNode(InnerNode):\n",
    "    def __init__(self, config, depth, asym=False, feature_index=None):\n",
    "        super().__init__(config, depth, asym)\n",
    "        if feature_index is None:\n",
    "            self.feature_index = np.random.randint(self.config[\"input_dim\"])\n",
    "        else:\n",
    "            self.feature_index = feature_index\n",
    "\n",
    "        self.fc = nn.Linear(\n",
    "            self.config[\"input_dim\"], self.config[\"n_tree\"], bias=True\n",
    "        ).to(device)\n",
    "        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for i, w_per_tree in enumerate(self.fc.weight):\n",
    "                for j, w in enumerate(w_per_tree):\n",
    "                    if j != feature_index:\n",
    "                        self.fc.weight[i][j] *= 0\n",
    "\n",
    "    def build_child(self, depth):\n",
    "        if depth < self.config[\"max_depth\"]:\n",
    "            self.left = SparseFinetuneInnerNode(self.config, depth + 1, asym=self.asym)\n",
    "            if self.asym:\n",
    "                self.right = LeafNode(self.config)\n",
    "            else:\n",
    "                self.right = SparseFinetuneInnerNode(\n",
    "                    self.config, depth + 1, asym=self.asym\n",
    "                )\n",
    "        else:\n",
    "            self.left = LeafNode(self.config)\n",
    "            self.right = LeafNode(self.config)\n",
    "\n",
    "\n",
    "class LeafNode:\n",
    "    def __init__(self, config):\n",
    "        self.config = config\n",
    "        self.leaf = True\n",
    "        self.param = nn.Parameter(\n",
    "            torch.randn(self.config[\"output_dim\"], self.config[\"n_tree\"]).to(device)\n",
    "        )  # [n_class, n_tree]\n",
    " \n",
    "    def forward(self):\n",
    "        return self.param\n",
    "\n",
    "    def calc_prob(self, x, path_prob):\n",
    "        path_prob = path_prob.to(device)  # [batch_size, n_tree]\n",
    "\n",
    "        Q = self.forward()\n",
    "        Q = Q.expand(\n",
    "            (path_prob.size()[0], self.config[\"output_dim\"], self.config[\"n_tree\"])\n",
    "        )  # -> [batch_size, n_class, n_tree]\n",
    "        return [[path_prob, Q]]\n",
    "\n",
    "    def reset(self):\n",
    "        pass\n",
    "\n",
    "\n",
    "class SoftTree(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        output_dim: int,\n",
    "        max_depth: int,\n",
    "        scale: float,\n",
    "        bias_scale: float,\n",
    "        n_tree: int,\n",
    "        asym: bool = False,\n",
    "        sparse: bool = False,\n",
    "    ):\n",
    "        super(SoftTree, self).__init__()\n",
    "        config = {\n",
    "            \"input_dim\": input_dim,\n",
    "            \"output_dim\": output_dim,\n",
    "            \"max_depth\": max_depth,\n",
    "            \"scale\": scale,\n",
    "            \"bias_scale\": bias_scale,\n",
    "            \"n_tree\": n_tree,\n",
    "        }\n",
    "        self.config = config\n",
    "        if sparse:\n",
    "            self.root = SparseInnerNode(config, depth=1, asym=asym)\n",
    "        else:\n",
    "            self.root = InnerNode(config, depth=1, asym=asym)\n",
    "\n",
    "        self.collect_parameters()\n",
    "\n",
    "    def collect_parameters(self):\n",
    "        nodes = [self.root]\n",
    "        self.module_list = nn.ModuleList()\n",
    "        self.param_list = nn.ParameterList()\n",
    "        while nodes:\n",
    "            node = nodes.pop(0)\n",
    "            if node.leaf:\n",
    "                param = node.param\n",
    "                self.param_list.append(param)\n",
    "            else:\n",
    "                fc = node.fc\n",
    "                nodes.append(node.right)\n",
    "                nodes.append(node.left)\n",
    "                self.module_list.append(fc)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config[\"input_dim\"])\n",
    "\n",
    "        path_prob_init = torch.Tensor(torch.ones(x.shape[0], self.config[\"n_tree\"]))\n",
    "\n",
    "        leaf_accumulator = self.root.calc_prob(x, path_prob_init)\n",
    "        pred = torch.zeros(x.shape[0], self.config[\"output_dim\"]).to(device)\n",
    "        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop\n",
    "            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)\n",
    "\n",
    "        pred /= np.sqrt(self.config[\"n_tree\"])  # NTK scaling\n",
    "\n",
    "        self.root.reset()\n",
    "        return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SoftTreeExp(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        output_dim: int,\n",
    "        max_depth: int,\n",
    "        scale: float,\n",
    "        bias_scale: float,\n",
    "        n_tree: int,\n",
    "        asym: bool=False,\n",
    "        sparse: bool=True,\n",
    "        finetune: bool=False\n",
    "    ):\n",
    "        super(SoftTreeExp, self).__init__()\n",
    "        config = {\n",
    "            \"input_dim\": input_dim,\n",
    "            \"output_dim\": output_dim,\n",
    "            \"scale\": scale,\n",
    "            \"bias_scale\": bias_scale,\n",
    "            \"n_tree\": n_tree,\n",
    "            \"max_depth\": max_depth\n",
    "        }\n",
    "        self.config = config\n",
    "        \n",
    "        assert sparse # only for sparse tree\n",
    "        assert finetune <= sparse\n",
    "        \n",
    "        if finetune: # AAI\n",
    "           #depth=1\n",
    "            self.root = SparseFinetuneInnerNode(config, depth=1, feature_index=0)\n",
    "            #depth=2\n",
    "            self.root.left = SparseFinetuneInnerNode(config, depth=2, feature_index=1)\n",
    "            self.root.right = SparseFinetuneInnerNode(config, depth=2, feature_index=1)\n",
    "        else: # AAA\n",
    "            # depth=1\n",
    "            self.root = SparseInnerNode(config, depth=1, feature_index=0)\n",
    "            #depth=2\n",
    "            self.root.left = SparseInnerNode(config, depth=2, feature_index=1)\n",
    "            self.root.right = SparseInnerNode(config, depth=2, feature_index=1)\n",
    " \n",
    "        #depth=3\n",
    "        self.root.left.left = LeafNode(config)\n",
    "        self.root.left.right = LeafNode(config)\n",
    "        self.root.right.left = LeafNode(config)\n",
    "        self.root.right.right = LeafNode(config)\n",
    "\n",
    "        self.collect_parameters()\n",
    "\n",
    "    def collect_parameters(self):\n",
    "        nodes = [self.root]\n",
    "        self.module_list = nn.ModuleList()\n",
    "        self.param_list = nn.ParameterList()\n",
    "        while nodes:\n",
    "            node = nodes.pop(0)\n",
    "            if node.leaf:\n",
    "                param = node.param\n",
    "                self.param_list.append(param)\n",
    "            else:\n",
    "                fc = node.fc\n",
    "                nodes.append(node.right)\n",
    "                nodes.append(node.left)\n",
    "                self.module_list.append(fc)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config[\"input_dim\"])\n",
    "\n",
    "        path_prob_init = torch.Tensor(torch.ones(x.shape[0], self.config[\"n_tree\"]))\n",
    "\n",
    "        leaf_accumulator = self.root.calc_prob(x, path_prob_init)\n",
    "        pred = torch.zeros(x.shape[0], self.config[\"output_dim\"])\n",
    "        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop\n",
    "            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)\n",
    "\n",
    "        pred /= np.sqrt(self.config[\"n_tree\"])  # NTK scaling\n",
    "\n",
    "        self.root.reset()\n",
    "        return pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.datasets import load_diabetes\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if False: # real dataset\n",
    "    import pandas as pd\n",
    "    from sklearn.datasets import load_diabetes\n",
    "    from sklearn.preprocessing import StandardScaler\n",
    "    from sklearn.model_selection import train_test_split\n",
    "\n",
    "    X, y = load_diabetes(as_frame = True, return_X_y=True)\n",
    "    X = X[['bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6', 'age', 'sex']]\n",
    "    target_scaler = StandardScaler()\n",
    "    target_scaler.fit(y.values.reshape(-1, 1))\n",
    "    y = pd.Series((target_scaler.transform(y.values.reshape(-1, 1)).squeeze()), name=\"y\")\n",
    "\n",
    "    for c in X.columns:\n",
    "        feature_scaler = StandardScaler()\n",
    "        feature_scaler.fit(X[c].values.reshape(-1,1))\n",
    "        X[c] = (feature_scaler.transform(X[c].values.reshape(-1, 1)).squeeze())\n",
    "\n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.95)\n",
    "\n",
    "    train_data = torch.Tensor(X_train.values)[0:50]\n",
    "    target_data = torch.tensor(y_train.values)[0:50]\n",
    "    test_data = torch.Tensor(X_test.values)[0:10]\n",
    "else: # random dataset\n",
    "    n_features = 2\n",
    "    n_dataset = 10\n",
    "    train_data = torch.Tensor([np.random.randn(n_features) for i in range(n_dataset)])\n",
    "    target_data = torch.tensor(np.random.randn(train_data.shape[0]))\n",
    "    test_data = torch.Tensor([np.random.randn(n_features) for i in range(10)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "source": [
    "## Tracking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:35.163984Z",
     "start_time": "2021-09-27T03:15:35.158843Z"
    }
   },
   "outputs": [],
   "source": [
    "def train_net(net, n_epochs, input_data, target, lr, initial_train):\n",
    "    criterion = nn.MSELoss(reduction='mean')\n",
    "    optimizer = optim.SGD(net.parameters(), lr=lr)\n",
    "    for epoch in range(n_epochs):  \n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(input_data.double())-initial_train.unsqueeze(1)\n",
    "        loss = criterion(outputs.view(-1), target)/2\n",
    "        loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:35.173329Z",
     "start_time": "2021-09-27T03:15:35.166262Z"
    }
   },
   "outputs": [],
   "source": [
    "def analytical_evolution_MSE(t, lr, H_train, H_test, initial_train, initial_test, target_data):\n",
    "    n_train = len(initial_train)\n",
    "\n",
    "    # first compute the exponential of the matrix (using eigendecomposition):\n",
    "    lam, P = np.linalg.eig(H_train)  # eig decomposition\n",
    "    lam = lam.astype(dtype='float64')\n",
    "\n",
    "    H_train_inv = np.dot(P, np.dot(np.diag(lam**(-1)), P.transpose()))\n",
    "\n",
    "    # note that you need to rescale the time by n_train, as the 2 paper use different convention for the loss function\n",
    "    # I am using np arrays, not torch tensors\n",
    "    exp_matrix = np.dot(\n",
    "        P, np.dot(np.diag(np.exp(-lr * t * lam / n_train)), P.transpose()))\n",
    "\n",
    "    # compute the prediction on train set\n",
    "    pred_train = target_data.cpu().numpy() + np.dot(exp_matrix,\n",
    "                                                    (initial_train - target_data).cpu().detach().numpy())\n",
    "\n",
    "    # compute the intermediate matrix used both in prediction on test set and weights evolution\n",
    "    tmp = np.dot(np.eye(lam.size) - exp_matrix,\n",
    "                 (initial_train - target_data).cpu().detach().numpy())\n",
    "    tmp = np.dot(H_train_inv, tmp)\n",
    "\n",
    "    # compute prediction on test set\n",
    "    pred_test = np.dot(H_test, tmp)\n",
    "    pred_test = initial_test.detach().cpu().numpy().reshape(-1) - pred_test\n",
    "\n",
    "    return pred_train, pred_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hardtree_viz(\n",
    "    X: np.array, alpha: float, beta: float, finetune: bool\n",
    "):\n",
    "    S_list = []\n",
    "    tau_list = []\n",
    "    tau_dot_list = []\n",
    "\n",
    "    for feature_index in range(len(X[0])):\n",
    "        S = np.outer(X[:, feature_index], X[:, feature_index].T) + beta**2\n",
    "        S_all = np.matmul(X, X.T) + beta**2\n",
    "        if finetune:\n",
    "            S_list.append(S_all)\n",
    "        else:\n",
    "            S_list.append(S)\n",
    "\n",
    "        _diag = [S[i, i] for i in range(len(S))]\n",
    "        diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))\n",
    "        diag_j = diag_i.transpose()\n",
    "        tau_list.append(calc_tau(alpha, S, diag_i, diag_j))\n",
    "        tau_dot_list.append(calc_tau_dot(alpha, S, diag_i, diag_j))\n",
    "        \n",
    "    K = np.zeros((X.shape[0], X.shape[0]))\n",
    "   \n",
    "    rulelist = [[0, 1], [0, 1], [0, 1], [0, 1]]\n",
    "    H = np.zeros_like(S_list[0])\n",
    "    for rules in rulelist:\n",
    "        \n",
    "        # Internal nodes\n",
    "        for i, s in enumerate(rules):\n",
    "            ts = rules[0:i]+rules[i+1:]\n",
    "            _H_nodes = S_list[s]* tau_dot_list[s]\n",
    "            for t in ts:\n",
    "                _H_nodes *= tau_list[t]\n",
    "            K+= _H_nodes\n",
    "        _H_leaves = np.ones_like(K)\n",
    "        \n",
    "        # Leaves\n",
    "        for tau in [tau_list[i] for i in rules]:\n",
    "            _H_leaves *= tau\n",
    "        K += _H_leaves\n",
    "    \n",
    "    return K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_trajectory(finetune: bool):\n",
    "    alpha = 2.0\n",
    "    beta = 0.5\n",
    "    depth = -1\n",
    "\n",
    "    H_analytical_train = hardtree_viz(train_data.numpy(), alpha=alpha, beta=beta, finetune=finetune)\n",
    "    H_analytical_test = hardtree_viz(torch.cat([train_data, test_data]).numpy(), alpha=alpha, beta=beta, finetune=finetune)[len(train_data):, 0:len(train_data)]\n",
    "\n",
    "    ptrain_empiricals1, ptest_empiricals1 = [], []\n",
    "    ptrain_empiricals2, ptest_empiricals2 = [], []\n",
    "\n",
    "    for n_tree in (16, 1024):\n",
    "        ptrain_empirical1, ptest_empirical1= [], []\n",
    "        ptrain_empirical2, ptest_empirical2 = [], []\n",
    "        ptrain_analytical1, ptest_analytical1 = [], []    \n",
    "        ptrain_analytical2, ptest_analytical2 = [], []\n",
    "\n",
    "        t_max = 1000\n",
    "        t_step = 10\n",
    "        lr = 0.1\n",
    "        t_list = np.arange(t_step, t_max+t_step, t_step)\n",
    "\n",
    "        st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, max_depth=depth, finetune=finetune)\n",
    "\n",
    "        initial_train1=st1.forward(train_data).reshape(-1)\n",
    "        initial_test1=st1.forward(test_data).reshape(-1)\n",
    "\n",
    "        ptrain_analytical1.append(torch.zeros_like(initial_train1).detach().numpy())\n",
    "\n",
    "        ptrain_empirical1.append(torch.zeros_like(initial_train1).detach().numpy())\n",
    "\n",
    "        ptest_analytical1.append(torch.zeros_like(initial_test1).detach().numpy())\n",
    "        ptest_empirical1.append(torch.zeros_like(initial_test1).detach().numpy())\n",
    "\n",
    "        for t in tqdm(t_list):\n",
    "            train_net(st1, t_step, train_data, target_data, lr, initial_train1.detach())\n",
    "\n",
    "            ptrain_empirical1.append(st1.forward(train_data).detach().cpu().numpy().reshape(-1)-initial_train1.detach().numpy())\n",
    "            ptest_empirical1.append(st1.forward(test_data).detach().cpu().numpy().reshape(-1)-initial_test1.detach().numpy())\n",
    "\n",
    "            pred_train, pred_test = analytical_evolution_MSE(\n",
    "                t=t,\n",
    "                lr=lr,\n",
    "                H_train=H_analytical_train, \n",
    "                H_test=H_analytical_test,\n",
    "                initial_train=torch.zeros_like(initial_train1), \n",
    "                initial_test=torch.zeros_like(initial_test1), \n",
    "                target_data=target_data\n",
    "            )\n",
    "            ptrain_analytical1.append(pred_train)\n",
    "            ptest_analytical1.append(pred_test)\n",
    "\n",
    "        ptrain_empiricals1.append(ptrain_empirical1)\n",
    "        ptest_empiricals1.append(ptest_empirical1)\n",
    "\n",
    "    cmap = plt.cm.nipy_spectral\n",
    "    t_list = np.arange(0, t_max+t_step, t_step)\n",
    "\n",
    "    for i in range(len(ptest_analytical1[0])):\n",
    "        plt.plot(t_list, np.array(ptest_analytical1)[:,i], color=cmap(i/len(ptest_analytical1[0])), alpha=0.3, linewidth=5)\n",
    "        plt.plot(t_list, np.array(ptest_empiricals1[0])[:,i], color=cmap(i/len(ptest_analytical1[0])), linestyle=\"dotted\")\n",
    "        plt.plot(t_list, np.array(ptest_empiricals1[1])[:,i], color=cmap(i/len(ptest_analytical1[0])), linestyle=\"dashed\")\n",
    "    plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "    if not finetune:\n",
    "        plt.ylabel(\"Model output\")\n",
    "    plt.title(f\"{'AAI' if finetune else 'AAA'}\")\n",
    "    plt.ylim(-2.0, 2.0)\n",
    "    plt.grid(linestyle=\"dotted\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(14,3))\n",
    "plt.subplot(1,2,1)\n",
    "plot_trajectory(finetune=False)\n",
    "plt.subplot(1,2,2)\n",
    "plot_trajectory(finetune=True)\n",
    "\n",
    "plt.subplot(1,2,2)\n",
    "plt.plot([], [], color=\"black\", label=\"Analytical\", linewidth=5, alpha=0.3)\n",
    "plt.plot([], [], color=\"black\", label=\"$M=16$\", linestyle=\"dotted\")\n",
    "plt.plot([], [], color=\"black\", label=\"$M=1024$\", linestyle=\"dashed\")\n",
    "plt.legend(ncol=3, bbox_to_anchor=(-0.1, -0.4), fontsize=15, loc=\"center\", borderaxespad=0)\n",
    "\n",
    "plt.savefig(\"./figures/trajectory.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
