import numpy as np
import networkx as nx

'''
Utility functions including K initialization and cycle finding
'''

def initialize_K(data): # initialize candidate parent sets
    n = data.n
    K = {i: np.zeros((1, n)) for i in range(n)} # a list containing n (1xn) all 0 arrays
    return K


def find_cycles(graph, C_set):
    G = nx.from_numpy_array(graph.T, create_using=nx.DiGraph)
    print(G)
    try:
        cycle = nx.find_cycle(G)
        cluster = list(dict.fromkeys(edge[0] for edge in cycle))
        C_set.append(cluster)
        print("Find a cycle: ", cluster)
        return cluster, C_set
    except nx.NetworkXNoCycle: # no cycle found
        return None, C_set
    # cycles = list(nx.simple_cycles(G))
    # for cluster in cycles:
    #     cluster = set(cluster)
    #     C_set.append(cluster)
    # # have_cycle = (len(cycles)!=0)
    # return cycles, C_set