import numpy as np
from dagma.linear import DagmaLinear
from dagma.utils import is_dag
import graph
import os
import sys


def dagma_loss(X, W, lambda1):
    n, p = X.shape
    dif = np.eye(p, dtype=np.float64) - W
    cov = X.T @ X / n
    rhs = cov @ dif
    return 0.5 * np.trace(dif.T @ rhs) + lambda1 * np.abs(W).sum()


def run_dagma(data_path, lambda1):
    X = np.genfromtxt(data_path, delimiter=",", skip_header=1)
    model = DagmaLinear(loss_type="l2")
    original_stdout = sys.stdout
    sys.stdout = open(os.devnull, "w")
    W = model.fit(X, lambda1=lambda1)
    sys.stdout.close()
    sys.stdout = original_stdout
    loss = dagma_loss(X, W, lambda1)
    print(loss)
    p, _ = W.shape
    while not is_dag(W):
        absW = np.abs(W)
        absW[absW == 0] = np.inf  # ignore zeros
        idx = np.unravel_index(absW.argmin(), W.shape)
        W[idx] = 0.0
    dir_edges = []
    for u in range(p):
        for v in range(p):
            if u == v:
                continue
            if abs(W[u, v]) != 0.0:
                dir_edges.append((u, v))
    return graph.CausalGraph(p, dir_edges, [], "dag")


if __name__ == "__main__":
    import sys

    if len(sys.argv) != 3:
        print("Usage: python run_dagma.py <data.csv> <lambda1>")
        sys.exit(1)

    data_path = sys.argv[1]
    lambda1 = float(sys.argv[2])

    run_dagma(data_path, lambda1).write()
