import numpy as np
from dagma.nonlinear import DagmaNonlinear, DagmaMLP
import graph
import os
import sys


def run_dagma(data_path, lambda1):
    X = np.genfromtxt(data_path, delimiter=",", skip_header=1)
    _, p = X.shape
    eq_model = DagmaMLP(dims=[p, 10, 1], bias=True)
    model = DagmaNonlinear(eq_model)
    # TODO: maybe not hard code lambda2
    W = model.fit(X, lambda1=lambda1, lambda2=0.005)
    dir_edges = []
    for u in range(p):
        for v in range(p):
            if u == v:
                continue
            if abs(W[u, v]) != 0.0:
                dir_edges.append((u, v))
    return graph.CausalGraph(p, dir_edges, [], "dag")


if __name__ == "__main__":
    import sys

    if len(sys.argv) != 3:
        print("Usage: python run_dagma.py <data.csv> <lambda1>")
        sys.exit(1)

    data_path = sys.argv[1]
    lambda1 = float(sys.argv[2])

    run_dagma(data_path, lambda1).write()
