from .tpnn import ATPNN, TPNN
import numpy as np
import copy
from scipy.stats import invgamma

class TPNNS:
    def __init__(self, data: np.ndarray, config: dict) -> list[TPNN]:
        self.p = data.shape[1]
        self.alpha = config['alpha']
        self.gamma = config['gamma']
        self.max_depth = min(config['max_depth'], self.p)

        depth_prior = [self.alpha * (1+d)**(-self.gamma) for d in range(self.max_depth + 1)]
        depth_prior = np.array(depth_prior)
        depth_prior = np.cumprod(depth_prior[:-1]) * (1-depth_prior[1:])
        depth_prior /= np.sum(depth_prior).item()
        self.depth_prior = depth_prior

        w = config['feature_weight']
        w /= np.sum(w).item()
        self.w = w
        assert len(self.w) == self.p

        self.model = []
        self.variable_sets = np.array([])           # np.array of set
        for t in range(config['K_max']):
            if config['init_from_prior']:           # usually false
                tpnn_t = self.tpnn_from_prior(data, config)
            else :
                tpnn_t = TPNN(self.p, config['y_dist'])
            self.model.append(tpnn_t)
            self.variable_sets = np.append(self.variable_sets, set([a.variable for a in tpnn_t.structure]))
        assert len(self.model) == config['K_max']
        assert len(self.variable_sets) == config['K_max']

        self.zs = np.ones((config['K_max'],))

        self.nui = config['nui']
        if config['y_dist'] == 'normal':
            assert self.nui
        else :
            self.nui = np.random.normal(0, 1, size=1).item() * np.sqrt(config['const_var'])

    def tpnn_from_prior(self, data: np.ndarray, config: dict, variable_set: np.ndarray=None) -> TPNN:
        """
        if variable_set is given, sample corresponding b and gamma
        else, we sample the b, gamma from their !! prior !! (note that config['feature_weight'] is applied only on proposal)
        RETURNs TPNN object
        """
        tpnn = TPNN(self.p, config['y_dist'])
        if variable_set is None or len(variable_set) == 0:
            selected_depth = np.random.choice(np.arange(1 + self.max_depth), size = 1, p = self.depth_prior).item()
            variable_set = np.random.choice(np.arange(self.p), size = (selected_depth,), replace = False)

        for new_variable in variable_set:
            new_b = np.random.uniform(config['b_hyperparam1'], config['b_hyperparam2'], size = 1).item()
            new_gamma = np.random.gamma(shape = config['gamma_shape'], scale = config['gamma_scale'], size = 1).item()
            component = ATPNN(new_variable, data, new_b, new_gamma)
            tpnn.structure.append(component)
            tpnn.variables = tpnn.variables[tpnn.variables != new_variable]
            assert len(tpnn.variables) + len(tpnn.structure) == self.p

        prior_height = np.random.normal(0, 1, size = 1).item() * np.sqrt(config['var_height']/config['K_max'])
        tpnn.height = prior_height
        
        return tpnn

    def new_atpnn(self, tpnn: TPNN, data: np.ndarray, config: dict):
        """
        RETURNs an ATPNN object
        This object uses a variable from the compliment of variables used in the given tpnn
        """
        cand_weight = self.w[tpnn.variables]
        cand_weight /= np.sum(cand_weight).item()
        new_variable = np.random.choice(tpnn.variables, size = 1, p = cand_weight).item()
        new_b = np.random.uniform(config['b_hyperparam1'], config['b_hyperparam2'], size = 1).item()
        new_gamma = np.random.gamma(shape = config['gamma_shape'], scale = config['gamma_scale'], size = 1).item()
        component = ATPNN(new_variable, data, new_b, new_gamma)

        return component


    def add(self, t: int, lambda_t: np.ndarray, data: np.ndarray, y: np.ndarray, config: dict):
        # new atpnn
        tpnn = copy.deepcopy(self.model[t])

        if len(tpnn.structure) >= config['max_depth']:
            self.model[t] = tpnn
            return False
        
        if len(tpnn.variables) == 0 :
            self.model[t] = tpnn
            return False

        # new component - which is not used in tpnn
        component = self.new_atpnn(tpnn, data, config)

        # new tpnn
        new_tpnn = copy.deepcopy(tpnn)
        new_tpnn.structure.append(component)
        assert len(new_tpnn.structure) == len(tpnn.structure) + 1
        new_tpnn.variables = new_tpnn.variables[new_tpnn.variables != component.variable]
        assert len(new_tpnn.variables) + len(new_tpnn.structure) == self.p

        # log acceptance ratio
        log_lr = new_tpnn.log_likelihood(lambda_t, data, y, self.nui) - tpnn.log_likelihood(lambda_t, data, y, self.nui)      # float
        d = len(new_tpnn.structure)
        log_structure_pr = np.log(self.alpha) - self.gamma * np.log(d) + np.log(1-self.alpha*(1+d)**(-self.gamma)) - np.log(1-self.alpha*d**(-self.gamma))
        log_propose_ratio = np.log(config['del_prob']) - np.log(config['add_prob']) + np.log( np.sum(self.w[tpnn.variables]) ) - np.log(self.w[component.variable]) - np.log(self.p - d + 1)
        
        log_acceptance_ratio = log_lr + log_structure_pr + log_propose_ratio
        if log_acceptance_ratio > 0 :
            acceptance_ratio = 1.
        else :
            acceptance_ratio = np.exp(log_acceptance_ratio).item()

        # dice
        dice = np.random.uniform(0, 1, size = 1).item()
        if dice < acceptance_ratio :
            self.model[t] = new_tpnn
            self.variable_sets[t] = set([a.variable for a in new_tpnn.structure])
            return True
        else :
            self.model[t] = tpnn
            return False

    def delete(self, t: int, lambda_t: np.ndarray, data: np.ndarray, y: np.ndarray, config: dict):
        tpnn = copy.deepcopy(self.model[t])
        d = len(tpnn.structure)

        if len(tpnn.structure) == 0 :
            self.model[t] = tpnn
            return False

        # select atpnn
        what_to_remove = np.random.choice(np.arange(d), size = 1).item()

        # new tpnn
        new_tpnn = copy.deepcopy(tpnn)
        removed_atpnn = new_tpnn.structure.pop(what_to_remove)
        assert len(new_tpnn.structure) == len(tpnn.structure) - 1
        new_tpnn.variables = np.append(new_tpnn.variables, removed_atpnn.variable)
        assert len(new_tpnn.variables) + len(new_tpnn.structure) == self.p

        # log acceptance ratio
        log_lr = new_tpnn.log_likelihood(lambda_t, data, y, self.nui) - tpnn.log_likelihood(lambda_t, data, y, self.nui)      # float
        log_structure_pr = -np.log(self.alpha) + self.gamma * np.log(d) + np.log(1-self.alpha*d**(-self.gamma)) - np.log(1-self.alpha*(1+d)**(-self.gamma))
        log_propose_ratio = np.log(config['add_prob']) - np.log(config['del_prob']) + np.log(self.w[removed_atpnn.variable]) - np.log( np.sum(self.w[new_tpnn.variables]) ) + np.log( self.p- d + 1 )

        log_acceptance_ratio = log_lr + log_structure_pr + log_propose_ratio
        acceptance_ratio = np.exp(log_acceptance_ratio).item()

        # dice
        dice = np.random.uniform(0, 1, size = 1).item()
        if dice < acceptance_ratio :
            self.model[t] = new_tpnn
            self.variable_sets[t] = set([a.variable for a in new_tpnn.structure])
            return True
        else :
            self.model[t] = tpnn
            return False

    def change(self, t: int, lambda_t: np.ndarray, data: np.ndarray, y: np.ndarray, config: dict):
        tpnn = copy.deepcopy(self.model[t])
        d = len(tpnn.structure)

        if d == 0:
            return False
        
        if len(tpnn.variables) == 0 :
            self.model[t] = tpnn
            return False

        # new component
        component = self.new_atpnn(tpnn, data, config)

        # select atpnn to remove
        what_to_remove = np.random.choice(np.arange(d), size = 1).item()        # idx

        # new tpnn
        new_tpnn = copy.deepcopy(tpnn)
        ## remove
        removed_atpnn = new_tpnn.structure.pop(what_to_remove)
        new_tpnn.variables = np.append(new_tpnn.variables, removed_atpnn.variable)
        ## append
        new_tpnn.structure.append(component)
        new_tpnn.variables = new_tpnn.variables[new_tpnn.variables != component.variable]
        assert len(new_tpnn.structure) == len(tpnn.structure)
        assert len(new_tpnn.variables) + len(new_tpnn.structure) == self.p

        # log acceptance ratio
        log_lr = new_tpnn.log_likelihood(lambda_t, data, y, self.nui) - tpnn.log_likelihood(lambda_t, data, y, self.nui)
        log_propose_ratio = np.log(self.w[removed_atpnn.variable]) + np.log(np.sum(self.w[tpnn.variables])) - np.log(self.w[component.variable]) - np.log(np.sum(self.w[new_tpnn.variables]))
        log_acceptance_ratio = log_lr + log_propose_ratio
        acceptance_ratio = np.exp(log_acceptance_ratio).item()

        # dice
        dice = np.random.uniform(0, 1, size = 1).item()
        if dice < acceptance_ratio :
            self.model[t] = new_tpnn
            self.variable_sets[t] = set([a.variable for a in new_tpnn.structure])
            return True
        else :
            self.model[t] = tpnn
            return False

    def b_gamma_update(self, t: int, lambda_t: np.ndarray, data: np.ndarray, y: np.ndarray, config: dict):
        tpnn = copy.deepcopy(self.model[t])
        d = len(tpnn.structure)
        cur_bs = np.array([tt.b for tt in tpnn.structure])                          # (d,)
        cur_gammas = np.array([tt.b for tt in tpnn.structure])                      # (d,)
        bs_score, gammas_score = tpnn.b_gamma_score_function(lambda_t, data, y, self.nui)             # (d,)
        gammas_score += (config['gamma_shape']-1)/cur_gammas - 1/config['gamma_scale']      # (d,)

        # langevin proposal
        b_momentum = np.random.normal(0, 1, size = (d,))
        new_bs = cur_bs + config['bg_step_size']**2 * bs_score / 2 + config['bg_step_size'] * b_momentum

        gamma_momentum = np.random.normal(0, 1, size = (d,))
        new_gammas = cur_gammas + config['bg_step_size']**2 * gammas_score / 2 + config['bg_step_size'] * gamma_momentum

        # new tpnn
        new_tpnn = copy.deepcopy(tpnn)
        for idx, atpnn in enumerate(new_tpnn.structure):
            tmp_new_atpnn = ATPNN(var = atpnn.variable, data = data, b = new_bs[idx], gamma = new_gammas[idx])
            new_tpnn.structure[idx] = tmp_new_atpnn
        
        new_bs_score, new_gammas_score = new_tpnn.b_gamma_score_function(lambda_t, data, y, self.nui)
        new_gammas_score += (config['gamma_shape']-1)/new_gammas - 1/config['gamma_scale']

        new_b_momentum = b_momentum + config['bg_step_size'] * bs_score / 2 + config['bg_step_size'] * new_bs_score / 2
        new_gamma_momentum = gamma_momentum + config['bg_step_size'] * gammas_score / 2 + config['bg_step_size'] * new_gammas_score / 2

        log_acceptance_ratio = new_tpnn.log_likelihood(lambda_t, data, y, self.nui) - tpnn.log_likelihood(lambda_t, data, y, self.nui)
        log_acceptance_ratio += (config['gamma_shape']-1) * (np.sum(np.log(new_gammas))-np.sum(np.log(cur_gammas))).item()
        log_acceptance_ratio += -(np.sum(new_gammas) - np.sum(cur_gammas)).item()/config['gamma_scale']
        log_acceptance_ratio += -(np.sum(new_b_momentum**2) - np.sum(b_momentum**2)).item()/2
        log_acceptance_ratio += -(np.sum(new_gamma_momentum**2) - np.sum(gamma_momentum**2)).item()/2
        acceptance_ratio = np.exp(log_acceptance_ratio).item()

        # dice
        dice = np.random.uniform(0, 1, size = 1).item()
        if dice < acceptance_ratio :
            self.model[t] = new_tpnn
            return True
        else :
            self.model[t] = tpnn
            return False


    def height_update(self, t: int, lambda_t: np.ndarray, data: np.ndarray, y: np.ndarray, config: dict):
        tpnn = copy.deepcopy(self.model[t])

        # L-step Leapfrog
        momentum = np.random.normal(0, 1, size = 1).item()
        for l in range(config['leapfrog_L']):
            if l == 0:
                new_tpnn = copy.deepcopy(tpnn)
                stein_score = -new_tpnn.height * config['K_max']/config['var_height'] + new_tpnn.height_score_function(lambda_t, data, y, self.nui)
                new_momentum = momentum + (config['step_size']/config['leapfrog_L']) * stein_score / 2

            new_height = new_tpnn.height + (config['step_size']/config['leapfrog_L']) * new_momentum
            new_tpnn.height = new_height
            stein_score = -new_tpnn.height * config['K_max']/config['var_height'] + new_tpnn.height_score_function(lambda_t, data, y, self.nui)
            new_momentum += (config['step_size']/config['leapfrog_L']) * stein_score
        new_momentum -= (config['step_size']/config['leapfrog_L']) * stein_score/2

        log_acceptance_ratio = new_tpnn.log_likelihood(lambda_t, data, y, self.nui) - tpnn.log_likelihood(lambda_t, data, y, self.nui)          # log LR
        log_acceptance_ratio -= (new_tpnn.height**2 - tpnn.height**2) * config['K_max']/(2*config['var_height'])                                # log prior ratio
            
        log_acceptance_ratio -= (new_momentum**2 - momentum**2)/2
        acceptance_ratio = np.exp(log_acceptance_ratio).item()

        # dice
        dice = np.random.uniform(0, 1, size = 1).item()
        if dice < acceptance_ratio :
            self.model[t] = new_tpnn
            return True
        else :
            self.model[t] = tpnn
            return False

    def tpnns_update(self, data: np.ndarray, y: np.ndarray, config: dict):
        lambda_t = self.forward(data, config)
        adc_prob = np.array([config['add_prob'], config['del_prob'], config['change_prob']]); adc_prob_sum = np.sum(adc_prob).item()
        adc_prob /= adc_prob_sum

        self.adc_log = ''
        self.bg_log = ''
        self.height_log = ''

        for t in range(config['K_max']):
            # S_t
            if self.model[t].z :
                lambda_t -= self.model[t].forward(data)

                adc = np.random.choice(['add', 'del', 'change'], size = 1, p = adc_prob).item()
                if adc == 'add':
                    updated = self.add(t, lambda_t, data, y, config)
                elif adc == 'del':
                    updated = self.delete(t, lambda_t, data, y, config)
                else :
                    updated = self.change(t, lambda_t, data, y, config)

                if updated :
                    self.adc_log += adc[0]
                else :
                    self.adc_log += 's'
                
                # b, gamma
                if self.model[t].structure and config['bg_update']:
                    bg_updated = self.b_gamma_update(t, lambda_t, data, y, config)
                    if bg_updated:
                        self.bg_log += 'Y'
                    else :
                        self.bg_log += '_'
                else :
                    self.bg_log += 'Z'                  # zero structure

                # beta
                height_updated = self.height_update(t, lambda_t, data, y, config)
                if height_updated : 
                    self.height_log += 'H'
                else : 
                    self.height_log += '_'
                # print(self.height_log)

                lambda_t += self.model[t].forward(data)

            else :
                old_variable = self.variable_sets[t]
                k = int(np.sum(self.zs).item())

                # propose new variable set
                from_prior = np.random.uniform(0, 1, size=1).item()
                if from_prior < config['M'] / (config['M'] + k):
                    ref_depth = np.random.choice(np.arange(self.max_depth)+1, size = 1, p = self.depth_prior).item()        # must be int
                    ref_variable = np.random.choice(np.arange(self.p), size = (ref_depth,), replace = False)                # choose variable set
                    ref_variable = set(ref_variable)
                else :
                    ref_tpnn_idx = np.random.choice(np.arange(k), size = 1).item()
                    ref_variable = self.variable_sets[np.where(self.zs == 1)][ref_tpnn_idx]     # set
                    assert isinstance(ref_variable, set) or len(ref_variable) == 0
                    variable_cand = np.array(list(set(range(self.p)) - ref_variable))       # array(int)
                    if (len(variable_cand) > 0) and (len(ref_variable) < config['max_depth']):
                        cand_weight = self.w[variable_cand]
                        cand_weight /= np.sum(cand_weight).item()
                        var_to_add = np.random.choice(variable_cand, size = 1, p = cand_weight).item()      # int
                        ref_variable.add(var_to_add)        # set
                    assert len(ref_variable) >= 1
                
                # cal acceptance ratio
                if len(old_variable) == 0:
                    log_acceptance_ratio = 0.
                else :
                    lx1 = np.log(self.weight_pot_ref_tree(old_variable)).item() \
                        - np.log(self.depth_prior[len(old_variable)-1]).item() + self.log_combination(self.p, len(old_variable)) - np.log(config['M']).item()
                    lx2 = np.log(self.weight_pot_ref_tree(ref_variable)).item() \
                        - np.log(self.depth_prior[len(ref_variable)-1]).item() + self.log_combination(self.p, len(ref_variable)) - np.log(config['M']).item()
                    log_acceptance_ratio = np.logaddexp(0.0, lx1) - np.logaddexp(0.0, lx2)

                # update
                dice = np.random.uniform(0, 1)
                log_dice = np.log(dice).item()
                if log_dice < log_acceptance_ratio :
                    new_tpnn = self.tpnn_from_prior(data, config, np.array(list(ref_variable)))
                    new_tpnn.z = 0
                    self.model[t] = new_tpnn
                    self.variable_sets[t] = ref_variable
                    self.adc_log += 'z'
                    self.bg_log += '_'
                    self.height_log += '_'
                else :
                    self.adc_log += 's'
                    self.bg_log += '_'
                    self.height_log += '_'
                    
    def z_update(self, data: np.ndarray, y: np.ndarray, config: dict):
        z_idx = np.random.choice(np.arange(config['K_max']), size = 1).item()

        z_tpnn = copy.deepcopy(self.model[z_idx])
        assert z_tpnn.z == self.zs[z_idx].item()
        new_z_tpnn = copy.deepcopy(z_tpnn); new_z_tpnn.z = 1 - z_tpnn.z

        lambda_z_idx = self.forward(data, config); z_tpnn_forward = z_tpnn.forward(data)
        lambda_z_idx -= z_tpnn_forward

        z_tpnn_log_likelihood = z_tpnn.log_likelihood(lambda_z_idx, data, y, self.nui)
        new_z_tpnn_log_likelihood = new_z_tpnn.log_likelihood(lambda_z_idx, data, y, self.nui)
        log_lr = new_z_tpnn_log_likelihood - z_tpnn_log_likelihood

        n = data.shape[0]
        cur_zs_sum = np.sum(self.zs).item()
        if z_tpnn.z :
            log_pr = config['c0'] * np.log(n).item() + np.log(config['K_max']-cur_zs_sum+1).item() - np.log(cur_zs_sum).item()
        else :
            log_pr = -config['c0'] * np.log(n).item() + np.log(cur_zs_sum + 1).item() - np.log(config['K_max']-cur_zs_sum-1+1).item()

        log_acceptance_ratio = log_lr + log_pr
        acceptance_ratio = np.exp(log_acceptance_ratio).item()

        # dice
        dice = np.random.uniform(0, 1, size = 1).item()
        if dice < acceptance_ratio:
            self.model[z_idx] = new_z_tpnn
            self.zs[z_idx] = 1 - self.zs[z_idx]
            assert self.model[z_idx].z == self.zs[z_idx].item()
            self.z_updated = True
            # print(f'z updated from {1-new_z_tpnn.z} to {new_z_tpnn.z}')
        else :
            self.model[z_idx] = z_tpnn
            self.z_updated = False
            # print(f'z update rejected.')

    def nui_update(self, data: np.ndarray, y: np.ndarray, config: dict):
        if config['y_dist'] == 'normal':
            assert self.nui['sigma2']
            forwarded_value = self.forward(data, config)
            new_inv_gamma_1 = (self.nui['inv_gamma_nu'] + data.shape[0])/2
            new_inv_gamma_2 = (self.nui['inv_gamma_nu'] * self.nui['inv_gamma_lambda'] + np.sum((y - forwarded_value)**2).item())/2   
            self.nui['sigma2'] = invgamma.rvs(new_inv_gamma_1,0,new_inv_gamma_2, size=1).item()

        elif config['y_dist'] == 'ber':
            assert self.nui
            assert config['const_var'] > 0
            assert config['const_step_size'] > 0
            forwarded_value = self.forward(data, config)
            exp_forwarded_value = np.exp(forwarded_value)
            momentum = np.random.normal(0, 1, size = 1).item()

            log_posterior = np.sum(forwarded_value * y - np.log(1+exp_forwarded_value)).item() - self.nui**2/(2*config['const_var'])     # -U(beta_0)
            log_posterior_diff = np.sum(y - exp_forwarded_value/(1+exp_forwarded_value)).item() - self.nui / config['const_var']        # -dU(beta_0)/d(beta_0)
            new_nui = self.nui + (config['const_step_size'] ** 2 * log_posterior_diff)/2 + config['const_step_size'] * momentum

            # acceptance ratio
            new_forwarded_value = forwarded_value - self.nui + new_nui
            exp_new_forwarded_value = np.exp(new_forwarded_value)
            new_log_posterior = np.sum(new_forwarded_value * y - np.log(1+exp_new_forwarded_value)).item() - new_nui**2/(2*config['const_var'])     # -U(beta_0*)
            new_log_posterior_diff = np.sum(y - exp_new_forwarded_value/(1+exp_new_forwarded_value)).item() - new_nui / config['const_var']        # -dU(beta_0*)/d(beta_0*)
            new_momentum = momentum + config['const_step_size'] * log_posterior_diff / 2 + config['const_step_size'] * new_log_posterior_diff / 2

            log_acceptance_ratio = (new_log_posterior - log_posterior) - (new_momentum**2 - momentum**2)/2

            dice = np.random.uniform(0, 1, size = 1).item()
            if np.log(dice) < log_acceptance_ratio : 
                self.nui = new_nui
        
        elif config['y_dist'] == 'poisson':
            assert self.nui
            assert config['const_var'] > 0
            assert config['const_step_size'] > 0
            forwarded_value = self.forward(data, config)
            momentum = np.random.normal(0, 1, size = 1).item()
            log_posterior = np.sum(y * forwarded_value - np.exp(forwarded_value)).item() - self.nui**2/(2*config['const_var'])
            log_posterior_diff = np.sum(y - np.exp(forwarded_value)).item() - self.nui / config['const_var']
            new_nui = self.nui + (config['const_step_size'] ** 2 * log_posterior_diff)/2 + config['const_step_size'] * momentum

            new_forwarded_value = forwarded_value - self.nui + new_nui
            new_log_posterior = np.sum(y * new_forwarded_value - np.exp(new_forwarded_value)).item() - new_nui**2/(2*config['const_var'])
            new_log_posterior_diff = np.sum(y - np.exp(new_forwarded_value)).item() - new_nui / config['const_var']
            new_momentum = momentum + config['const_step_size'] * log_posterior_diff / 2 + config['const_step_size'] * new_log_posterior_diff / 2

            log_acceptance_ratio = (new_log_posterior - log_posterior) - (new_momentum**2 - momentum**2) / 2

            dice = np.random.uniform(0, 1, size = 1).item()
            if np.log(dice) < log_acceptance_ratio :
                self.nui = new_nui


    def forward(self, data: np.ndarray, config: dict) -> np.ndarray:
        result = np.zeros((data.shape[0],))
        for tpnn_t in self.model:
            t_forward = tpnn_t.forward(data) 
            result += t_forward
        if config['y_dist'] != 'normal':
            result += self.nui
        return result
    

    def evaluate(self, ss: int, data: np.ndarray, y: np.ndarray, test_data: np.ndarray, test_y: np.ndarray, config, print_res = False):
        assert data.shape[1] == test_data.shape[1]
        assert data.shape[0] == y.shape[0]
        assert test_data.shape[0] == test_y.shape[0]

        if config['y_dist'] == 'normal':
            train_fitted_value = self.forward(data, config)
            train_rmse = np.sqrt(np.mean((y - train_fitted_value)**2)).item() * config['y_std']
            
            test_fitted_value = self.forward(test_data, config)
            test_rmse = np.sqrt(np.mean((test_y - test_fitted_value)**2)).item() * config['y_std']

            self.train_metric = np.round(train_rmse, 3)
            self.test_metric = np.round(test_rmse, 3)
        
        elif config['y_dist'] == 'ber':
            train_fitted_value = self.forward(data, config)
            from sklearn.metrics import roc_auc_score
            train_auroc = roc_auc_score(y, train_fitted_value)

            test_fitted_value = self.forward(test_data, config)
            test_auroc = roc_auc_score(test_y, test_fitted_value)

            self.train_metric = np.round(train_auroc, 4)
            self.test_metric = np.round(test_auroc, 4)

        elif config['y_dist'] == 'poisson':
            train_fitted_value = self.forward(data, config)
            train_fitted_value = np.exp(train_fitted_value)
            train_rmse = np.sqrt(np.mean((y - train_fitted_value)**2)).item()
            
            test_fitted_value = self.forward(test_data, config)
            test_fitted_value = np.exp(test_fitted_value)
            test_rmse = np.sqrt(np.mean((test_y - test_fitted_value)**2)).item()

            self.train_metric = np.round(train_rmse, 3)
            self.test_metric = np.round(test_rmse, 3)

        if print_res :
            print(f'fold {config["fold"]} {ss+1}th TPNNS\ttrain : {self.train_metric}\ttest : {self.test_metric}\tz l1-norm:{np.sum(self.zs).astype(int).item()}')

    def log_combination(self, n: int, x:int):
        assert n >= x
        res = 0.
        for i in range(x):
            res += np.log(n-i).item() - np.log(x-i).item()
        return res

    def weight_pot_ref_tree(self, variable: set):
        """
        variable : S_t or S_t^new
        sum_{i:z_i=1} I(\exists j_i\in S_i^c s.t. S_i \cup {j_i} = S_t) * (w_{j_i})/(\sum_{v\in S_i^c} w_v)
        """
        weight = 0.
        z_one_variable_sets = self.variable_sets[np.where(self.zs == 1)]
        for tmp_variable_set in z_one_variable_sets:
            # tmp_variable_set : S_i, variable : S_t or S_t^new
            if len(tmp_variable_set - variable) == 0 and len(variable - tmp_variable_set) == 1:
                tmp_variable = int(list(variable - tmp_variable_set)[0])        # j_i
                assert isinstance(tmp_variable, int)
                tmp_variable_weight = self.w[tmp_variable].item()
                tmp_compliment_variable_weight = np.sum(self.w).item() - np.sum( self.w[list(tmp_variable_set)] ).item()
                weight += tmp_variable_weight/tmp_compliment_variable_weight
        
        return weight

        

