import xges
import numpy as np
import graph


def run_xges(data_path, lambda_bic):
    X = np.genfromtxt(data_path, delimiter=",", skip_header=1)
    xges_algo = xges.XGES(alpha=lambda_bic)
    pdag = xges_algo.fit(X)
    mat = pdag.to_adjacency_matrix()
    p, _ = mat.shape
    dir_edges = []
    undir_edges = []
    for u in range(p):
        for v in range(p):
            if mat[u, v] == 1.0 and mat[v, u] == 0.0:
                dir_edges.append((u, v))
            if mat[u, v] == 1.0 and mat[v, u] == 1.0 and u < v:
                undir_edges.append((u, v))
    return graph.CausalGraph(p, dir_edges, undir_edges, "cpdag")


if __name__ == "__main__":
    import sys

    if len(sys.argv) != 3:
        print("Usage: python run_dagma.py <data.csv> <lambda>")
        sys.exit(1)

    data_path = sys.argv[1]
    lambda_bic = float(sys.argv[2])

    run_xges(data_path, lambda_bic).write()
