import numpy as np
from causallearn.utils.cit import CIT
from itertools import combinations
from copy import deepcopy
import pydotplus
import os
import igraph
def count_precision_recall_f1(tp, fp, fn):
    # Precision
    if tp + fp == 0:
        precision = None
    else:
        precision = float(tp) / (tp + fp)

    # Recall
    if tp + fn == 0:
        recall = None
    else:
        recall = float(tp) / (tp + fn)

    # F1 score
    if precision is None or recall is None:
        f1 = None
    elif precision == 0 or recall == 0:
        f1 = 0.0
    else:
        f1 = float(2 * precision * recall) / (precision + recall)
    return precision, recall, f1

def count_dag_accuracy(B_bin_true, B_bin_est):
    d = B_bin_true.shape[0]
    # linear index of nonzeros
    pred = np.flatnonzero(B_bin_est)
    cond = np.flatnonzero(B_bin_true)
    cond_reversed = np.flatnonzero(B_bin_true.T)
    cond_skeleton = np.concatenate([cond, cond_reversed])
    # true pos
    true_pos = np.intersect1d(pred, cond, assume_unique=True)
    # false pos
    false_pos = np.setdiff1d(pred, cond_skeleton, assume_unique=True)
    # reverse
    extra = np.setdiff1d(pred, cond, assume_unique=True)
    reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
    # compute ratio
    pred_size = len(pred)
    cond_neg_size = 0.5 * d * (d - 1) - len(cond)
    if pred_size == 0:
        fdr = None
    else:
        fdr = float(len(reverse) + len(false_pos)) / pred_size
    if len(cond) == 0:
        tpr = None
    else:
        tpr = float(len(true_pos)) / len(cond)
    if cond_neg_size == 0:
        fpr = None
    else:
        fpr = float(len(reverse) + len(false_pos)) / cond_neg_size
    # structural hamming distance
    pred_lower = np.flatnonzero(np.tril(B_bin_est + B_bin_est.T))
    cond_lower = np.flatnonzero(np.tril(B_bin_true + B_bin_true.T))
    extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
    missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)
    shd = len(extra_lower) + len(missing_lower) + len(reverse)
    # false neg
    false_neg = np.setdiff1d(cond, true_pos, assume_unique=True)
    precision, recall, f1 = count_precision_recall_f1(tp=len(true_pos),
                                                      fp=len(reverse) + len(false_pos),
                                                      fn=len(false_neg))
    # return {'fdr': fdr, 'tpr': tpr, 'fpr': fpr, 'shd': shd, 'nnz': pred_size, 
    #         'precision': precision, 'recall': recall, 'f1': f1}
    return {'f1': f1,  'precision': precision, 'recall': recall, 'shd': shd}


def given_set(i,j,G):
    give_set = []
    g = igraph.Graph.Adjacency(G)
    all_pathes = g.get_all_shortest_paths(i,j)
    for path in all_pathes:
        if len(path) == 2:
            continue
        give_set.append(path[1])
        give_set.append(path[-2])
    return set(give_set)

def get_adjSet(i, G, n_node):
    adj = []
    for j in range(n_node):
        if G[i][j] == 1 or G[j][i] == 1:
            adj.append(j)
    return adj
def get_adj_ij(i, j, G, n_node):
    adj = []
    for k in range(n_node):
        if G[i][k] ==1 & G[k][j] == 1:
            adj.append(k)
    return adj
def fisher_z_test(i, j, K, sample, result):
    indep = True
    fisher_z_obj = CIT(sample, "kci")
    Pvalue = fisher_z_obj(i,j,K)
    result.append([f'{i}_{j}_{K}___{Pvalue}'])
    # print(f'{i}_{j}_{K}___{Pvalue}')
    alpha = 0.05
    if Pvalue >= alpha:
        indep = True
    else:
        indep = False
    return indep

def skeleton(n_node, sample):

    C = np.ones((n_node,n_node))

    S = []
    for i in range(n_node):
        S.append([])
        for j in range(n_node):
            S[i].append([])

    pairs = []
    for i in range(n_node):
        for j in range(n_node - i):
            if(i != (n_node - j - 1)):  
                pairs.append((i, (n_node - j - 1)))
            else:
                C[i, i] = 0
    CI_result = []
    l = -1    
    while 1:
        l = l + 1
        flag = True   
        for (i, j) in pairs:

            adj_set = get_adjSet(i, C, n_node)    
            if(C[i][j] == 1) & (len(adj_set) >= l):    
                flag =False   
                adj_set.remove(j)    

                combin_set = combinations(adj_set, l)    
                for K in combin_set:
                    if fisher_z_test(i, j, list(K), sample, CI_result):   
                        C[i][j] = 0
                        C[j][i] = 0

                        S[i][j] = list(K)
                        S[j][i] = list(K)    
                    else:
                        continue
            else:
                continue

        if flag:
            break

    return C, S, CI_result

def direction(C, S):
    
    G = deepcopy(C)
    G = G.astype(int)
    n_node = C.shape[0]


    pairs = []
    for i in range(n_node):
        for j in range(n_node):
            if(i != j):    
                if(C[i][j] == 1):
                    pairs.append((i, j))
    
    triples = []
    for (i, j) in pairs:
        for k in range(n_node):
            if(C[j][k] == 1) & (k != i):
                triples.append([i, j, k])
    
    # import pdb
    # pdb.set_trace()
    #  i-j-k， # i and k are not adjacent and (if and only if j is not in the sep_set (i,k)), then i -> j <- k
    for [i, j, k] in triples:
        if (G[i][j] == 1) & (G[j][i] == 1) & (G[k][j] == 1) & (G[j][k] == 1) & (G[i][k] == 0) & (G[k][i] == 0): 
            if j not in S[i][k]:
                G[j][i] = 0
                G[j][k] = 0
                import pdb
                pdb.set_trace()

    return G, triples
def adjacency_matrix_to_dot(adjacency_matrix, node_names=None, selection = [], latent = []):
    graph = pydotplus.Dot(graph_type='graph', rankdir='TB')

    num_nodes = len(adjacency_matrix)
    for i in range(num_nodes):
        graph.add_node(pydotplus.Node(str(i), label=str(node_names[i]) if node_names else None))

    for i in range(num_nodes):
        for j in range(i):
            if adjacency_matrix[i, j] == 1 and adjacency_matrix[j, i] == 0:
                graph.add_edge(pydotplus.Edge(str(i), str(j), dir="forward"))
            elif adjacency_matrix[i, j] == 0 and adjacency_matrix[j, i] == 1:
                graph.add_edge(pydotplus.Edge(str(j), str(i), dir="forward"))
            elif (adjacency_matrix[i, j] == 1) & (adjacency_matrix[j, i] == 1):
                graph.add_edge(pydotplus.Edge(str(i), str(j), dir="none"))

    if len(selection) != 0:
        graph.add_node(pydotplus.Node('s', label='s' if node_names else None))
        for s in selection:
            graph.add_edge(pydotplus.Edge(str(s[0]), 's', dir="forward"))
            graph.add_edge(pydotplus.Edge(str(s[1]), 's', dir="forward"))
    if len(latent) != 0:
        graph.add_node(pydotplus.Node('L', label='L' if node_names else None))
        for l in latent:
            for j in range(len(l)):
                graph.add_edge(pydotplus.Edge('L', str(l[j]), dir="forward"))
    return graph

def count_skeleton_accuracy(skeleton_true, skeleton_est):
    # skeleton_true = get_skeleton(B_bin_true) # b_bin_true[i,j]=1  <==> skeleton[i,j]=skeleton[j,i]=1
    # skeleton_est = get_skeleton(B_bin_est)   # b_bin_est[i,j]=-1 & b_bin_est[j,i]=1  <==>  skeleton[i,j]=skeleton[j,i]=1

    # print(3, skeleton_true)
    # print(4, skeleton_est) 

    d = len(skeleton_true)
    skeleton_triu_true = skeleton_true[np.triu_indices(d, k=1)]
    skeleton_triu_est = skeleton_est[np.triu_indices(d, k=1)]
    pred = np.flatnonzero(skeleton_triu_est)  # estimated graph
    cond = np.flatnonzero(skeleton_triu_true) # true graph 

    # true pos: an edge estimated with correct direction.
    true_pos = np.intersect1d(pred, cond, assume_unique=True)
    # false pos: an edge that is in estimated graph but not in the true graph.
    false_pos = np.setdiff1d(pred, cond, assume_unique=True)
    # false neg: an edge that is not in estimated graph but in the true graph.
    false_neg = np.setdiff1d(cond, pred, assume_unique=True) # This is also OK: np.setdiff1d(cond, true_pos, assume_unique=True)
    # true negative: an edge that is neither in estimated graph nor in true graph.
    # true negative: normally equals 0.

    # compute ratio
    nnz = len(pred)
    cond_neg_size = len(skeleton_triu_true) - len(cond)
    fdr = float(len(false_pos)) / max(nnz, 1)  # fdr = (FP) / (TP + FP) = FP / |pred_graph|
    tpr = float(len(true_pos)) / max(len(cond), 1)  # tpr: TP / (TP + FN) = TP / |true_graph|
    fpr = float(len(false_pos)) / max(cond_neg_size, 1) # fpr: (FP) / (TN + FP) = FP / ||
    try:
        f1 = len(true_pos) / (len(true_pos) + 0.5 * (len(false_pos) + len(false_neg)))
    except:
        f1 = None

    # structural hamming distance
    extra_lower = np.setdiff1d(pred, cond, assume_unique=True)
    missing_lower = np.setdiff1d(cond, pred, assume_unique=True)
    shd = len(extra_lower) + len(missing_lower)
    return {'f1_skeleton': f1, 'precision_skeleton': 1 - fdr, 'recall_skeleton': tpr, 'shd_skeleton': shd}
    # return {'f1_skeleton': f1, 'precision_skeleton': 1 - fdr, 'recall_skeleton': tpr,
            # 'shd_skeleton': shd, 'TPR_skeleton': tpr, 'FDR_skeleton': fdr, "number_edge_pred":len(pred), "number_edge_true":len(cond)}

def count_skeleton_acc(skeleton_true, skeleton_est):
    d = skeleton_true.shape[0]
    same = True
    for i in range(d):
        for j in range(d):
            if skeleton_true[i][j] != skeleton_est[i][j]:
                same = False
                return same
    return same


def dag_to_pag(adj, latent, selection):
    d = adj.shape[0]
    for i in range(d):
        for j in range(i):
            if adj[i][j] ==1:
                adj[j][i] = -1
    for la in latent:
        pair = combinations(la,2)
        for p in pair:
            if adj[p[0]][p[1]] == 0 & adj[p[1]][p[0]] ==0:
                adj[p[0]][p[1]] = 1
                adj[p[1]][p[0]] = 1
            elif adj[p[0]][p[1]] == -1 & adj[p[1]][p[0]] ==1:
                adj[p[0]][p[1]] = 2
                adj[p[1]][p[0]] = 1
            elif adj[p[0]][p[1]] == 1 & adj[p[1]][p[0]] == -1:
                adj[p[0]][p[1]] = 1
                adj[p[1]][p[0]] = 2
    for se in selection:
        if adj[se[0]][se[1]] == 0 & adj[se[1]][se[0]] ==0:
            adj[se[0]][se[1]] = -1
            adj[se[1]][se[0]] = -1
        elif adj[se[0]][se[1]] == -1 & adj[se[1]][se[0]] ==1:
            adj[se[0]][se[1]] = -1
            adj[se[1]][se[0]] = 2
        elif adj[se[0]][se[1]] == 1 & adj[se[1]][se[0]] == -1:
            adj[se[0]][se[1]] = 2
            adj[se[1]][se[0]] = -1
    return adj

acc_selection = []
acc_latent = []
acc_dag = []
f1_dag = []
recall_dag = []
shd_dag = []
times = 0
samples = []
count = 0
interven = 'hard'
d = 7
# sample = [47, 125, 231, 415, 461, 597, 639, 928, 1012, 1027, 1048, 1075, 1144, 1190, 1399, 1586, 1661, 1750, 1927, 1928]
sample = [11, 13, 14, 35, 67, 68, 77, 87, 99, 121, 133, 161, 173, 186, 189, 192, 197, 208, 225, 230]

for i in sample:
# for i in range(3000):
    print(i)
    # file_path = f'./v_{d}/{interven}/sample_{i}'
    file_path = f'./latent_with_selection/General_G/{interven}/v_{d}/sample_{i}'
    if not os.path.exists(file_path):
        continue
    data = np.load(os.path.join(file_path, f'sample_{interven}.npz'), allow_pickle=True)
    obs = data['obs']
    print(obs.shape)
    if obs.shape[0] < 100:
        continue
    n_sample, n_node = obs.shape[0],obs.shape[1]
    node_name = [str(i) for i in range(n_node)]
    # cov = np.cov(np.transpose(obs))
    ske,sep_set, CI_ske_result = skeleton(n_node, obs)
    # import pdb
    # pdb.set_trace()
    # ret_skeleton = count_skeleton_accuracy(data['ske'], ske)
    print(count_skeleton_acc(data['ske'], ske))

    if not count_skeleton_acc(data['ske'], ske):
        continue
    samples.append([count_skeleton_acc(data['ske'], ske), i])

    # ske_graph = adjacency_matrix_to_dot(ske, node_name, data['selection'], data['latent'])
    # image_file_path = os.path.join(file_path, f'ske_{interven}.png')
    # ske_graph.write_png(image_file_path)

    # dag, triples = direction(ske,sep_set)
    # cpdag = deepcopy(dag)
    # cpdag_graph = adjacency_matrix_to_dot(cpdag, node_name, data['selection'])
    # image_file_path = './selection_data_visual/cpdag.png'
    # cpdag_graph.write_png(image_file_path)

    dag = deepcopy(ske)
    dag = dag.astype(int)
    # import pdb
    # pdb.set_trace()
    threshold = 0.05
    s_indicator = np.zeros([n_node, n_node])
    s_without_cause = []
    correct_set = []
    result = {}
    condition_set = []
    result['latent'] = []
    result['selection'] = []
    for i in range(n_node):
        for j in range(i):
            if dag[i][j] == 1 & dag[j][i] == 1:
                correct = False
                obs_ij = np.concatenate((obs[:,[i,j]], np.zeros((n_sample, 1))), axis=1)
                data_i_org = data[f'per_{i}']
                data_j_org = data[f'per_{j}']
                n_i = data_i_org.shape[0]
                n_j = data_j_org.shape[0]
                data_i_p = np.concatenate((data_i_org[:,[i,j]], np.ones((n_i,1))), axis=1)
                data_j_p = np.concatenate((data_j_org[:,[i,j]], np.ones((n_j,1))), axis=1)
                data_i = np.concatenate((obs_ij,data_i_p), axis=0)
                data_j = np.concatenate((obs_ij,data_j_p), axis=0)
                CIT_obj = CIT(data_j, "kci")
                Upj_value = CIT_obj(0,2,set([]))
                Cpj_value = CIT_obj(0,2,set([1]))
                CIT_obi = CIT(data_i, "kci")
                Upi_value = CIT_obi(1,2, set([]))
                Cpi_value = CIT_obi(1,2, set([0]))
                if (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                    dag[j][i] = 0
                    # result['direct_cause'][f'{i}-{j}'] = Upj_value
                elif (Upi_value > threshold) & (Cpi_value < threshold) & (Upj_value < threshold) & (Cpj_value > threshold):
                    dag[i][j] = 0
                    # result['direct_cause'][f'{j}-{i}'] = Upi_value
                elif (Upj_value < threshold) & (Cpj_value > threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                    dag[j][i] = 0
                    dag[i][j] = 0
                    result['selection'].append([i,j])
                elif (Upj_value < threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                    dag[j][i] = 0
                    correct = True
                    correct_set.append([i,j])
                    result['selection'].append([i,j])
                    condition_set.append('S_C')
                    condition = 'S_C'
                elif (Upi_value < threshold) & (Cpi_value < threshold) & (Upj_value < threshold) & (Cpj_value > threshold):
                    dag[i][j] = 0
                    correct = True
                    correct_set.append([j,i])
                    result['selection'].append([i,j])
                    condition_set.append('S_C')
                    condition = 'S_C'
                elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value < threshold):
                    dag[j][i] = 0
                    correct = True
                    correct_set.append([i,j])
                    condition_set.append('L_C')
                    condition = 'L_C'
                    result['latent'].append([i,j])
                elif (Upi_value > threshold) & (Cpi_value < threshold) & (Upj_value < threshold) & (Cpj_value < threshold):
                    dag[i][j] = 0
                    correct = True
                    correct_set.append([j,i])
                    condition_set.append('L_C')
                    condition = 'L_C'
                    result['latent'].append([i,j])
                elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value > threshold) & (Cpi_value < threshold):
                    dag[j][i] = 0
                    dag[i][j] = 0
                    result['latent'].append([i,j])
                else:
                    correct = True
                    correct_set.append([i,j])
                    condition_set.append('F_D') 
                    condition = 'F_D'
                    correct_set.append([j,i])
                    condition_set.append('F_D') 

    
    # import pdb
    # pdb.set_trace()  
    
    for index, pair in enumerate(correct_set):
        indicator = False
        i,j = pair[0], pair[1]
        # c_set = set(get_adjSet(i, dag, n_node) + get_adjSet(j, dag, n_node))
        c_set_all = given_set(i,j,dag)
        if len(c_set_all) > 4:
            m_c = 4
        else:
            m_c = len(c_set_all)+1
        for k in range(1, m_c):
            given = combinations(c_set_all, k)
            for m in given:
                c_set = list(m)
                if i in c_set:
                    c_set.remove(i)
                if j in c_set:
                    c_set.remove(j)
                assert i not in c_set
                assert j not in c_set
                # import pdb
                # pdb.set_trace()
                if not c_set:
                    continue
                obs_ij = np.concatenate((obs[:,[i,j]], np.zeros((n_sample, 1))), axis=1)
                data_i_org = data[f'per_{i}']
                data_j_org = data[f'per_{j}']
                n_i = data_i_org.shape[0]
                n_j = data_j_org.shape[0]
                data_i_p = np.concatenate((data_i_org[:,[i,j]], np.ones((n_i,1))), axis=1)
                data_j_p = np.concatenate((data_j_org[:,[i,j]], np.ones((n_j,1))), axis=1)
                data_i = np.concatenate((obs_ij,data_i_p), axis=0)
                data_j = np.concatenate((obs_ij,data_j_p), axis=0)
                data_adj_i = np.concatenate((obs[:,list(c_set)], data_i_org[:,list(c_set)]), axis=0)
                data_adj_j = np.concatenate((obs[:,list(c_set)], data_j_org[:,list(c_set)]), axis=0)
                data_p_i = np.concatenate((data_i,data_adj_i), axis=1)
                data_p_j = np.concatenate((data_j,data_adj_j), axis=1)
                CIT_obj = CIT(data_p_j, "kci")
                g_adj = [i for i in range(3,data_p_i.shape[1])]
                # import pdb
                # pdb.set_trace()
                Upj_value = CIT_obj(0,2, g_adj)
                Cpj_value = CIT_obj(0,2,g_adj+[1])
                CIT_obi = CIT(data_p_i,"kci")
                Upi_value = CIT_obj(1,2, g_adj)
                Cpi_value = CIT_obj(1,2,g_adj+[0])
                
                if condition_set[index] == 'S_C':
                    if (Upj_value < threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value < threshold):
                        continue
                    elif (Upj_value < threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                        continue
                    elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                        continue
                    else:
                        dag[i][j] = 0
                elif condition_set[index] == 'L_C':
                    if (Upj_value < threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value < threshold):
                        continue
                    elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value < threshold):
                        continue
                    elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                        continue
                    else:
                        dag[i][j] = 0
                elif condition_set[index] == 'F_D':
                    if (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                        dag[j][i] = 0
                    elif (Upj_value < threshold) & (Cpj_value > threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                        dag[j][i] = 0
                        dag[i][j] = 0
                        result['selection'].append([i,j])
                    elif (Upj_value < threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                        dag[j][i] = 0
                        result['selection'].append([i,j])
                    elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value < threshold):
                        dag[j][i] = 0
                        result['latent'].append([i,j])
                    elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value > threshold) & (Cpi_value < threshold):
                        dag[j][i] = 0
                        dag[i][j] = 0
                        result['latent'].append([i,j])


                
    # import pdb
    # pdb.set_trace()
    # print(f'The accuracy of selecion is {true_s/len(selection)}')
    # if len(selection) == 0 and len(data['selection'].tolist())==0:
    #     acc_selection += 1
                
    count_s = 0
    count_l = 0
    # import pdb
    # pdb.set_trace()
    for s in result['selection']:
        if s in data['selection'] or (s[1], s[0]) in data['selection']:
            count_s += 1
    for l in result['latent']:
        for t in data['latent']:
            if l[0] in t and l[1] in t:
                count_l += 1

    if len(data['selection']) != 0:
        if len(result['selection']) == 0:
            acc_selection.append(0)
        else:
            acc_selection.append(count_s/len(result['selection']))
    if len(data['latent']) != 0:
        if len(result['latent']) == 0:
            acc_latent.append(0)
        else:
            acc_latent.append(count_l/len(result['latent']))

    ret_dire = count_dag_accuracy(data['dag'], dag)
    print("Directions 1 by CausalDAG: ", ret_dire)
    if ret_dire['f1'] == None:
        ret_dire['f1'] = 0
    if ret_dire['recall'] == None:
        ret_dire['recall'] = 0
    if ret_dire['precision'] == None:
        ret_dire['precision'] = 0
    f1_dag.append(ret_dire['f1'])
    recall_dag.append(ret_dire['recall'])
    acc_dag.append(ret_dire['precision'])
    shd_dag.append(ret_dire['shd'])
    times += 1
    if times == 20:
        break

print(times)
print(samples)
print(result['selection'])
print(f'the average accuracy of selection is {np.mean(acc_selection)}, variance is {np.var(acc_selection)}')
print(f'the average accuracy of latent is {np.mean(acc_latent)}, variance is {np.var(acc_latent)}')
print(f'the average accuracy of dag is {np.mean(acc_dag)}, variance is {np.var(acc_dag)}')
print(f'the average accuracy of recall of dag is {np.mean(recall_dag)}, variance is {np.var(recall_dag)}')
print(f'the average f1 score of dag is {np.mean(f1_dag)}, variance is {np.var(f1_dag)}')
print(f'the average shd of dag is {np.mean(shd_dag)}, variance is {np.var(shd_dag)}')
                    
                    


                






            
            
