import numpy as np

def tree_gen(pdf_values, elements, bp, bot):
    #pdf values: np array of size |elements| x bp containing bits for probabilities
    #elements: np array with names of elements
    #bp: bits of precision
    #bot: element to return for excess probability

    m = {1:0}
    cur_id = 0
    layers = [[1]]
    leaves = []
    leaves_elements = []

    for i in range(bp):
        
        new_nodes = []
        #elements to be leaves at the next layer
        bi_elements = elements[np.where(pdf_values[:,i]==1)]

        for node in layers[i]:
            if node not in leaves:
                #add child nodes to new nodes for layer, and update m
                new_nodes.append(2*node)
                m[2*node] = cur_id+1
                cur_id += 1

                new_nodes.append(2*node+1)
                m[2*node+1] = cur_id+1
                cur_id +=1

        n_leaves = len(bi_elements)
        
        #Add leaves
        if n_leaves != 0:
            leaves = leaves + new_nodes[-n_leaves:]
            leaves_elements = leaves_elements + bi_elements.tolist()

        #Add layer with new nodes
        layers.append(new_nodes)

        #Add bot elements for remaining nodes on last layer
        if i == bp-1:
            if n_leaves != 0:
                leaves = leaves + new_nodes[:-n_leaves]
                print("here")
                leaves_elements = leaves_elements + (bot*np.ones(len(new_nodes[:-n_leaves])).astype(int)).tolist()
            else:
                leaves = leaves + new_nodes
                leaves_elements = leaves_elements + (bot*np.ones(len(new_nodes)).astype(int)).tolist()


    return m, layers, leaves, leaves_elements, cur_id

def truncated_discrete_Gaussian(bt,bp,sigma):
    #bt: bits for the domain [-2**(bt-1)-1,2**(bt-1)]
    #bp: bits for precision starting with 1/2 bit [101] = 1/2 + 1/8
    #sigma: std for the Gaussian
    
    x = np.arange(2**(bt-1))

    p = np.power(x,2)
    p = np.exp(-p/(2*sigma))

    #full domain normalization constant (pos and neg)
    p_norm = 2*np.sum(p) - p[0]

    #normalized and truncated to support elements
    new_p = p / p_norm
    truncated_p = new_p[np.where(new_p >= 2**(-bp))]

    #get bit reps
    bit_reps = []
    approx_vals = []

    for val in truncated_p:
        bit_rep = []
        approx_val = 0
        for i in range(1,bp+1):
            if val > 2**(-i) + approx_val:
                bit_rep.append(1)
                approx_val = 2**(-i) + approx_val
            else:
                bit_rep.append(0)

        bit_reps.append(bit_rep)
        approx_vals.append(approx_val)

    #get negative element reps and elements
    full_domain_bit_reps = np.concatenate([np.flip(np.array(bit_reps),axis=0), np.array(bit_reps)[1:]])
    full_domain = np.concatenate([np.flip(-x[:len(truncated_p)]), x[1:len(truncated_p)]])


    return full_domain_bit_reps, full_domain, p_norm
