from vector_sbnModel import *
from utils import namenum
import torch


def sample_tree(tree_model, rooted=False, random_p=None):
    root = Tree()
    node_split_stack = [(root, '0' * tree_model.ntaxa + '1' * tree_model.ntaxa)]
    for i in range(tree_model.ntaxa - 1):
        node, split_bitarr = node_split_stack.pop()
        parent_clade_bitarr = bitarray(split_bitarr[tree_model.ntaxa:])
        node.clade_bitarr = parent_clade_bitarr
        node.split_bitarr = min([parent_clade_bitarr, ~parent_clade_bitarr]).to01()
        if node.is_root():
            split_prob = tree_model.rs_CPDs
            # split = self.rs_reverse_map[np.random.choice(len(split_prob), p=split_prob)]

            split = tree_model.rs_reverse_map[torch.multinomial(split_prob, 1).item()]
            if random_p is not None:
                eps = np.random.random()
                if eps < random_p:
                    idx = np.random.randint(0, len(tree_model.rs_CPDs))
                    split = tree_model.rs_reverse_map[idx]
        else:
            split_prob = tree_model.get_subsplit_CPDs(split_bitarr)
            # split = self.ss_reverse_map[split_bitarr][np.random.choice(len(split_prob), p=split_prob)]
            split = tree_model.ss_reverse_map[split_bitarr][torch.multinomial(split_prob, 1).item()]
            if random_p is not None:
                eps = np.random.random()
                if eps < random_p:
                    idx = np.random.randint(0, len(split_prob))
                    split = tree_model.ss_reverse_map[split_bitarr][idx]

        comp_split = (parent_clade_bitarr ^ bitarray(split)).to01()

        c1 = node.add_child()
        c2 = node.add_child()
        if split.count('1') > 1:
            node_split_stack.append((c1, comp_split + split))
        else:
            c1.name = tree_model.taxa[split.find('1')]
            c1.clade_bitarr = bitarray(split)
            c1.split_bitarr = min([c1.clade_bitarr, ~c1.clade_bitarr]).to01()
        if comp_split.count('1') > 1:
            node_split_stack.append((c2, split + comp_split))
        else:
            c2.name = tree_model.taxa[comp_split.find('1')]
            c2.clade_bitarr = bitarray(comp_split)
            c2.split_bitarr = min([c2.clade_bitarr, ~c2.clade_bitarr]).to01()

    if not rooted:
        root.unroot()
    return root