{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:27.744893Z",
     "start_time": "2021-09-27T03:15:27.738004Z"
    }
   },
   "outputs": [],
   "source": [
    "def set_latex():\n",
    "    for i in range(2):\n",
    "        import matplotlib\n",
    "        import matplotlib.pyplot as plt\n",
    "\n",
    "        plt.rc('text', usetex=True)\n",
    "        plt.rc('font', family='serif')\n",
    "\n",
    "        plt.style.use(\"default\")\n",
    "        plt.rcParams[\"font.size\"]=15\n",
    "\n",
    "        plt.rcParams['font.family'] = 'Times New Roman'  # font familyの設定\n",
    "        plt.rcParams['mathtext.fontset'] = 'stix'  # math fontの設定\n",
    "\n",
    "        try:\n",
    "            del matplotlib.font_manager.weight_dict['roman']\n",
    "            matplotlib.font_manager._rebuild()\n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:32.310591Z",
     "start_time": "2021-09-27T03:15:27.749292Z"
    }
   },
   "outputs": [],
   "source": [
    "set_latex()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:35.145364Z",
     "start_time": "2021-09-27T03:15:32.313901Z"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "import numpy as np\n",
    "import copy\n",
    "import ntk\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import warnings \n",
    "from scipy import special\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.colors as colors\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "# from models import SoftTree\n",
    "from utils import compute_kernels, rotation_o\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "if os.environ.get(\"GPU\"):\n",
    "    device = os.environ.get(\"GPU\") if torch.cuda.is_available() else \"cpu\"\n",
    "else:\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "torch.set_default_dtype(torch.float64)\n",
    "\n",
    "\n",
    "class InnerNode:\n",
    "    def __init__(self, config, depth):\n",
    "        self.config = config\n",
    "        self.leaf = False\n",
    "        self.fc = nn.Linear(\n",
    "            self.config[\"input_dim\"], self.config[\"n_tree\"], bias=False\n",
    "        ).to(device)\n",
    "        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0\n",
    "        self.prob = None\n",
    "        self.path_prob = None\n",
    "        self.left = None\n",
    "        self.right = None\n",
    "        self.leaf_accumulator = []\n",
    "\n",
    "    def forward(self, x):  # decision function\n",
    "        return (\n",
    "            0.5 * torch.erf(self.config[\"scale\"] * self.fc(x)) + 0.5\n",
    "        )  # -> [batch_size, n_tree]\n",
    "\n",
    "    def calc_prob(self, x, path_prob):\n",
    "        self.prob = self.forward(x)  # probability of selecting right node\n",
    "        path_prob = path_prob.to(device)  # path_prob: [batch_size, n_tree]\n",
    "        self.path_prob = path_prob\n",
    "        left_leaf_accumulator = self.left.calc_prob(x, path_prob * (1 - self.prob))\n",
    "        right_leaf_accumulator = self.right.calc_prob(x, path_prob * self.prob)\n",
    "        self.leaf_accumulator.extend(left_leaf_accumulator)\n",
    "        self.leaf_accumulator.extend(right_leaf_accumulator)\n",
    "        return self.leaf_accumulator\n",
    "\n",
    "    def reset(self):\n",
    "        self.leaf_accumulator = []\n",
    "        self.penalties = []\n",
    "        self.left.reset()\n",
    "        self.right.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LeafNode:\n",
    "    def __init__(self, config):\n",
    "        self.config = config\n",
    "        self.leaf = True\n",
    "        self.param = nn.Parameter(\n",
    "            torch.randn(self.config[\"output_dim\"], self.config[\"n_tree\"]).to(device)\n",
    "        )  # [n_class, n_tree]\n",
    "        # self.param.requires_grad = False  # Freeze\n",
    "\n",
    "    def forward(self):\n",
    "        return self.param\n",
    "\n",
    "    def calc_prob(self, x, path_prob):\n",
    "        path_prob = path_prob.to(device)  # [batch_size, n_tree]\n",
    "\n",
    "        Q = self.forward()\n",
    "        Q = Q.expand(\n",
    "            (path_prob.size()[0], self.config[\"output_dim\"], self.config[\"n_tree\"])\n",
    "        )  # -> [batch_size, n_class, n_tree]\n",
    "        return [[path_prob, Q]]\n",
    "\n",
    "    def reset(self):\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SoftTreeExp(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        output_dim: int,\n",
    "        scale: float,\n",
    "        n_tree: int,\n",
    "        mode: bool = False,\n",
    "    ):\n",
    "        super(SoftTreeExp, self).__init__()\n",
    "        config = {\n",
    "            \"input_dim\": input_dim,\n",
    "            \"output_dim\": output_dim,\n",
    "            \"scale\": scale,\n",
    "            \"n_tree\": n_tree,\n",
    "        }\n",
    "        self.config = config\n",
    "        # self.root = InnerNode(config, depth=1, mode=mode, counter=0)\n",
    "        \n",
    "        if mode==1:\n",
    "            #depth=1\n",
    "            self.root = InnerNode(config, depth=1)\n",
    "            #depth=2\n",
    "            self.root.left = InnerNode(config, depth=2)\n",
    "            self.root.right = InnerNode(config, depth=2)\n",
    "            #depth=3\n",
    "            self.root.left.left = LeafNode(config)\n",
    "            self.root.left.right = LeafNode(config)\n",
    "            self.root.right.left = InnerNode(config, depth=3)\n",
    "            self.root.right.right = InnerNode(config, depth=3)\n",
    "            # depth=4\n",
    "            self.root.right.left.left = LeafNode(config)\n",
    "            self.root.right.left.right = LeafNode(config)\n",
    "            self.root.right.right.left = LeafNode(config)\n",
    "            self.root.right.right.right = LeafNode(config)\n",
    "    \n",
    "        elif mode==2:\n",
    "            #depth=1\n",
    "            self.root = InnerNode(config, depth=1)\n",
    "            #depth=2\n",
    "            self.root.left = InnerNode(config, depth=2)\n",
    "            self.root.right = InnerNode(config, depth=2)\n",
    "            #depth=3\n",
    "            self.root.left.left = LeafNode(config)\n",
    "            self.root.left.right = InnerNode(config, depth=3)\n",
    "            self.root.right.left = LeafNode(config)\n",
    "            self.root.right.right = InnerNode(config, depth=3)\n",
    "            # depth=4\n",
    "            self.root.left.right.left = LeafNode(config)\n",
    "            self.root.left.right.right = LeafNode(config)\n",
    "            self.root.right.right.left = LeafNode(config)\n",
    "            self.root.right.right.right = LeafNode(config)\n",
    "        else:\n",
    "            raise NotImplementedError        \n",
    "        \n",
    "        self.collect_parameters()\n",
    "\n",
    "    def collect_parameters(self):\n",
    "        nodes = [self.root]\n",
    "        self.module_list = nn.ModuleList()\n",
    "        self.param_list = nn.ParameterList()\n",
    "        while nodes:\n",
    "            node = nodes.pop(0)\n",
    "            if node.leaf:\n",
    "                param = node.param\n",
    "                self.param_list.append(param)\n",
    "            else:\n",
    "                fc = node.fc\n",
    "                nodes.append(node.right)\n",
    "                nodes.append(node.left)\n",
    "                self.module_list.append(fc)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config[\"input_dim\"])\n",
    "\n",
    "        path_prob_init = torch.Tensor(torch.ones(x.shape[0], self.config[\"n_tree\"]))\n",
    "\n",
    "        leaf_accumulator = self.root.calc_prob(x, path_prob_init)\n",
    "        pred = torch.zeros(x.shape[0], self.config[\"output_dim\"]).to(device)\n",
    "        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop\n",
    "            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)\n",
    "\n",
    "        pred /= np.sqrt(self.config[\"n_tree\"])  # NTK scaling\n",
    "\n",
    "        self.root.reset()\n",
    "        return pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:35.156771Z",
     "start_time": "2021-09-27T03:15:35.147632Z"
    }
   },
   "outputs": [],
   "source": [
    "def circle_transform(angle_vec):\n",
    "    cos_tensor = torch.cos(angle_vec)\n",
    "    sin_tensor = torch.sin(angle_vec)\n",
    "    return torch.stack((cos_tensor, sin_tensor), -1)\n",
    "\n",
    "n_features = 2\n",
    "n_dataset = 10\n",
    "\n",
    "train_data = torch.Tensor([np.random.randn(n_features) for i in range(n_dataset)])\n",
    "target_data = torch.tensor(np.random.randn(train_data.shape[0]))\n",
    "test_data = torch.Tensor([np.random.randn(n_features) for i in range(10)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "source": [
    "## Tracking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:35.163984Z",
     "start_time": "2021-09-27T03:15:35.158843Z"
    }
   },
   "outputs": [],
   "source": [
    "def train_net(net, n_epochs, input_data, target, lr, initial_train):\n",
    "    criterion = nn.MSELoss(reduction='mean')\n",
    "    optimizer = optim.SGD(net.parameters(), lr=lr)\n",
    "    for epoch in range(n_epochs):  \n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(input_data.double())-initial_train.unsqueeze(1)\n",
    "        loss = criterion(outputs.view(-1), target)/2\n",
    "        loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-27T03:15:45.786139Z",
     "start_time": "2021-09-27T03:15:45.780542Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "from utils import compute_kernels, rotation_o\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 2.0\n",
    "depth = 3\n",
    "\n",
    "ptrain_empiricals1, ptest_empiricals1 = [], []\n",
    "ptrain_empiricals2, ptest_empiricals2 = [], []\n",
    "\n",
    "for n_tree in (16, 4096):\n",
    "    ptrain_empirical1, ptest_empirical1= [], []\n",
    "    ptrain_empirical2, ptest_empirical2 = [], []\n",
    "    ptrain_analytical1, ptest_analytical1 = [], []    \n",
    "    ptrain_analytical2, ptest_analytical2 = [], []\n",
    "    \n",
    "    t_max = 1000\n",
    "    t_step = 10\n",
    "    lr = 0.1\n",
    "    t_list = np.arange(t_step, t_max+t_step, t_step)\n",
    "\n",
    "    st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, n_tree=n_tree, mode=1)\n",
    "    st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, n_tree=n_tree, mode=2)\n",
    "\n",
    "    initial_train1=st1.forward(train_data).reshape(-1)\n",
    "    initial_test1=st1.forward(test_data).reshape(-1)\n",
    "    initial_train2=st2.forward(train_data).reshape(-1)\n",
    "    initial_test2=st2.forward(test_data).reshape(-1)\n",
    "    \n",
    "    ptrain_empirical1.append(torch.zeros_like(initial_train1).detach().numpy())\n",
    "    ptrain_empirical2.append(torch.zeros_like(initial_train2).detach().numpy())\n",
    "\n",
    "    for t in tqdm(t_list):\n",
    "        train_net(st1, t_step, train_data, target_data, lr, initial_train1.detach())\n",
    "        train_net(st2, t_step, train_data, target_data, lr, initial_train2.detach())\n",
    "\n",
    "        ptrain_empirical1.append(st1.forward(train_data).detach().cpu().numpy().reshape(-1)-initial_train1.detach().numpy())\n",
    "        ptest_empirical1.append(st1.forward(test_data).detach().cpu().numpy().reshape(-1)-initial_test1.detach().numpy())\n",
    "        ptrain_empirical2.append(st2.forward(train_data).detach().cpu().numpy().reshape(-1)-initial_train2.detach().numpy())\n",
    "        ptest_empirical2.append(st2.forward(test_data).detach().cpu().numpy().reshape(-1)-initial_test2.detach().numpy())\n",
    "        \n",
    "    ptrain_empiricals1.append(ptrain_empirical1)\n",
    "    ptest_empiricals1.append(ptest_empirical1)\n",
    "    ptrain_empiricals2.append(ptrain_empirical2)\n",
    "    ptest_empiricals2.append(ptest_empirical2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = plt.cm.nipy_spectral\n",
    "t_list = np.arange(0, t_max+t_step, t_step)[0:100]\n",
    "\n",
    "plt.figure(figsize=(15, 3))\n",
    "plt.subplot(1,2,1)\n",
    "for i in range(len(ptest_empiricals1[0][0])):\n",
    "    if i==0:\n",
    "        plt.plot(t_list, np.array(ptest_empiricals1[0])[:,i], color=cmap(i/len(ptest_empiricals1[0][0])), label=\"Shape: A\")\n",
    "        plt.plot(t_list, np.array(ptest_empiricals2[0])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptest_empiricals1[0][0])), label=\"Shape: B\", linewidth=5, alpha=0.3)\n",
    "    else:\n",
    "        plt.plot(t_list, np.array(ptest_empiricals1[0])[:,i], color=cmap(i/len(ptest_empiricals1[0][0])))\n",
    "        plt.plot(t_list, np.array(ptest_empiricals2[0])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptest_empiricals1[0][0])), linewidth=5, alpha=0.3)\n",
    "plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "plt.ylabel(\"$f(x_i, w, \\pi)$\")\n",
    "plt.title(\"$M=16$\")\n",
    "plt.ylim(-1.5, 1.5)\n",
    "plt.legend(loc=\"upper left\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.subplot(1,2,2)\n",
    "for i in range(len(ptest_empiricals1[0][0])):\n",
    "    if i==0:\n",
    "        plt.plot(t_list, np.array(ptest_empiricals1[1])[:,i], color=cmap(i/len(ptest_empiricals1[0][0])), label=\"Shape: A\")\n",
    "        plt.plot(t_list, np.array(ptest_empiricals2[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptest_empiricals1[0][0])), label=\"Shape: B\", linewidth=5, alpha=0.3)\n",
    "    else:\n",
    "        plt.plot(t_list, np.array(ptest_empiricals1[1])[:,i], color=cmap(i/len(ptest_empiricals1[0][0])))\n",
    "        plt.plot(t_list, np.array(ptest_empiricals2[1])[:,i], linestyle=\"dashed\", color=cmap(i/len(ptest_empiricals1[0][0])), linewidth=5, alpha=0.3)\n",
    "plt.xlabel(\"$\\\\tau$ (iteration)\")\n",
    "plt.title(\"$M=4096$\")\n",
    "plt.ylim(-1.5, 1.5)\n",
    "\n",
    "plt.legend(loc=\"upper left\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "plt.savefig(\"./figures/trajectory.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
