{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"none","dataSources":[],"dockerImageVersionId":30588,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":false}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import torch\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\nfrom torchvision import datasets, transforms\nimport numpy as np\ndevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\nprint(device)\n\nclass SDT(nn.Module):\n    \"\"\"Fast implementation of soft decision tree in PyTorch.\n\n    Parameters\n    ----------\n    input_dim : int\n      The number of input dimensions.\n    output_dim : int\n      The number of output dimensions. For example, for a multi-class\n      classification problem with `K` classes, it is set to `K`.\n    depth : int, default=5\n      The depth of the soft decision tree. Since the soft decision tree is\n      a full binary tree, setting `depth` to a large value will drastically\n      increases the training and evaluating cost.\n    lamda : float, default=1e-3\n      The coefficient of the regularization term in the training loss. Please\n      refer to the paper on the formulation of the regularization term.\n    use_cuda : bool, default=False\n      When set to `True`, use GPU to fit the model. Training a soft decision\n      tree using CPU could be faster considering the inherent data forwarding\n      process.\n\n    Attributes\n    ----------\n    internal_node_num_ : int\n      The number of internal nodes in the tree. Given the tree depth `d`, it\n      equals to :math:`2^d - 1`.\n    leaf_node_num_ : int\n      The number of leaf nodes in the tree. Given the tree depth `d`, it equals\n      to :math:`2^d`.\n    penalty_list : list\n      A list storing the layer-wise coefficients of the regularization term.\n    inner_nodes : torch.nn.Sequential\n      A container that simulates all internal nodes in the soft decision tree.\n      The sigmoid activation function is concatenated to simulate the\n      probabilistic routing mechanism.\n    leaf_nodes : torch.nn.Linear\n      A `nn.Linear` module that simulates all leaf nodes in the tree.\n    \"\"\"\n\n    def __init__(\n            self,\n            input_dim,\n            output_dim,\n            depth=5,\n            lamda=1e-3,\n            use_cuda=False):\n        super(SDT, self).__init__()\n\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n\n        self.depth = depth\n        self.lamda = lamda\n        self.device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n\n        self._validate_parameters()\n\n        self.internal_node_num_ = 2 ** self.depth - 1\n        self.leaf_node_num_ = 2 ** self.depth\n\n        # Different penalty coefficients for nodes in different layers\n        self.penalty_list = [\n            self.lamda * (2 ** (-depth)) for depth in range(0, self.depth)\n        ]\n\n        # Initialize internal nodes and leaf nodes, the input dimension on\n        # internal nodes is added by 1, serving as the bias.\n        self.inner_nodes = nn.Sequential(\n            nn.Linear(self.input_dim + 1, self.internal_node_num_, bias=False),\n            nn.Sigmoid(),\n        )\n\n        self.leaf_nodes = nn.Linear(self.leaf_node_num_,\n                                    self.output_dim,\n                                    bias=False)\n\n    def forward(self, X, is_training_data=False):\n        _mu, _penalty = self._forward(X)\n        y_pred = self.leaf_nodes(_mu)\n\n        # When `X` is the training data, the model also returns the penalty\n        # to compute the training loss.\n        if is_training_data:\n            return y_pred, _penalty\n        else:\n            return y_pred\n\n    def _forward(self, X):\n        \"\"\"Implementation on the data forwarding process.\"\"\"\n        batch_size = X.size()[0]\n        X = self._data_augment(X)\n\n        path_prob = self.inner_nodes(X)\n        path_prob = torch.unsqueeze(path_prob, dim=2)\n        path_prob = torch.cat((path_prob, 1 - path_prob), dim=2)\n\n        _mu = X.data.new(batch_size, 1, 1).fill_(1.0)\n        _penalty = torch.tensor(0.0).to(self.device)\n\n        # Iterate through internal nodes in each layer to compute the final path\n        # probabilities and the regularization term.\n        begin_idx = 0\n        end_idx = 1\n\n        for layer_idx in range(0, self.depth):\n            _path_prob = path_prob[:, begin_idx:end_idx, :]\n\n            # Extract internal nodes in the current layer to compute the\n            # regularization term\n            _penalty = _penalty + self._cal_penalty(layer_idx, _mu, _path_prob)\n            _mu = _mu.view(batch_size, -1, 1).repeat(1, 1, 2)\n\n            _mu = _mu * _path_prob  # update path probabilities\n\n            begin_idx = end_idx\n            end_idx = begin_idx + 2 ** (layer_idx + 1)\n\n        mu = _mu.view(batch_size, self.leaf_node_num_)\n\n        return mu, _penalty\n\n    def _cal_penalty(self, layer_idx, _mu, _path_prob):\n        \"\"\"\n        Compute the regularization term for internal nodes in different layers.\n        \"\"\"\n\n        penalty = torch.tensor(0.0).to(self.device)\n\n        batch_size = _mu.size()[0]\n        _mu = _mu.view(batch_size, 2 ** layer_idx)\n        _path_prob = _path_prob.view(batch_size, 2 ** (layer_idx + 1))\n\n        for node in range(0, 2 ** (layer_idx + 1)):\n            alpha = torch.sum(\n                _path_prob[:, node] * _mu[:, node // 2], dim=0\n            ) / torch.sum(_mu[:, node // 2], dim=0)\n\n            coeff = self.penalty_list[layer_idx]\n\n            penalty -= 0.5 * coeff * (torch.log(alpha) + torch.log(1 - alpha))\n\n        return penalty\n\n    def _data_augment(self, X):\n        \"\"\"Add a constant input `1` onto the front of each sample.\"\"\"\n        batch_size = X.size()[0]\n        X = X.view(batch_size, -1)\n        bias = torch.ones(batch_size, 1).to(self.device)\n        X = torch.cat((bias, X), 1)\n\n        return X\n\n    def _validate_parameters(self):\n\n        if not self.depth > 0:\n            msg = (\"The tree depth should be strictly positive, but got {}\"\n                   \"instead.\")\n            raise ValueError(msg.format(self.depth))\n\n        if not self.lamda >= 0:\n            msg = (\n                \"The coefficient of the regularization term should not be\"\n                \" negative, but got {} instead.\"\n            )\n            raise ValueError(msg.format(self.lamda))","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"class TorchDataset(torch.utils.data.Dataset):\n\n        def __init__(self, *data, **options):\n            \n            n_data = len(data)\n            if n_data == 0:\n                raise ValueError(\"At least one set required as input\")\n\n            self.data = data\n            means = options.pop('means', None)\n            stds = options.pop('stds', None)\n            self.transform = options.pop('transform', None)\n            self.test = options.pop('test', False)\n            \n            if options:\n                raise TypeError(\"Invalid parameters passed: %s\" % str(options))\n            \n            if means is not None:\n                assert stds is not None, \"must specify both <means> and <stds>\"\n\n                self.normalize = lambda data: [(d - m) / s for d, m, s in zip(data, means, stds)]\n\n            else:\n                self.normalize = lambda data: data\n\n        def __len__(self):\n            return len(self.data[0])\n\n        def __getitem__(self, idx):\n            data = self.normalize([s[idx] for s in self.data])\n            if self.transform:\n\n                if self.test:\n                    data = sum([[self.transform.test_transform(d)] * 2 for d in data], [])\n                else:\n                    data = sum([self.transform(d) for d in data], [])\n                \n            return data","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"#@title Synthetic data\ndef set_npseed(seed):\n    np.random.seed(seed)\n\n\ndef set_torchseed(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n\n#classification data\n\ndef data_gen_decision_tree(num_data=1000, dim=2, seed=0, w_list=None, b_list=None,vals=None, num_levels=2):        \n    set_npseed(seed=seed)\n\n    # Construct a complete decision tree with 2**num_levels-1 internal nodes,\n    # e.g. num_levels=2 means there are 3 internal nodes.\n    # w_list, b_list is a list of size equal to num_internal_nodes\n    # vals is a list of size equal to num_leaf_nodes, with values +1 or 0\n    num_internal_nodes = 2**num_levels - 1\n    num_leaf_nodes = 2**num_levels\n    stats = np.zeros(num_internal_nodes+num_leaf_nodes) #stores the num of datapoints at each node so at 0(root) all data points will be present\n\n    if vals is None: #when val i.e., labels are not provided make the labels dynamically\n        vals = np.arange(0,num_internal_nodes+num_leaf_nodes,1,dtype=np.int32)%2 #assign 0 or 1 label to the node based on whether its numbering is even or odd\n        vals[:num_internal_nodes] = -99 #we put -99 to the internal nodes as only the values of leaf nodes are counted\n\n    if w_list is None: #if the w values of the nodes (hyperplane eqn) are not provided then generate dynamically\n        w_list = np.random.standard_normal((num_internal_nodes, dim))\n        w_list = w_list/np.linalg.norm(w_list, axis=1)[:, None] #unit norm w vects\n        b_list = np.zeros((num_internal_nodes))\n\n    '''\n    np.random.random_sample\n    ========================\n    Return random floats in the half-open interval [0.0, 1.0).\n\n    Results are from the \"continuous uniform\" distribution over the\n    stated interval.  To sample :math:`Unif[a, b), b > a` multiply\n    the output of `random_sample` by `(b-a)` and add `a`::\n\n        (b - a) * random_sample() + a\n    '''\n\n#     data_x = np.random.random_sample((num_data, dim))*2 - 1. #generate the datas in range -1 to +1\n#     relevant_stats = data_x @ w_list.T + b_list #stores the x.wT+b value of each nodes for all data points(num_data x num_nodes) to check if > 0 i.e will follow right sub tree route or <0 and will follow left sub tree route\n#     curr_index = np.zeros(shape=(num_data), dtype=int) #stores the curr index for each data point from root to leaf. So initially a datapoint starts from root but then it can go to right or left if it goes to right its curr index will become 2 from 0 else 1 from 0 then in next iteration from say 2 it goes to right then it will become 6\n\n    data_x = np.random.standard_normal((num_data, dim))\n    data_x /= np.sqrt(np.sum(data_x**2, axis=1, keepdims=True))\n    relevant_stats = data_x @ w_list.T + b_list\n    curr_index = np.zeros(shape=(num_data), dtype=int)\n    \n    for level in range(num_levels):\n        nodes_curr_level=list(range(2**level - 1,2**(level+1)-1  ))\n        for el in nodes_curr_level:\n#             b_list[el]=-1*np.median(relevant_stats[curr_index==el,el])\n            relevant_stats[:,el] += b_list[el]\n        decision_variable = np.choose(curr_index, relevant_stats.T) #based on the curr index will choose the corresponding node value of the datapoint\n\n        # Go down and right if wx+b>0 down and left otherwise.\n        # i.e. 0 -> 1 if w[0]x+b[0]<0 and 0->2 otherwise\n        curr_index = (curr_index+1)*2 - (1-(decision_variable > 0)) #update curr index based on the desc_variable\n        \n\n    bound_dist = np.min(np.abs(relevant_stats), axis=1) #finds the abs value of the minm node value of a datapoint. If some node value of a datapoint is 0 then that data point exactly passes through a hyperplane and we remove all such datapoints\n    thres = threshold\n    labels = vals[curr_index] #finally labels for each datapoint is assigned after traversing the whole tree\n\n    data_x_pruned = data_x[bound_dist>thres] #to distingush the hyperplanes seperately for 0 1 labels (classification)\n    #removes all the datapoints that passes through a node hyperplane\n    labels_pruned = labels[bound_dist>thres]\n    relevant_stats = np.sign(data_x_pruned @ w_list.T + b_list) #storing only +1 or -1 for a particular node if it is active or not\n    nodes_active = np.zeros((len(data_x_pruned),  num_internal_nodes+num_leaf_nodes), dtype=np.int32) #stores node actv or not for a data\n\n    for node in range(num_internal_nodes+num_leaf_nodes):\n        if node==0:\n            stats[node]=len(relevant_stats) #for root node all datapoints are present\n            nodes_active[:,0]=1 #root node all data points active status is +1\n            continue\n        parent = (node-1)//2\n        nodes_active[:,node]=nodes_active[:,parent]\n        right_child = node-(parent*2)-1 # 0 means left, 1 means right 1 has children 3,4\n        #finds if it is a right child or left of the parent\n        if right_child==1:\n            nodes_active[:,node] *= relevant_stats[:,parent]>0 #if parent node val was >0 then this right child of parent is active\n        if right_child==0:\n            nodes_active[:,node] *= relevant_stats[:,parent]<0 #else left is active\n        stats = nodes_active.sum(axis=0) #updates the status i.e., no of datapoints active in that node (root has all active then gradually divided in left right)\n    return ((data_x_pruned, labels_pruned), (w_list, b_list, vals), stats)","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"class Dataset_syn:\n    def __init__(self, dataset, data_path='./DATA'):\n        if dataset ==\"syn\":\n            self.X_train = train_data\n            self.y_train = train_data_labels\n            self.X_valid = vali_data\n            self.y_valid = vali_data_labels\n            self.X_test = test_data\n            self.y_test = test_data_labels\n        self.data_path = data_path\n        self.dataset = dataset","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"**Change this cell to run different configs**","metadata":{}},{"cell_type":"code","source":"def onehot_coding(target, device, output_dim):\n    \"\"\"Convert the class labels into one-hot encoded vectors.\"\"\"\n    target_onehot = torch.FloatTensor(target.size()[0], output_dim).to(device)\n    target_onehot.data.zero_()\n    target_onehot.scatter_(1, target.view(-1, 1), 1.0)\n    return target_onehot\n\n\n# Define dictionaries\nseed=365\nnum_levels=4\nthreshold = 0 #data seperation distance\noutput_dim=1\n\n\ndata_configs = [\n    {\"input_dim\": 20, \"num_data\": 40000},\n    {\"input_dim\": 100, \"num_data\": 60000},\n    {\"input_dim\": 500, \"num_data\": 100000}\n]\n\n# Code block to run for each dictionary\nfor config in data_configs:\n    input_dim = config[\"input_dim\"]\n    num_data = config[\"num_data\"]\n\n    \n    \n    ((data_x, labels), (w_list, b_list, vals), stats) = data_gen_decision_tree(\n                                                dim=input_dim, seed=seed, num_levels=num_levels,\n                                                num_data=num_data)\n    seed_set=seed\n    w_list_old = np.array(w_list)\n    b_list_old = np.array(b_list)\n    print(sum(labels==1))\n    print(sum(labels==0))\n    print(\"Seed= \",seed_set)\n    num_data = len(data_x)\n    num_train= num_data//2\n    num_vali = num_data//4\n    num_test = num_data//4\n    \n    train_data = data_x[:num_train,:]\n    train_data_labels = labels[:num_train]\n\n    vali_data = data_x[num_train:num_train+num_vali,:]\n    vali_data_labels = labels[num_train:num_train+num_vali]\n\n    test_data = data_x[num_train+num_vali :,:]\n    test_data_labels = labels[num_train+num_vali :]\n\n    # Parameters\n    input_dim = input_dim    # the number of input dimensions\n    output_dim = 2        # the number of outputs (i.e., # classes on MNIST)\n    depth = 5              # tree depth\n    lamda = 1e-3           # coefficient of the regularization term\n    lr = 1e-3              # learning rate\n    weight_decaly = 5e-4   # weight decay\n    batch_size = 128       # batch size\n    epochs = 500           # the number of training epochs\n    log_interval = 100     # the number of batches to wait before printing logs\n    use_cuda = True       # whether to use GPU\n\n    # Model and Optimizer\n    tree = SDT(input_dim, output_dim, depth, lamda, use_cuda).to(device)\n    tree=tree.float()\n\n    optimizer = torch.optim.Adam(tree.parameters(),\n                                 lr=lr,\n                                 weight_decay=weight_decaly)\n    DATA_NAME = \"syn\"\n    data = Dataset_syn(DATA_NAME)\n    BATCH_SIZE=batch_size\n    train_loader = DataLoader(TorchDataset(data.X_train, data.y_train), batch_size=BATCH_SIZE, num_workers=16, shuffle=True)\n    valloader = DataLoader(TorchDataset(data.X_valid, data.y_valid), batch_size=BATCH_SIZE*2, num_workers=16, shuffle=False)\n    test_loader = DataLoader(TorchDataset(data.X_test, data.y_test), batch_size=BATCH_SIZE*2, num_workers=16, shuffle=False)\n\n    # Utils\n    best_testing_acc = 0.0\n    testing_acc_list = []\n    training_loss_list = []\n    criterion = nn.CrossEntropyLoss()\n    device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n\n    for epoch in range(epochs):\n\n        # Training\n        tree.train()\n        for batch_idx, (data, target) in enumerate(train_loader):\n\n            batch_size = data.size()[0]\n            data, target = data.to(device), target.to(device)\n            # Convert to torch.int64\n            target = target.to(torch.int64)\n            data=data.float()\n            target_onehot = onehot_coding(target, device, output_dim)\n            output, penalty = tree.forward(data, is_training_data=True)\n\n            loss = criterion(output, target.view(-1))\n            loss += penalty\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            # Print training status\n            if batch_idx % log_interval == 0:\n                pred = output.data.max(1)[1]\n                correct = pred.eq(target.view(-1).data).sum()\n\n                msg = (\n                    \"Epoch: {:02d} | Batch: {:03d} | Loss: {:.5f} |\"\n                    \" Correct: {:03d}/{:03d}\"\n                )\n                print(msg.format(epoch, batch_idx, loss, correct, batch_size))\n                training_loss_list.append(loss.cpu().data.numpy())\n\n        # Evaluating\n        tree.eval()\n        correct = 0.\n\n        for batch_idx, (data, target) in enumerate(test_loader):\n\n            batch_size = data.size()[0]\n            data, target = data.to(device), target.to(device)\n            data=data.float()\n            output = F.softmax(tree.forward(data), dim=1)\n\n            pred = output.data.max(1)[1]\n            correct += pred.eq(target.view(-1).data).sum()\n\n        accuracy = 100.0 * float(correct) / len(test_loader.dataset)\n\n        if accuracy > best_testing_acc:\n            best_testing_acc = accuracy\n\n        msg = (\n            \"\\nEpoch: {:02d} | Testing Accuracy: {}/{} ({:.3f}%) |\"\n            \" Historical Best: {:.3f}%\\n\"\n        )\n        print(\n            msg.format(\n                epoch, correct,\n                len(test_loader.dataset),\n                accuracy,\n                best_testing_acc\n            )\n        )\n        testing_acc_list.append(accuracy)","metadata":{"trusted":true},"execution_count":null,"outputs":[]}]}