import networkx as nx
import numpy as np

# modified from existing package (name hidden for anonymity)
def randomTRIL_variabledegree(N, max_degree=5, p_degree=0.5, connected=False):
    """
    Generate a random triangular matrix

    https://stackoverflow.com/a/56514463
    
    :param N: size of the matrix
    :type N: integer
    
    :param degree: degree
    :type degree: integer
    
    :param connected: (Ask)
    :type connected: boolean
    
    :returns: a random triangular adjacency matrix
    :rtype: numpy array 
    """
    import numpy as np
    
    mat = [[0 for x in range(N)] for y in range(N)]
    for _ in range(N):
        degree = max(1, np.random.binomial(max_degree, p_degree))
        for j in range(degree):
            v1 = np.random.randint(0, N-1)
            v2 = np.random.randint(0, N-1)
            if v1 > v2:
                mat[v1][v2] = 1
            elif v1 < v2:
                mat[v2][v1] = 1
    mat = np.asarray(mat, dtype=np.uint8)
    if connected:
        mat = remove_tril_singletons(mat)
    return mat


def randomDAG(N, max_degree=5, p_degree=0.5, connected=True):
    """
    Generates a random DAG

    :param N: number of nodes
    :type N: integer
    
    :param degree: degree
    :type degree: integer
    
    :param connected: If true, returns a connected DAG
    :type connected: boolean
    
    :returns: a random DAG
    :rtype: NetworkX graph
    """
    import networkx as nx
    import numpy as np
    adjacency_matrix = randomTRIL_variabledegree(N, max_degree=max_degree, p_degree=p_degree,
                                  connected=connected)
    rows, cols = np.where(adjacency_matrix == 1)
    edges = zip(rows.tolist(), cols.tolist())
    gr = nx.DiGraph()
    gr.add_edges_from(edges)
    if connected:
        components = [x for x in nx.algorithms.components.weakly_connected_components(gr)]
        if len(components) > 1:
            for component in components[1:]:
                v1 = np.random.choice(tuple(components[0]))
                v2 = np.random.choice(tuple(component))
                gr.add_edge(v1, v2)
    assert nx.is_directed_acyclic_graph(gr)
    assert nx.algorithms.components.is_weakly_connected(gr)
    return gr

def remove_tril_singletons(T):
    """
    Ensure that the DAG resulting from this matrix will not have
    singleton nodes not connected to anything.

    :param T: lower triangular matrix representing a DAG
    :type T: numpy array
    
    :returns: adjacency matrix where no singleton nodes are connected to anything
    :rtype: numpy array 
    """
    import numpy as np
    
    N = T.shape[0]
    neighbors = T.sum(0) + T.sum(1)
    idx = np.where(neighbors == 0)
    if idx:
        for i in idx[0]:
            v1 = i
            while i == v1:
                v1 = np.random.randint(0, N-1)
            if i > v1:
                T[i][v1] = 1
            elif i < v1:
                T[v1][i] = 1
    return T

def randomTRIL(N, degree=5, connected=False):
    """
    Generate a random triangular matrix

    https://stackoverflow.com/a/56514463
    
    :param N: size of the matrix
    :type N: integer
    
    :param degree: degree
    :type degree: integer
    
    :param connected: (Ask)
    :type connected: boolean
    
    :returns: a random triangular adjacency matrix
    :rtype: numpy array 
    """
    import numpy as np
    
    mat = [[0 for x in range(N)] for y in range(N)]
    for _ in range(N):
        for j in range(degree):
            v1 = np.random.randint(0, N-1)
            v2 = np.random.randint(0, N-1)
            if v1 > v2:
                mat[v1][v2] = 1
            elif v1 < v2:
                mat[v2][v1] = 1
    mat = np.asarray(mat, dtype=np.uint8)
    if connected:
        mat = remove_tril_singletons(mat)
    return mat