import numpy as np

from dagsolver_utils import least_square_cost, apply_threshold, count_accuracy, compute_norm_distance, \
    find_optimal_threshold_for_shd



def log_acc_metrics(acc, run, prefix):
    for key in ['shd', 'shd_w', 'shd_a', 'fdr', 'least_square_cost', 'norm_distance', 'cond', 'true_pos', 'tpr', 'false_pos', 'nnz', 'false_neg', 'precision', 'f1score', 'g_score']:
        if key in acc:
            run.log_metric(f'{prefix}_{key}', acc[key])


def calculate_metrics(X, Y, W_true, B_true, A_true, W_est, A_est, W_bi_true, W_bi_est):
    cost_W_true = least_square_cost(X, W_true, Y, A_true)
    if W_bi_est is None:
        W_bi_est = np.zeros_like(W_true)
        W_bi_true = np.zeros_like(W_true)
    W_bi_est = np.triu(W_bi_est, k=1)
    W_bi_true = np.triu(W_bi_true, k=1)
    B_bi_true = (W_bi_true != 0)
    B_all_true = B_true.astype(int) - B_bi_true.astype(int)

    best_t, best_shd = find_optimal_threshold_for_shd(B_true, W_est, A_true, A_est, W_bi_true, W_bi_est)
    best_f1_score = None

    thresholds = [0.5, 0.3, 0.15, 0.05, best_t]
    best_W = None
    best_Wbi = None
    best_A = None
    for threshold in thresholds:
        W_est_t = apply_threshold(W_est, threshold)
        B_est_t = W_est_t != 0
        W_bi_est_t = apply_threshold(W_bi_est, threshold)
        B_bi_est_t = (W_bi_est_t != 0)
        B_all_est_t = B_est_t + (-1 * B_bi_est_t) # CPDAG - undirected edges have -1
        A_est_t = [apply_threshold(A_i_est, threshold) for A_i_est in A_est]
        acc_all = count_accuracy(B_all_true, B_all_est_t, A_true, A_est_t)
        acc_all['least_square_cost'] = least_square_cost(X, W_est_t, Y, A_est_t) - cost_W_true
        acc_all['norm_distance'] = compute_norm_distance(W_true, W_est_t, A_true, A_est_t)
        if threshold == best_t:
            best_W = W_est_t
            best_Wbi = W_bi_est_t
            best_A = A_est_t
            best_f1_score = acc_all['f1score']
        metric_infix = 'best' if threshold == best_t else f't{threshold}'

        print(metric_infix)
        print(acc_all)

    assert best_W is not None


    return best_W, best_Wbi, best_A, best_shd, best_f1_score




