
class Node:
    def __init__(self, parent = None, depth = 3, *args, **kwargs):
        # Node pointers 
        self.parent = parent # Is None for root node. 
        self.right = None 
        self.left = None  
        
        # Have we tried splitting this node? 
        self._split_attempt = False 
        
        # What parameter row corresponds to this node in the tree. 
        self.split_reference = None
        
        # Determine the depth of this node
        self.set_depth()
        
        
        
    def set_depth(self):
        if self.is_root():
            self.keys_set = False 
            self.depth = 0
        else:
            self.depth = self.parent.depth + 1 
            
    def is_root(self):
        return self.parent is None 
        
    def is_leaf(self):
        return self.right is None and self.left is None
        
    def split(self, split_reference):
        # Only split if it is a leaf. 
        assert(self.is_leaf())  
        
        self._split_attempt = True 
        
        #Assign split parameters to this node. 
        self.split_reference = split_reference 
        
        # Create left and right nodes
        self.left = Node(self)
        self.right = Node(self)
        
        return 
    
    def is_left_child(self):
        if self.is_root():
            return False 
        else:
            if self.parent.left == self:
                return True
            else:
                return False
    
    def is_right_child(self):
        if self.is_root():
            return False
        else:
            if self.parent.right == self:
                return True
            else:
                return False 
            
            
    ########### Functions for printing ###########
    """
        Shamelessly stolen from stackoverflow https://stackoverflow.com/questions/34012886/print-binary-tree-level-by-level-in-python
    """
    def display(self):
        if self.is_root():
            if self.keys_set:
                lines, *_ = self._display_aux()
                for line in lines:
                    print(line)
            else:
                print("Keys not set yet")
                
        else:
            print("Cannot call self.display on a non-root node. Try using the get_root function.")

    def _display_aux(self):
        """Returns list of strings, width, height, and horizontal coordinate of the root."""
        # No child.
        if self.right is None and self.left is None:
            line = '%s' % self.key
            width = len(line)
            height = 1
            middle = width // 2
            return [line], width, height, middle

        # Only left child.
        if self.right is None:
            lines, n, p, x = self.left._display_aux()
            s = '%s' % self.key
            u = len(s)
            first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s
            second_line = x * ' ' + '/' + (n - x - 1 + u) * ' '
            shifted_lines = [line + u * ' ' for line in lines]
            return [first_line, second_line] + shifted_lines, n + u, p + 2, n + u // 2

        # Only right child.
        if self.left is None:
            lines, n, p, x = self.right._display_aux()
            s = '%s' % self.key
            u = len(s)
            first_line = s + x * '_' + (n - x) * ' '
            second_line = (u + x) * ' ' + '\\' + (n - x - 1) * ' '
            shifted_lines = [u * ' ' + line for line in lines]
            return [first_line, second_line] + shifted_lines, n + u, p + 2, u // 2

        # Two children.
        left, n, p, x = self.left._display_aux()
        right, m, q, y = self.right._display_aux()
        s = '%s' % self.key
        u = len(s)
        first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s + y * '_' + (m - y) * ' '
        second_line = x * ' ' + '/' + (n - x - 1 + u + y) * ' ' + '\\' + (m - y - 1) * ' '
        if p < q:
            left += [n * ' '] * (q - p)
        elif q < p:
            right += [m * ' '] * (p - q)
        zipped_lines = zip(left, right)
        lines = [first_line, second_line] + [a + u * ' ' + b for a, b in zipped_lines]
        return lines, n + m + u, max(p, q) + 2, n + u // 2





class variationalRegressionTree:
    # """
    #     X: number of datapoints by dimension 
    # """
    def __init__(self, depth, X, y, device = 'cpu', bias = 5.0):
        ###### These two parameters define a distribution over trees ######
        # log of the kumaraswamy parameters 
        self.b = 2**(depth - 1) # Number of leaves of a complete tree of depth d. 
        self.d = int(X.shape[1]) # Dimension of the data 

        self.D = int(X.shape[0])
        
        self.device = device
        
        

        ### Variational parameters ###
        w = 1000
        self.γ = torch.randn((self.b - 1,), device=device).double()/w # Logits of the bernoulli splits
        self.γ = self.γ   + bias*torch.ones_like(self.γ, device=device)
        # self.γ = 10*torch.ones_like(self.γ, device=device)
        self.γ.requires_grad_()

        self.μ = torch.randn(self.b - 1, self.d, device=device).double()/w
        self.μ.requires_grad_() # Gaussian splits 

        self.σ = torch.randn(self.b - 1, self.d, device=device).double()/w - 10.0 
        self.σ.requires_grad_() # Log variances of splits 
        self.ζ = torch.randn((2*self.b - 1, self.d), device=device).double()/w
        self.ζ.requires_grad_() # Mean parameter 

        self.ξ = torch.randn((2*self.b - 1,), device=device).double()/w - 10.0 # Log variance parameter 
        self.ξ.requires_grad_()

        ### Prior parameters ###

        self.γ2 = torch.randn((2*self.b - 1,), device=device).double()/w # Logits of the bernoulli splits
        self.γ2 = self.γ2  + bias*torch.ones_like(self.γ2)
        # self.γ2 = 10*torch.ones_like(self.γ2, device=device)
        # self.γ2.requires_grad_()
        
        self.μ2 = torch.randn(self.b - 1, self.d, device=device).double()/w
        self.μ2.requires_grad_() # Gaussian splits 
        
        self.σ2 = torch.randn(self.b - 1, self.d, device=device).double()/w - 10.0 
        self.σ2.requires_grad_() # Log variances

        self.ζ2 = torch.randn((2*self.b - 1, self.d), device=device).double()/w
        self.ζ2.requires_grad_() # Parameters (together with x this forms the mean.)

        self.ξ2 = torch.randn((2*self.b - 1,), device=device).double()/w - 10.0 
        self.ξ2.requires_grad_() 

        # Temperature
        self.τ = torch.tensor(1.00, device=device).double()
        # self.τ.requires_grad_() # Temperature to anneal the softmax-gumbel samples. 

        # May eventually be needed for inference. 
        self._tree_sample = None # Root node of the tree. 


        self.s = torch.zeros(self.γ.shape[0],device=device)  # Bernoulli/GumbelSoftmax samples used to reconstruct the splitting rules. 
        
        # Reparametrization trick samples. 
        self.ϵ_rules = None # Noise variables used to reconstruct the tree split rules. 
        self.ϵ_leaf = None # Parameters defining the distributions at each of the leaf nodes. 


        self.depth = depth # Maximum depth of the tree. 

        if not torch.is_tensor(X):
            self.X = torch.tensor(X)
        else:
            self.X = X

        if not torch.is_tensor(y):
            self.y = torch.tensor(y)
        else:
            self.y = y 


        ######### FOR THE PURPOSES OF VECTORIZING WIZARDRY ONLY #########
        self.num_rows = 2*self.b - 1 

        self.num_columns = int(torch.log2(torch.tensor(self.b,  device=device)) + 1)

        self.M2 = torch.arange(1, self.num_rows + 1, device=device).view(-1,1)*torch.ones(self.num_rows, self.num_columns, device=device) #Row index (k)
        self.M = torch.arange(1, self.num_columns + 1, device=device).view(1,-1)*torch.ones(self.num_rows, self.num_columns, device=device) #Column index (i)

        self.I = torch.floor(self.M2/(2**self.M)).long()
        self.I2 = torch.floor(self.M2/2**(self.M-1)).long()

    def get_max_gradient(self):
        with torch.no_grad():
            params = [self.γ, self.γ2,
                    self.μ,
                    self.σ,
                    self.ζ,
                    self.ξ,
                    self.μ2,
                    self.σ2,
                    self.ζ2,
                    self.ξ2]
            params_grad = [] 
            for param in params:
                if param.grad is None:
                    return 
                params_grad.append(param.grad.view(-1).cpu().numpy())

            return max(list(map(np.max, params_grad))), min(list(map(np.min, params_grad)))
        
    def _create_tree(self):
        
        ###### TODO: debug this function ######
        with torch.no_grad():
            α = torch.round(logistic(self.γ))
            root = Node(depth = self.depth)
            tree = create_tree(α, self.depth - 1, root)
            return tree

    def _predict_deterministic(self, x, η):
        with torch.no_grad():
            if η.is_leaf():
                ref_num = η.split_reference
                return torch.dot(x, self.ζ[ref_num])
            
            ref_num = η.split_reference
            score = torch.dot(x, self.μ[ref_num])
            if score >= 0:
                return self._predict_deterministic(x, η.left)
            
            else:
                return self._predict_deterministic(x, η.right)


    def predict_deterministic(self, x):
        tree = self._create_tree()
        return self._predict_deterministic(x, tree)


    def predict_stochastic(self, x):
        η, β, θ = self.sample_tree_posterior() 
        return self._predict_stochastic(x, η, β, θ)

    def predictive_posterior(self, x, samples = 100):
        η, β, θ = self.sample_tree_posterior() 
        stochastic_preds = [] 
        for _ in range(samples):
            pred = self.predict_stochastic(x)
            stochastic_preds.append(pred)

        return stochastic_preds



    def _predict_stochastic(self, x, η, β, θ):
        with torch.no_grad():
            if η.is_leaf():
                ref_num = η.split_reference
                return torch.normal(torch.dot(x, θ[ref_num]), torch.exp(0.5*self.ξ[ref_num])).item()
            
            ref_num = η.split_reference
            score = torch.dot(x, β[ref_num])

            coin_flip = torch.bernoulli(logistic(score)).item()

            if coin_flip == 1:
                return self._predict_stochastic(x, η.left, β, θ)
            
            elif coin_flip == 0:
                return self._predict_stochastic(x, η.right, β, θ)

            else:
                raise Exception("Error...")

    def sample_posterior_predictive_vectorized(self, x, samples = 100):
        predictions = []
        for _ in tqdm(range(samples), position = 0, leave = False):
            predictions.append(self.sample_tree_posterior_vectorized(x).tolist())
        predictions = torch.tensor(predictions)
        return torch.mean(predictions, dim = 0) #torch.quantile(predictions, 0.5, dim=0)


    def predict(self, x, samples = 100):
        return self.sample_posterior_predictive_vectorized(x, samples = samples)

    def sample_tree_posterior_vectorized(self, x):
        with torch.no_grad():
            self.sample_tree_parameters()

            s = logistic(self.γ)
            s = torch.bernoulli(s)

            leaf_ind = self.get_leaf_indicators_sample(s)

            β = self.ϵ_rules*torch.exp(0.5*self.σ) + self.μ

            probabilities = β@x.T

            path_left = torch.bernoulli(logistic(probabilities))
            path_right = 1-path_left 

            mask = (self.I != torch.zeros_like(self.I)).double()

            idx = (self.I-1)*mask
            idx = idx.long()

            s_left = path_left[idx] 
            s_right = path_right[idx]

            s_left[mask == 0] = 1
            s_right[mask == 0] = 1

            check_tensor = torch.hstack((torch.arange(1, self.I.shape[0] + 1).view(-1,1).to(self.device), self.I))[:,:-1]*mask
            check_tensor[0,0] = 1

            ind_even = (check_tensor % 2).double()
            ind_odd = 1 - ind_even

            ind_even = ind_even.unsqueeze(-1)
            ind_odd = ind_odd.unsqueeze(-1)

            s_left = ind_odd*s_left
            s_right = ind_even*s_right


            T = s_right + s_left 
            mask_aug= mask.unsqueeze(-1).repeat(1,1,T.shape[-1])
            T[mask_aug == 0] = 1
            p = T.prod(dim = 1)
            p = p*leaf_ind.unsqueeze(-1)
            p = p.T
            example_leafs = torch.argmax(p, dim = 1)
            mean = torch.sum(self.ζ[example_leafs]*x, dim = 1)
            return mean 





    def sample_tree_posterior(self):
        with torch.no_grad():
            s, β, θ = self.sample_tree_posterior_parameters() # Samples the splits which are stored in 

            α = torch.round(s)
            root = Node(depth = self.depth)
            tree = create_tree(α, self.depth - 1, root)

            return tree, β, θ


    def sample_tree_posterior_parameters(self):
        with torch.no_grad():
            π = logistic(self.γ) # Turns the logits into actual probabilities

            s = torch.bernoulli(π)

            ϵ_rules = torch.randn(self.b - 1, self.d, device=self.device)
            σ_sd = torch.exp(0.5*self.σ)        

            β = ϵ_rules*σ_sd + self.μ # The sample is from the prior 

            ϵ_leaf = torch.randn(2*self.b - 1, self.d, device=self.device)

            ξ_sd = torch.exp(0.5*self.ξ).view(-1,1)

            θ = ϵ_leaf*ξ_sd + self.ζ

        return s, β, θ

    def sample_splits(self):
        ### We may backpropagate through these calculations ###
        g = -torch.log(-torch.log(torch.rand(self.b - 1, 2,  device=self.device)))
        π = logistic(self.γ) # Turns the logits into actual probabilities
        y = (g + torch.stack((torch.log(π), torch.log(1-π)), dim = 1))
        self.s = softmax(y/self.τ, dim = 1)[:,0]
        return 



    def sample_tree_parameters(self):
        self.sample_splits() #Samples the s's from a SoftmaxGumbel if self.train else from Bernoulli
        self.ϵ_rules = torch.randn(self.b - 1, self.d, device=self.device)
        self.ϵ_leaf = torch.randn(2*self.b - 1, self.d, device=self.device)
        return  




    def train(self, epochs = 1000, lr0 = 0.1, lrf = 1e-3, clip_norm = 10.0, optimizer = 'clippedAdam', h1 = 1, h2 = -1, tree_num = None):

        lrd = (lrf/lr0)**(1/epochs) # Learning rate decay 
        
        self.τ =  1.00
        lrd_τ = (0.1/self.τ)**(1/epochs) # Anneal this parameter


        if optimizer == 'clippedAdam':

            optim = ClippedAdam([self.γ, self.γ2, 
                self.μ,
                self.σ,
                self.ζ,
                self.ξ,
                self.μ2,
                self.σ2,
                self.ζ2,
                self.ξ2
                ], lr=lr0, lrd = lrd, betas = (0.95, 0.999),
                clip_norm=clip_norm)

        elif optimizer == 'adaGrad':
            optim = Adagrad([  self.γ, self.γ2,

                self.μ,
                self.σ,
                self.ζ,
                self.ξ,
                self.μ2,
                self.σ2,
                self.ζ2,
                self.ξ2
                ], lr=lr0)
            
        else:
            optim = Adam([ self.γ, self.γ2, 
                self.μ,
                self.σ,
                self.ζ,
                self.ξ,
                self.μ2,
                self.σ2,
                self.ζ2,
                self.ξ2], lr=lr0)
        losses = [] 

        with tqdm(range(epochs), leave = False, position=0) as t:
            for i in t:
                optim.zero_grad()
                if i == 0 or i % 100 == 0:
                    self.sample_tree_parameters()
                else:
                    self.sample_splits()
                loss, likelihood, cross_entropy, entropy = self.calculate_negative_ELBO(h1, h2)
                if tree_num is None:
                    t.set_description(f"Iteration {i + 1}/{epochs} ")
                else:
                    t.set_description(f"Iteration {i + 1}/{epochs}, Tree: {tree_num} ")
                t.set_postfix(ELBO=-1*loss.item(), likelihood=likelihood.item(), ce = h1*cross_entropy.item(), e = h2*entropy.item(), lr = optim.param_groups[0]['lr'], )
                # t.set_postfix(ELBO=-1*loss.item(), lr = optim.param_groups[0]['lr'], τ=self.τ)
                loss.backward()
                # print(self.get_max_gradient())

                with torch.no_grad():
                    losses.append(-1*loss.item())
                    self.τ *= lrd_τ
                    
                optim.step()
                
        return losses
                
                
    def calculate_negative_ELBO(self, h1= 1, h2 = -1):

        likelihood =   self.calculate_full_likelihood()
        cross_entropy = self.calculate_cross_entropy_vectorized()  
        entropy = self.calculate_entropy_vectorized()
        H = likelihood + h1*cross_entropy + h2*entropy 
        H = -1*H

        return H, likelihood, cross_entropy, entropy 
    
    def calculate_full_likelihood(self):
        return self.calculate_likelihood_batch(self.X, self.y)

    def calculate_likelihood_batch(self, X, y):
        leaf_indicators = self.get_leaf_indicators_vectorized()
        O = self.get_path_indicators_vectorized_batch(X)
        
        ξ_sd = torch.exp(0.5*self.ξ).view(-1,1)

        mean_prediction_vector = (self.ζ@X.T).T

        weights = 1/ξ_sd
        error = torch.pow(weights.view(1,-1)*(y.view(-1,1) - mean_prediction_vector), 2)
        error = -0.5*error

        log_var = -0.5*(self.ξ.view(-1)).view(1,-1) - (1/2)*torch.log(torch.tensor(2*3.14159).to(self.device))
        O = torch.clip(O, min=1e-6, max=1e9)
        leaf_indicators = torch.clip(leaf_indicators, min=1e-6, max=1e9)
        log_f = log_var + error + torch.log(O) + torch.log(leaf_indicators)

        logT = torch.logsumexp(log_f, dim = 1)
        partition_function = torch.log(leaf_indicators) + torch.log(O)
        logS = torch.logsumexp(partition_function, dim = 1)
        return (logT - logS).sum()
    
    
    def calculate_entropy_vectorized(self):
        ψ = self.get_int_indicators_vectorized()

        σ_var = torch.exp(self.σ)
        σ_sd = torch.exp(0.5*self.σ)
        
        
        β = self.μ + self.ϵ_rules*σ_sd   # The sample is from the prior 

        T = -0.5*self.σ - 0.5*torch.pow((self.μ - β)/σ_sd,2)

        T = T.sum(dim=1)
        T = T - (self.d/2)*torch.log(torch.tensor(2*3.14159).to(self.device)) # This constant factor actually matters because of the inner product in T1

#         First sum

        T1 = ψ@(torch.log(logistic(self.γ)) + T)

        κ = self.get_leaf_indicators_vectorized()

        ξ_sd = torch.exp(0.5*self.ξ).view(-1,1)
        ξ_var = torch.exp(self.ξ).view(-1,1)

        θ =  self.ζ + self.ϵ_leaf*ξ_sd 

        S =  -0.5*torch.pow((self.ζ - θ)/ξ_sd, 2) - 0.5*self.ξ.view(-1,1)

        S = S.sum(dim=1)
        S = S - (self.d/2)*torch.log(torch.tensor(2*3.14159).to(self.device))

        T2 = κ@(S) + κ[:self.b - 1]@torch.log(1-logistic(self.γ))
        entropy = T1 + T2
        return entropy
    
    def calculate_cross_entropy_vectorized(self):
        ψ = self.get_int_indicators_vectorized()

        σ_sd = torch.exp(0.5*self.σ)
        σ2_sd = torch.exp(0.5*self.σ2)
        
        β = self.μ  + self.ϵ_rules*σ_sd # The sample is from the prior 
        
        T = -0.5*self.σ2 - 0.5*torch.pow((self.μ2 - β)/σ2_sd,2)
        T = T.sum(dim=1)
        T = T - (self.d/2)*torch.log(torch.tensor(2*3.14159).to(self.device)) # This constant factor actually matters because of the inner product in T1
        # First sum
        T1 = ψ@(torch.log(logistic(self.γ2)[:self.b - 1]) + T)
        
        κ = self.get_leaf_indicators_vectorized()

        ξ_sd = torch.exp(0.5*self.ξ).view(-1,1)
        ξ2_sd = torch.exp(0.5*self.ξ2).view(-1,1)

        θ = self.ζ + self.ϵ_leaf*ξ_sd 
        S =  -0.5*torch.pow((self.ζ2 - θ)/ξ2_sd, 2) - 0.5*self.ξ2.view(-1,1)
        S = S.sum(dim=1) 
        S = S - (self.d/2)*torch.log(torch.tensor(2*3.14159).to(self.device))

        T2 = κ@(S + torch.log(1-logistic(self.γ2)))
        cross_entropy = T1 + T2
        
        return cross_entropy 
    
    def get_int_indicators_vectorized(self):
        # with torch.no_grad():
        I = self.I[:self.b - 1]
        I2 = self.I2
        M = self.M
        M2 = self.M2  
        shat = torch.hstack((torch.tensor([1.0], device=self.device), self.s))
        # print(shat)
        C = shat.repeat(self.b - 1,1)
        # print(C)
        C = torch.gather(C, 1, I)
        # print(I)
        C = C.prod(dim=1)
        C = C*self.s
        # print(C)

        return C

    def get_leaf_indicators_sample(self, s):
        """
            Given a bit string defining a tree, it returns a bit vectors 
            with a 1 on positions where the nodes are leaves! 
        """
        # with torch.no_grad():
        I = self.I 
        I2 = self.I2
        M = self.M
        M2 = self.M2  
        shat = torch.hstack((torch.tensor([1.0], device=self.device), s))
        C = shat.repeat(2*self.b - 1 ,1)
        C = torch.gather(C, 1, I)
        C = C.prod(dim=1)

        v = torch.hstack((1-s, torch.ones(self.b, device=self.device)))
        C = C*v
        return C

    def get_leaf_indicators_vectorized(self):
        # with torch.no_grad():
        I = self.I 
        I2 = self.I2
        M = self.M
        M2 = self.M2  
        shat = torch.hstack((torch.tensor([1.0], device=self.device), self.s))

        C = shat.repeat(2*self.b - 1 ,1)
        C = torch.gather(C, 1, I)
        C = C.prod(dim=1)

        v = torch.hstack((1-self.s, torch.ones(self.b, device=self.device)))
        C = C*v
        return C
    
    def get_path_indicators_vectorized_batch(self, X):
        
        σ_var = torch.exp(0.5*self.σ)
        β = σ_var*self.ϵ_rules + self.μ

        ζhat = torch.vstack((torch.zeros_like(β[0]), β))

        num_rows = self.num_rows
        num_cols = self.num_columns

        I = self.I
        I2 = self.I2 

        Iaug = self.I.unsqueeze(2).repeat(1,1,ζhat.shape[1])

        ζhat_aug = ζhat.unsqueeze(0).repeat(num_rows,1,1) 
        C = torch.gather(ζhat_aug,1,Iaug)

        A = C@X.T  ### A[i,k,j] = β_{i//2^k} x_{j}

        B = (-1*torch.ones(num_rows, num_cols, device=self.device))**(I2)
        B = B.unsqueeze(2).repeat(1,1,A.shape[2])
        M = A*B
        O = logistic(A*B)
        O[O == 0.5] = 1.0
        O = O.prod(dim=1)
        O = O.T
        
        return O

    def rmse(self, yhat, y):
        return torch.sqrt(((yhat.to(self.device).view(-1) - y.to(self.device).view(-1))**2).mean()).item()
        

class CVTree:
    def __init__(self, X, y, n_trees = 10, device = 'cpu', tree_depth = 4, bias = 0.0):
        X_train, X_cv, y_train, y_cv = train_test_split(X, y, test_size = 0.05, random_state = 1)
        
        
        if not torch.is_tensor(X_train):
            self.X_train = torch.tensor(X_train).double().to(device)    
            self.X_cv = torch.tensor(X_cv).double().to(device)
        else:
            self.X_train = X_train.double().to(device)
            self.X_cv = X_cv.double().to(device)
            
        if not torch.is_tensor(y):
            self.y_train = torch.tensor(y_train).double().to(device)
            self.y_cv = torch.tensor(y_cv).double().to(device)
            
        else:
            self.y_train = y_train.double().to(device)
            self.y_cv = y_cv.double().to(device)
            
        self.trees = [] 
        
        if isinstance(tree_depth, list):
            for tree_idx, depth in zip(range(n_trees), tree_depth):
                self.trees.append(variationalRegressionTree(depth, X_train, y_train, device=device, bias = bias))
        
        
        else:
            for _ in range(n_trees):
                self.trees.append(variationalRegressionTree(tree_depth, X_train, y_train, device=device, bias = bias))
                
        self.best_tree = None
        
        self.device = device
        
    def rmse(self, yhat, y):
        return torch.sqrt(((yhat.to(self.device).view(-1) - y.to(self.device).view(-1))**2).mean()).item()
        
    def train_trees(self, epochs = 1000, lr0 = 0.1, lrf = 0.001, clip_norm = 10.0, optimizer = 'clippedAdam', samples = 100):
        cv_scores = []
        min_rmse = float('inf')
        best_tree = None
        
        
        if isinstance(epochs, list):
            if len(epochs) != len(self.trees):
                raise Exception("Epochs array must be equal to the number of trees trained.")
            for t, num_epochs in zip(range(len(self.trees)), epochs):
                tree = self.trees[t]
                tree.train(epochs = num_epochs, lr0 = lr0, lrf = lrf, clip_norm = clip_norm, optimizer = optimizer, tree_num = t + 1)

                yhat_cv = tree.sample_posterior_predictive_vectorized(self.X_cv, samples = samples)

                rmse = self.rmse(yhat_cv, self.y_cv)

                if rmse <= min_rmse:
                    best_tree = t
                    min_rmse = rmse

            ## Choose tree with best cv score 
            self.best_tree = self.trees[best_tree]
            return min_rmse
            
        else:
            for t in range(len(self.trees)):
                tree = self.trees[t]
                tree.train(epochs = epochs, lr0 = lr0, lrf = lrf, clip_norm = clip_norm, optimizer = optimizer, tree_num = t + 1)

                yhat_cv = tree.sample_posterior_predictive_vectorized(self.X_cv, samples = samples)

                rmse = self.rmse(yhat_cv, self.y_cv)

                if rmse <= min_rmse:
                    best_tree = t
                    min_rmse = rmse

            ## Choose tree with best cv score 
            self.best_tree = self.trees[best_tree]
            return min_rmse
    
    def predict(self, x, samples = 100):
        assert(self.best_tree is not None)
        
        return self.best_tree.sample_posterior_predictive_vectorized(x, samples = samples)
    

        
            



if __name__ == 'variationalRegressionTree':
    import numpy as np 
    import torch 
    import clippedAdam

    from clippedAdam import ClippedAdam 

    from torch import softmax  
    from math import floor 
    from tqdm import tqdm, trange 
    from torch.optim import Adam , SGD, Adagrad
    from time import sleep

    from utils_reg import logistic, set_keys, create_tree

    from sklearn.model_selection import train_test_split
