import numpy as np

# Local application imports
from .spn import SPNFG, ISPNFG

def construct_fg(num_nodes,
                 num_factors,
                 max_copies,
                 spn_target='factor',
                 normalize=True,
                 tau=1.0,
                 p_conn=0.25,
                 sparsity_temp=0.0):
    """Constructs a factor graph using spn

    Args:
        num_nodes (int): Number of graph nodes
        num_factors (int): Number of factors
        max_copies (int): Maximum network width, should be a power of 2.
        spn_target (str, optional): Target for spn, {factor, node}. Defaults to 'factor'.
        normalize (bool, optional): Whether to use normalized probability. Defaults to True.
        tau (float, optional): Temperature for gumbel softmax. Defaults to 1.0.
        p_conn (float, optional): Node-to-factor density parameter. Defaults to 0.2.
        sparsity_temp (float, optional): Weight for sparsity control. Defaults to 1.0.

    Returns:
        :class:`SPNFG`: Factor Graph
    """
    # num_vars = 2**((num_factors - 1).bit_length()) # round to next power of 2
    num_vars = num_factors if spn_target == 'factor' else num_nodes
    num_leaves = 2 * num_vars # +ve and -ve for each var
    remainder_num = num_leaves % 4  # remainder node number
    leaf_to_var = np.repeat(np.arange(num_vars), 2)

    fg = SPNFG(num_nodes=num_nodes,
               num_factors=num_factors,
               spn_target=spn_target,
               normalize=normalize,
               tau=tau)
    fg.add_bernoulli_layer(num_leaves, leaf_to_var)
    # number of product layers = log(num_vars)
    # product layers reduce vars_partitions by factor of 2
    # sum layers: form deterministic partitions
    num_product_layers = (num_vars//2).bit_length()
    partitions, copies = num_vars, 2

    for i in range(num_product_layers):
        MAX_C = max_copies  # power of 2
        if copies > MAX_C:
            new_copies = MAX_C
            fg.add_sum_layer(partitions * new_copies,
                             remainder_num=remainder_num,
                             p_conn=p_conn,
                             sparsity_temp=sparsity_temp)
            copies = new_copies
        if partitions % 2 == 1 and i > 0:
            remainder_num = remainder_num + new_copies

        new_partitions = partitions//2
        new_copies = copies ** 2

        # cartesian product between copies of two children vars_partition
        fg.add_product_layer(new_partitions*new_copies,
                             new_copies,
                             new_partitions,
                             remainder_num=remainder_num)
        copies = new_copies
        partitions = new_partitions

    #if copies != num_roots: # copies > num_roots and copies < num_roots are both possible
    #    new_copies = num_roots
    #    fg.add_sum_layer(new_copies)
    fg.add_final_layer(num_vars, max_copies, p_conn, sparsity_temp)
    fg.ready()
    return fg


def _construct_fg(fg,
                  num_vars,
                  max_copies,
                  p_conn,
                  sparsity_temp,
                  intervention=False):
    num_leaves = 2 * num_vars
    remainder_num = num_leaves % 4  # remainder node number
    leaf_to_var = np.repeat(np.arange(num_vars), 2)
    if intervention:
        fn_bern = fg.add_bernoulli_layer_int
        fn_sum = fg.add_sum_layer_int
        fn_prod = fg.add_product_layer_int
        fn_final = fg.add_final_layer_int
    else:
        fn_bern = fg.add_bernoulli_layer
        fn_sum = fg.add_sum_layer
        fn_prod = fg.add_product_layer
        fn_final = fg.add_final_layer

    fn_bern(num_leaves, leaf_to_var)
    # number of product layers = log(num_vars)
    # product layers reduce vars_partitions by factor of 2
    # sum layers: form deterministic partitions
    num_product_layers = (num_vars//2).bit_length()
    partitions, copies = num_vars, 2

    for i in range(num_product_layers):
        MAX_C = max_copies  # power of 2
        if copies > MAX_C:
            new_copies = MAX_C
            fn_sum(partitions * new_copies,
                   remainder_num=remainder_num,
                   p_conn=p_conn,
                   sparsity_temp=sparsity_temp)
            copies = new_copies
        if partitions % 2 == 1 and i > 0:
            remainder_num = remainder_num + new_copies

        new_partitions = partitions//2
        new_copies = copies ** 2

        # cartesian product between copies of two children vars_partition
        fn_prod(new_partitions*new_copies,
                new_copies,
                new_partitions,
                remainder_num=remainder_num)
        copies = new_copies
        partitions = new_partitions

    fn_final(num_vars, max_copies, p_conn, sparsity_temp)
    return fg


def construct_ifg(num_nodes,
                  num_factors,
                  num_interventions,
                  max_copies,
                  spn_target='factor',
                  normalize=True,
                  tau=1.0,
                  p_conn=(0.25, 0.25),
                  sparsity_temp=(0.0, 0.0)):
    """Constructs a factor graph with non-genetic interventions using spn

    Args:
        num_nodes (int): Number of graph nodes
        num_factors (int): Number of factors
        num_interventions (int): Number of interventions
        max_copies (int): Maximum network width, should be a power of 2.
        spn_target (str, optional): Target for spn, {factor, node}. Defaults to 'factor'.
        normalize (bool, optional): Whether to use normalized probability. Defaults to True.
        tau (float, optional): Temperature for gumbel softmax. Defaults to 1.0.
        p_conn (tuple[float], optional): Node-to-factor density parameter.
            The first entry is for the original factor graph and second one is for the interventional part.
            Defaults to 0.2.
        sparsity_temp (tuple[float], optional): Weight for sparsity control.
            The first entry is for the original factor graph and second one is for the interventional part.
            Defaults to 1.0.

    Returns:
        :class:`ISPNFG`: Factor Graph
    """

    ifg = ISPNFG(num_nodes=num_nodes,
                 num_factors=num_factors,
                 num_intervention=num_interventions,
                 spn_target=spn_target,
                 normalize=normalize,
                 tau=tau)

    # Build regular factor graphs
    if spn_target == 'factor':
        _construct_fg(ifg, num_factors, max_copies, p_conn[0], sparsity_temp[0])
        _construct_fg(ifg, num_factors, max_copies, p_conn[1], sparsity_temp[1], intervention=True)
    else:
        _construct_fg(ifg, num_nodes, max_copies, p_conn[0], sparsity_temp[0])
        _construct_fg(ifg, num_interventions, max_copies, p_conn[1], sparsity_temp[1], intervention=True)
    # Uncomment if always using factors as SPN targets
    # _construct_fg(ifg, num_factors, max_copies, p_conn[1], sparsity_temp[1], intervention=True)
    ifg.ready()
    return ifg