import networkx as nx
import numpy as np
from networkx import from_numpy_array, DiGraph
from numpy.linalg import norm

import utils as notears_utils

class ExDagDataException(Exception):
    pass


def apply_threshold(W, w_threshold):
    W_t = np.copy(W)
    W_t[np.abs(W) < w_threshold] = 0
    return W_t


def compute_shd(W_true, W_est):
    acc = notears_utils.count_accuracy(W_true != 0, W_est != 0)
    return acc['shd'], acc


def find_minimal_dag_threshold(W):
    if notears_utils.is_dag(W):
        return 0, W
    possible_thresholds = sorted((abs(t) for t in W.flatten() if abs(t) > 0))
    for t_candidate in possible_thresholds:
        W[np.abs(W) < t_candidate] = 0
        if notears_utils.is_dag(W):
            return t_candidate, W
    assert False  # Should always find a dag


def find_optimal_threshold_for_shd(B_true, W_est, A_true, A_est, W_bi_true, W_bi_est):
    values = set((abs(t) for t in W_est.flatten() if abs(t) > 0))
    for A_i_est in A_est:
        values.update((abs(t) for t in A_i_est.flatten() if abs(t) > 0))

    possible_thresholds = values #sorted((abs(t) for t in W_est.flatten() if abs(t) > 0))
    if not possible_thresholds:
        possible_thresholds = [0]

    best_t = max(possible_thresholds) if possible_thresholds else 0
    best_shd = B_true.shape[0] ** 2 # calculate_shd(W_true, W_est != 0, A_true, A_est) # W_true.shape[0]**2
    B_bi_true = W_bi_true != 0
    B_all_true = B_true.astype(int) - B_bi_true.astype(int)
    for t_candidate in possible_thresholds:
        W_est_t = apply_threshold(W_est, t_candidate)
        A_est_t = [apply_threshold(A_i_est, t_candidate) for A_i_est in A_est]
        W_bi_est_t = apply_threshold(W_bi_est, t_candidate)
        B_bi_est_t = (W_bi_est_t != 0)
        B_est_t = W_est_t != 0
        B_all_est_t = B_est_t + (-1 * B_bi_est_t) # CPDAG - undirected edges have -1
        shd, _, _ = calculate_shd(B_all_true, B_all_est_t, A_true, A_est_t)

        if shd < best_shd:
            best_t = t_candidate
            best_shd = shd
    return best_t, best_shd


def find_optimal_multiple_thresholds(W_true, W_est, A_true, A_est, W_bi_true, Wbi):
    assert False, 'this method needs to be fixed'
    best_w_t, best_w_shd, best_W, _ = find_optimal_threshold_single_matrix(W_true, W_est)
    best_a_t = []
    best_a_shd = []
    best_a = []
    for a_true_i, a_est_i in zip(A_true, A_est):
        best_a_i_t, best_a_i_shd, best_a_i, _ = find_optimal_threshold_single_matrix(a_true_i, a_est_i)
        best_a_t.append(best_a_i_t)
        best_a_shd.append(best_a_i_shd)
        best_a.append(best_a_i)

    best_acc = count_accuracy(W_true, best_W !=0, A_true, best_a)

    return best_w_t, best_a_t, best_w_shd + sum(best_a_shd), best_W, best_a, best_acc


def find_optimal_threshold_single_matrix(W_true, W_est):
    assert False, 'this method needs to be fixed'
    possible_thresholds = sorted((abs(t) for t in W_est.flatten() if abs(t) > 0))
    best_t = max(possible_thresholds) if possible_thresholds else 0
    best_shd, _, _ = calculate_shd(W_true, W_est !=0, [], [], test_dag=False) # W_true.shape[0]**2
    best_acc = count_accuracy(W_true, W_est !=0, [], [], test_dag=False)
    best_W = W_est
    for t_candidate in possible_thresholds:
        W_est_t = apply_threshold(W_est, t_candidate)
        shd, _, _ = calculate_shd(W_true, W_est_t !=0, [], [], test_dag=False)
        #shd, acc = compute_shd(W_true, W_est_t)
        if shd < best_shd:
            best_t = t_candidate
            best_shd = shd
            best_acc = count_accuracy(W_true, W_est_t !=0, [], [], test_dag=False)
            best_W = W_est_t
    return best_t, best_shd, best_W, best_acc

def least_square_cost(X, W, Y, A):
    n, d = X.shape
    p = len(A)
    assert len(Y) == len(A)
    val = sum((X[i,j] - sum(X[i, k] * W[k, j] for k in range(d) if k != j) - sum(Y[t][i, k] * A[t][k, j] for k in range(d) for t in range(p)))**2 for i in range(n) for j in range(d))
    return val


def plot(W, nodelist, filename=None, dpi=None):
    import matplotlib.pyplot as plt
    # if abbrev:
    #     ls = dict((x,x[:3]) for x in self.nodes)
    # else:
    #     ls = None
    # try:
    #     edge_colors = [self._edge_colour[compelled] for (u,v,compelled) in self.edges.data('compelled')]
    # except KeyError:
    #     edge_colors = 'k'
    graph = from_numpy_array(W, create_using=DiGraph, nodelist=nodelist)
    fig, ax = plt.subplots()
    nx.draw_networkx(graph, ax=ax, pos=nx.drawing.nx_agraph.graphviz_layout(graph,prog='dot'),
                     node_color="white",arrowsize=15)
    if filename is not None:
        fig.savefig(filename,format='png', bbox_inches='tight', dpi=dpi)
        plt.close(fig)
    else:
        plt.show()


def plot_heatmap(W, names_x, names_y, filename=None, dpi=None):
    import matplotlib.pyplot as plt

    # Remove '_lag0' suffix from names
    names_x = [name.split("_lag")[0] for name in names_x]
    names_y = [name.split("_lag")[0] for name in names_y]

    fig, ax = plt.subplots()

    # Create the heatmap using imshow
    limit = max(abs(W.min()), abs(W.max()))

    cax = ax.imshow(W, cmap='coolwarm', interpolation='nearest', vmin=-limit, vmax=limit) # YlGnBu

    ax.set_xticks(np.arange(len(names_x)))
    ax.set_xticklabels(names_x, rotation=90)
    ax.set_yticks(np.arange(len(names_y)))
    ax.set_yticklabels(names_y)

    # Add a colorbar to the figure
    fig.colorbar(cax, ax=ax)

    fig.tight_layout()

    if filename is not None:
        fig.savefig(filename,format='png', bbox_inches='tight', dpi=dpi)
        plt.close(fig)
    else:
        plt.show()


def calculate_dag_shd(B_true, B_est, test_dag=True):
    assert B_true.shape == B_est.shape
    if (B_est == -1).any():  # cpdag
        if not ((B_est == 0) | (B_est == 1) | (B_est == -1)).all():
            raise ValueError('B_est should take value in {0,1,-1}')
        if ((B_est == -1) & (B_est.T == -1)).any():
            raise ValueError('undirected edge should only appear once')
    else:  # dag
        if not ((B_est == 0) | (B_est == 1)).all():
            raise ValueError('B_est should take value in {0,1}')
        if test_dag and not notears_utils.is_dag(B_est):
            raise ValueError('B_est should be a DAG')

    shd = 0
    for i in range(B_true.shape[0]):
        for j in range(i):
            e_ij = (B_est[i,j], B_est[j,i])
            if min(e_ij) == -1:
                e_ij = (-1,-1)
            t_ij = (B_true[i,j], B_true[j,i])
            if min(t_ij) == -1:
                t_ij = (-1,-1)

            if e_ij != t_ij:
                if e_ij == t_ij[::-1]:
                    shd += 0.5
                elif (e_ij == (-1,-1) and t_ij == (0,0)) or (e_ij == (0,0) and t_ij == (-1,-1)):
                    shd += 1
                elif e_ij == (-1, -1) or t_ij == (-1, -1):
                    shd += 0.5
                else:
                    shd += 1

    return shd


def calculate_dag_shd_old(B_true, B_est, test_dag=True):
    if (B_est == -1).any():  # cpdag
        if not ((B_est == 0) | (B_est == 1) | (B_est == -1)).all():
            raise ValueError('B_est should take value in {0,1,-1}')
        if ((B_est == -1) & (B_est.T == -1)).any():
            raise ValueError('undirected edge should only appear once')
    else:  # dag
        if not ((B_est == 0) | (B_est == 1)).all():
            raise ValueError('B_est should take value in {0,1}')
        if test_dag and not notears_utils.is_dag(B_est):
            raise ValueError('B_est should be a DAG')

    d = B_true.shape[0]
    pred_und = np.flatnonzero(B_est == -1)
    pred = np.flatnonzero(B_est == 1)
    cond = np.flatnonzero(B_true)
    cond_reversed = np.flatnonzero(B_true.T)
    #cond_skeleton = np.concatenate([cond, cond_reversed])
    # reverse
    extra = np.setdiff1d(pred, cond, assume_unique=True)
    reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)

    pred_lower = np.flatnonzero(np.tril(B_est + B_est.T))
    cond_lower = np.flatnonzero(np.tril(B_true + B_true.T))
    extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
    missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)

    shd = len(extra_lower) + len(missing_lower) + len(reverse)
    return shd

def calculate_shd(B_true, B_est, A_true, A_est, test_dag=True):
    shd = calculate_dag_shd(B_true, B_est, test_dag=test_dag)
    a_shd = 0
    for i in range(len(A_true)):
        a_i_shd = calculate_dag_shd(A_true[i] != 0, A_est[i] != 0, test_dag=False)
        a_shd += a_i_shd
    return shd + a_shd, shd, a_shd


def _count_accuracy_stats(B_true, B_est):
    #d = B_true.shape[0]
    # linear index of nonzeros
    pred = np.flatnonzero(B_est == 1)
    positive = np.flatnonzero(B_true)
    # cond_reversed = np.flatnonzero(B_true.T)
    # cond_skeleton = np.concatenate([cond, cond_reversed])
    # true pos
    true_pos = np.intersect1d(pred, positive, assume_unique=True)
    # treat undirected edge favorably
    #true_pos_und = np.intersect1d(pred_und, cond_skeleton, assume_unique=True)
    #true_pos = np.concatenate([true_pos, true_pos_und])
    # false pos
    false_pos = np.setdiff1d(pred, positive, assume_unique=True)
    # false_pos_und = np.setdiff1d(pred_und, cond_skeleton, assume_unique=True)
    # false_pos = np.concatenate([false_pos, false_pos_und])
    # reverse
    #extra = np.setdiff1d(pred, positive, assume_unique=True)
    #reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
    # compute ratio
    pred_size = len(pred)
    cond_neg_size = B_true.shape[0] * B_true.shape[1] - len(positive)
    fdr = float(len(false_pos)) / max(pred_size, 1)
    tpr = float(len(true_pos)) / max(len(positive), 1)
    fpr = float(len(false_pos)) / max(cond_neg_size, 1)
    precision = len(true_pos) / max((len(true_pos) + len(false_pos)), 1)
    f1 = 2 * (tpr*precision)/max((tpr+precision), 1)
    g_score = max((len(true_pos)-len(false_pos),0))/max(len(positive), 1)
    return {
        'fdr': fdr,
        'tpr': tpr, # recall, sensitivity
        'fpr': fpr,
        'nnz': pred_size,
        'true_pos': len(true_pos),
        'false_pos': len(false_pos),
        'false_neg': len(positive) - len(true_pos),
        'cond': len(positive),
        'precision': precision,
        'f1score': f1,
        'g_score': g_score
    }


def count_accuracy(B_true, B_est, A_true, A_est, test_dag=True):
    assert len(A_true) == len(A_est)
    shd, w_shd, a_shd = calculate_shd(B_true, B_est, A_true, A_est, test_dag=test_dag)


    m_true = np.copy(B_true)
    m_est = np.copy(B_est)


    m_est = np.concatenate([m_est] + A_est, axis=0)
    m_true = np.concatenate([m_true] + A_true, axis=0)

    #norm_dist = norm(m_est - m_true)

    acc = _count_accuracy_stats(m_true, m_est != 0)
    acc['shd'] = shd
    acc['shd_w'] = w_shd
    acc['shd_a'] = a_shd
    #acc['norm_distance'] = norm_dist
    return acc


def compute_combined_shd(W_true, W_est, A_true, A_est):
    sum_est = np.copy(W_est)
    sum_true = np.copy(W_true)
    for A_true_i, A_est_i in zip(A_true, A_est):
        sum_est += A_est_i
        sum_true += A_true_i

    return find_optimal_threshold_single_matrix(sum_true, sum_est)

def compute_norm_distance(W_true, W_est, A_true, A_est):
    m_true = np.copy(W_true)
    m_est = np.copy(W_est)
    m_est = np.concatenate([m_est] + A_est, axis=0)
    m_true = np.concatenate([m_true] + A_true, axis=0)

    norm_dist = norm(m_est - m_true)
    return norm_dist

def tupledict_to_np_matrix(tuple_dict, d):
    matrix = np.zeros((d, d))
    for (i, j), value in tuple_dict.items():
        matrix[i, j] = value
    return matrix

if __name__ == '__main__':
    n = 10
    B = np.triu(np.random.randint(2, size=(n, n)), k=1)
    print(f"Number of non-zero entries in B: {np.count_nonzero(B)}")
    print(calculate_shd(B,B,[],[]))
    B_est = np.zeros_like(B)
    print(calculate_shd(B,B_est,[],[]))
    
    B = -1 * B
    print(calculate_shd(B,B_est,[],[]))

    B = np.zeros((n, n))
    B[0, 1] = -1
    print(calculate_shd(B,B_est,[],[]))

    B[0, 1] = -1
    B_est[1, 0] = 1
    print(calculate_shd(B,B_est,[],[]))

    B[0, 1] = 1
    B_est[1, 0] = 1
    print(calculate_shd(B,B_est,[],[]))

    B = np.zeros((n, n))
    B_est = np.zeros((n, n))
    B[0, 1] = -1
    B_est[1, 0] = -1
    print(calculate_shd(B,B_est,[],[]))

    B = np.zeros((n, n))
    B_est = np.zeros((n, n))
    B[0, 1] = -1
    B_est[1, 0] = 1
    print(calculate_shd(B,B_est,[],[]))
