import gadjid
import numpy as np
import numpy.linalg as la

from config import Config
import run_dagma
import graph
import algo


def metrics(
    config: Config,
    algo_id: str,
    true_dag: graph.CausalGraph,
    out_graph: graph.CausalGraph,
    X,
    reported_score,
):
    true_cpdag = true_dag.to_cpdag()
    out_cpdag = out_graph.to_cpdag()
    out_dag = out_graph.to_dag()

    is_valid = out_graph.is_valid()

    computed_vals = []
    for metric in config.eval.metrics:
        if not (metric == "shd" and out_graph.graph_type == "cpdag") and not is_valid:
            computed_vals.append("")
            continue

        if metric == "shd":
            val = shd(true_cpdag, out_cpdag)
        elif metric == "aid":
            val = ancestor_aid(true_cpdag, out_cpdag)
        elif metric == "bic":
            bic_val = config.algo.lambda_bic[0]
            if len(config.algo.lambda_bic) > 1:
                params = algo.read_params(algo_id.split("-")[1:])
                bic_val = float(params["lambda"])
            val = bic(out_dag, X, bic_val)
        elif metric == "dagma-loss":
            W = mle_params(out_dag, X)
            val = run_dagma.dagma_loss(X, W, config.algo.lambda_dagma)
            if algo_id.startswith("dagma") and not algo_id.startswith(
                "dagma_nonlinear"
            ):
                val = min(val, reported_score)
        else:
            raise ValueError("metric {metric} not supported")
        computed_vals.append(val)
    return tuple(computed_vals)


def empty_metrics(config):
    empty_vals = []
    for metric in config.eval.metrics:
        empty_vals.append("")
    return tuple(empty_vals)


def shd(g1, g2):
    assert g1.num_nodes() == g2.num_nodes(), (
        "graphs need to have same number of vertices"
    )
    p = g1.num_nodes()
    gm1 = g1.to_matrix()
    gm2 = g2.to_matrix()

    shd = 0
    for u in range(p):
        for v in range(u + 1, p):
            if gm1[u, v] != gm2[u, v] or gm1[v, u] != gm2[v, u]:
                shd += 1
    return shd


def ancestor_aid(g1, g2):
    gm1 = g1.to_matrix()
    gm2 = g2.to_matrix()
    return gadjid.ancestor_aid(gm1, gm2, edge_direction="from row to column")[1]


def bic(dag, X, penalty):
    tg = dag.transpose_graph()

    # score computation
    n, p = X.shape
    cov = np.cov(X.T, bias=True, dtype=np.double)

    score = 0.0
    for u in range(p):
        S = sorted(tg.dir_neighbors[u]) + [u]  # do we need to sort -> probably not
        L = la.cholesky(cov[np.ix_(S, S)])
        score += 2.0 * n * np.log(L[-1, -1]) + penalty * np.log(n) * len(
            tg.dir_neighbors[u]
        )
    return score


def mle_params(g: graph.CausalGraph, X):
    n, p = X.shape
    G = np.zeros((p, p), dtype=np.int8)
    for u in range(p):
        for v in g.dir_neighbors[u]:
            G[u, v] = 1
    W = np.zeros((p, p), dtype=float)

    for i in range(p):
        parents_idx = np.nonzero(G[:, i])[0]
        if parents_idx.size == 0:
            continue
        y = X[:, i]
        Xp = X[:, parents_idx]

        XtX = Xp.T @ Xp
        Xty = Xp.T @ y
        beta = np.linalg.solve(XtX, Xty)
        W[parents_idx, i] = beta

    return W
