import numpy as np
from causallearn.utils.cit import CIT
from itertools import combinations
from copy import deepcopy
import pydotplus
import os
def count_precision_recall_f1(tp, fp, fn):
    # Precision
    if tp + fp == 0:
        precision = None
    else:
        precision = float(tp) / (tp + fp)

    # Recall
    if tp + fn == 0:
        recall = None
    else:
        recall = float(tp) / (tp + fn)

    # F1 score
    if precision is None or recall is None:
        f1 = None
    elif precision == 0 or recall == 0:
        f1 = 0.0
    else:
        f1 = float(2 * precision * recall) / (precision + recall)
    return precision, recall, f1

def count_dag_accuracy(B_bin_true, B_bin_est):
    d = B_bin_true.shape[0]
    # linear index of nonzeros
    pred = np.flatnonzero(B_bin_est)
    cond = np.flatnonzero(B_bin_true)
    cond_reversed = np.flatnonzero(B_bin_true.T)
    cond_skeleton = np.concatenate([cond, cond_reversed])
    # true pos
    true_pos = np.intersect1d(pred, cond, assume_unique=True)
    # false pos
    false_pos = np.setdiff1d(pred, cond_skeleton, assume_unique=True)
    # reverse
    extra = np.setdiff1d(pred, cond, assume_unique=True)
    reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
    # compute ratio
    pred_size = len(pred)
    cond_neg_size = 0.5 * d * (d - 1) - len(cond)
    if pred_size == 0:
        fdr = None
    else:
        fdr = float(len(reverse) + len(false_pos)) / pred_size
    if len(cond) == 0:
        tpr = None
    else:
        tpr = float(len(true_pos)) / len(cond)
    if cond_neg_size == 0:
        fpr = None
    else:
        fpr = float(len(reverse) + len(false_pos)) / cond_neg_size
    # structural hamming distance
    pred_lower = np.flatnonzero(np.tril(B_bin_est + B_bin_est.T))
    cond_lower = np.flatnonzero(np.tril(B_bin_true + B_bin_true.T))
    extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
    missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)
    shd = len(extra_lower) + len(missing_lower) + len(reverse)
    # false neg
    false_neg = np.setdiff1d(cond, true_pos, assume_unique=True)
    precision, recall, f1 = count_precision_recall_f1(tp=len(true_pos),
                                                      fp=len(reverse) + len(false_pos),
                                                      fn=len(false_neg))
    # return {'fdr': fdr, 'tpr': tpr, 'fpr': fpr, 'shd': shd, 'nnz': pred_size, 
    #         'precision': precision, 'recall': recall, 'f1': f1}
    return {'f1': f1,  'precision': precision, 'recall': recall, 'shd': shd}


def get_adjSet(i, G, n_node):
    adj = []
    for j in range(n_node):
        if G[i][j] == 1 or G[j][i] == 1:
            adj.append(j)
    return adj
def get_adj_ij(i, j, G, n_node):
    adj = []
    for k in range(n_node):
        if G[i][k] ==1 & G[k][j] == 1:
            adj.append(k)
    return adj
def fisher_z_test(i, j, K, sample, result):
    indep = True
    fisher_z_obj = CIT(sample, "fisherz")
    Pvalue = fisher_z_obj(i,j,K)
    result.append([f'{i}_{j}_{K}___{Pvalue}'])
    # print(f'{i}_{j}_{K}___{Pvalue}')
    alpha = 0.07
    if Pvalue >= alpha:
        indep = True
    else:
        indep = False
    return indep

def skeleton(n_node, sample):

    C = np.ones((n_node,n_node))

    S = []
    for i in range(n_node):
        S.append([])
        for j in range(n_node):
            S[i].append([])

    pairs = []
    for i in range(n_node):
        for j in range(n_node - i):
            if(i != (n_node - j - 1)):  
                pairs.append((i, (n_node - j - 1)))
            else:
                C[i, i] = 0
    CI_result = []
    l = -1    
    while 1:
        l = l + 1
        flag = True   
        for (i, j) in pairs:

            adj_set = get_adjSet(i, C, n_node)    
            if(C[i][j] == 1) & (len(adj_set) >= l):    
                flag =False   
                adj_set.remove(j)    

                combin_set = combinations(adj_set, l)    
                for K in combin_set:
                    if fisher_z_test(i, j, list(K), sample, CI_result):   
                        C[i][j] = 0
                        C[j][i] = 0

                        S[i][j] = list(K)
                        S[j][i] = list(K)    
                    else:
                        continue
            else:
                continue

        if flag:
            break

    return C, S

def direction(C, S):
    
    G = deepcopy(C)
    G = G.astype(int)
    n_node = C.shape[0]


    pairs = []
    for i in range(n_node):
        for j in range(n_node):
            if(i != j):    
                if(C[i][j] == 1):
                    pairs.append((i, j))
    
    triples = []
    for (i, j) in pairs:
        for k in range(n_node):
            if(C[j][k] == 1) & (k != i):
                triples.append([i, j, k])
    
    # import pdb
    # pdb.set_trace()
    #  i-j-k， # i and k are not adjacent and (if and only if j is not in the sep_set (i,k)), then i -> j <- k
    for [i, j, k] in triples:
        if (G[i][j] == 1) & (G[j][i] == 1) & (G[k][j] == 1) & (G[j][k] == 1) & (G[i][k] == 0) & (G[k][i] == 0): 
            if j not in S[i][k]:
                G[j][i] = 0
                G[j][k] = 0
                import pdb
                pdb.set_trace()

    return G, triples
def adjacency_matrix_to_dot(adjacency_matrix, node_names=None, selection = []):
    graph = pydotplus.Dot(graph_type='graph', rankdir='TB')

    num_nodes = len(adjacency_matrix)
    for i in range(num_nodes):
        graph.add_node(pydotplus.Node(str(i), label=str(node_names[i]) if node_names else None))

    for i in range(num_nodes):
        for j in range(i):
            if adjacency_matrix[i, j] == 1 and adjacency_matrix[j, i] == 0:
                graph.add_edge(pydotplus.Edge(str(i), str(j), dir="forward"))
            elif adjacency_matrix[i, j] == 0 and adjacency_matrix[j, i] == 1:
                graph.add_edge(pydotplus.Edge(str(j), str(i), dir="forward"))
            elif (adjacency_matrix[i, j] == 1) & (adjacency_matrix[j, i] == 1):
                graph.add_edge(pydotplus.Edge(str(i), str(j), dir="none"))

    if len(selection) != 0:
        graph.add_node(pydotplus.Node('s', label='s' if node_names else None))
        for s in selection:
            graph.add_edge(pydotplus.Edge(str(s[0]), 's', dir="forward"))
            graph.add_edge(pydotplus.Edge(str(s[1]), 's', dir="forward"))

    return graph

def count_skeleton_accuracy(skeleton_true, skeleton_est):
    # skeleton_true = get_skeleton(B_bin_true) # b_bin_true[i,j]=1  <==> skeleton[i,j]=skeleton[j,i]=1
    # skeleton_est = get_skeleton(B_bin_est)   # b_bin_est[i,j]=-1 & b_bin_est[j,i]=1  <==>  skeleton[i,j]=skeleton[j,i]=1

    # print(3, skeleton_true)
    # print(4, skeleton_est) 

    d = len(skeleton_true)
    skeleton_triu_true = skeleton_true[np.triu_indices(d, k=1)]
    skeleton_triu_est = skeleton_est[np.triu_indices(d, k=1)]
    pred = np.flatnonzero(skeleton_triu_est)  # estimated graph
    cond = np.flatnonzero(skeleton_triu_true) # true graph 

    # true pos: an edge estimated with correct direction.
    true_pos = np.intersect1d(pred, cond, assume_unique=True)
    # false pos: an edge that is in estimated graph but not in the true graph.
    false_pos = np.setdiff1d(pred, cond, assume_unique=True)
    # false neg: an edge that is not in estimated graph but in the true graph.
    false_neg = np.setdiff1d(cond, pred, assume_unique=True) # This is also OK: np.setdiff1d(cond, true_pos, assume_unique=True)
    # true negative: an edge that is neither in estimated graph nor in true graph.
    # true negative: normally equals 0.

    # compute ratio
    nnz = len(pred)
    cond_neg_size = len(skeleton_triu_true) - len(cond)
    fdr = float(len(false_pos)) / max(nnz, 1)  # fdr = (FP) / (TP + FP) = FP / |pred_graph|
    tpr = float(len(true_pos)) / max(len(cond), 1)  # tpr: TP / (TP + FN) = TP / |true_graph|
    fpr = float(len(false_pos)) / max(cond_neg_size, 1) # fpr: (FP) / (TN + FP) = FP / ||
    try:
        f1 = len(true_pos) / (len(true_pos) + 0.5 * (len(false_pos) + len(false_neg)))
    except:
        f1 = None

    # structural hamming distance
    extra_lower = np.setdiff1d(pred, cond, assume_unique=True)
    missing_lower = np.setdiff1d(cond, pred, assume_unique=True)
    shd = len(extra_lower) + len(missing_lower)
    return {'f1_skeleton': f1, 'precision_skeleton': 1 - fdr, 'recall_skeleton': tpr, 'shd_skeleton': shd}
    # return {'f1_skeleton': f1, 'precision_skeleton': 1 - fdr, 'recall_skeleton': tpr,
            # 'shd_skeleton': shd, 'TPR_skeleton': tpr, 'FDR_skeleton': fdr, "number_edge_pred":len(pred), "number_edge_true":len(cond)}

def count_skeleton_acc(skeleton_true, skeleton_est):
    d = skeleton_true.shape[0]
    same = True
    for i in range(d):
        for j in range(d):
            if skeleton_true[i][j] != skeleton_est[i][j]:
                same = False
                return same
    return same
acc_selection = []
acc_dag = []
f1_dag = []
recall_dag = []
shd_dag = []
times = 0
samples = []
count = 0
interven = 'hard'
d = 10
number_of_selection = 4
# sample = [47, 125, 231, 415, 461, 597, 639, 928, 1012, 1027, 1048, 1075, 1144, 1190, 1399, 1586, 1661, 1750, 1927, 1928]
# for i in sample:
for i in range(8000):
    print(i)
    # file_path = f'./v_{d}/{interven}/sample_{i}'
    file_path = f'./selection/v_{d}/{interven}/select_{number_of_selection}/sample_{i}'
    if not os.path.exists(file_path):
        continue
    data = np.load(os.path.join(file_path, f'sample_{interven}.npz'), allow_pickle=True)
    obs = data['obs']
    if obs.shape[0] < 2000:
        continue
    n_sample, n_node = obs.shape[0],obs.shape[1]
    node_name = [str(i) for i in range(n_node)]
    # cov = np.cov(np.transpose(obs))
    ske,sep_set = skeleton(n_node, obs)
    # ret_skeleton = count_skeleton_accuracy(data['ske'], ske)
    if not count_skeleton_acc(data['ske'], ske):
        continue
    samples.append([count_skeleton_acc(data['ske'], ske), i])
    ske_graph = adjacency_matrix_to_dot(ske, node_name, data['selection'])
    image_file_path = os.path.join(file_path, f'ske_{interven}.png')
    ske_graph.write_png(image_file_path)

    # dag, triples = direction(ske,sep_set)
    # cpdag = deepcopy(dag)
    # cpdag_graph = adjacency_matrix_to_dot(cpdag, node_name, data['selection'])
    # image_file_path = './selection_data_visual/cpdag.png'
    # cpdag_graph.write_png(image_file_path)

    dag = deepcopy(ske)
    dag = dag.astype(int)
    # import pdb
    # pdb.set_trace()
    threshold = 0.03
    s_indicator = np.zeros([n_node, n_node])
    s_without_cause = []
    result = {}
    result['direct_cause'] = {}
    result['direct_cause_givenadj'] = {}
    result['selection_without_cause'] = {}
    result['selection_with_cause'] = {}
    result['others'] = {}
    for i in range(n_node):
        for j in range(i):
            if dag[i][j] == 1 & dag[j][i] == 1:
                obs_ij = np.concatenate((obs[:,[i,j]], np.zeros((n_sample, 1))), axis=1)
                data_i_org = data[f'per_{i}']
                data_j_org = data[f'per_{j}']
                n_i = data_i_org.shape[0]
                n_j = data_j_org.shape[0]
                data_i_p = np.concatenate((data_i_org[:,[i,j]], np.ones((n_i,1))), axis=1)
                data_j_p = np.concatenate((data_j_org[:,[i,j]], np.ones((n_j,1))), axis=1)
                data_i = np.concatenate((obs_ij,data_i_p), axis=0)
                data_j = np.concatenate((obs_ij,data_j_p), axis=0)
                CIT_obj = CIT(data_j, "fisherz")
                Upj_value = CIT_obj(0,2,set([]))
                Cpj_value = CIT_obj(0,2,set([1]))
                CIT_obi = CIT(data_i, "fisherz")
                Upi_value = CIT_obi(1,2, set([]))
                Cpi_value = CIT_obi(1,2, set([0]))
                if Upj_value > threshold:
                    dag[j][i] = 0
                    result['direct_cause'][f'{i}-{j}'] = Upj_value
                elif Upi_value > threshold:
                    dag[i][j] = 0
                    result['direct_cause'][f'{j}-{i}'] = Upi_value
                elif (Upj_value < threshold) & (Cpj_value > threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                    dag[j][i] = 0
                    dag[i][j] = 0
                    s_indicator[i][j] = s_indicator[j][i] = 1
                else:
                    s_indicator[i][j] = s_indicator[j][i] = 1
                    adj = set(get_adjSet(i, dag, n_node) + get_adjSet(j, dag, n_node))
                    adj.remove(i)
                    adj.remove(j)
                    assert i,j not in adj
                    data_adj_i = np.concatenate((obs[:,list(adj)], data_i_org[:,list(adj)]), axis=0)
                    data_adj_j = np.concatenate((obs[:,list(adj)], data_j_org[:,list(adj)]), axis=0)
                    data_p_i = np.concatenate((data_i,data_adj_i), axis=1)
                    data_p_j = np.concatenate((data_j,data_adj_j), axis=1)
                    CIT_obj = CIT(data_p_j, "fisherz")
                    g_adj = [i for i in range(3,data_p_i.shape[1])]
                    Upj_value = CIT_obj(0,2, g_adj)
                    Cpj_value = CIT_obj(0,2,g_adj+[1])
                    CIT_obi = CIT(data_p_i,"fisherz")
                    Upi_value = CIT_obj(1,2, g_adj)
                    Cpi_value = CIT_obj(1,2,g_adj+[0])
                    if Upj_value > threshold:
                        dag[j][i] = 0
                        result['direct_cause_givenadj'][f'{i}-{j}'] = Upj_value
                        s_indicator[i][j] = s_indicator[j][i] = 0
                    elif Upi_value > threshold:
                        dag[i][j] = 0
                        result['direct_cause_givenadj'][f'{j}-{i}'] = Upi_value
                        s_indicator[i][j] = s_indicator[j][i] = 0
                    elif (Upj_value < threshold) & (Cpj_value > threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                        dag[j][i] = 0
                        dag[i][j] = 0
                        result['selection_without_cause'][f'{i}-{j}'] =  [Upj_value, Cpj_value, Upi_value, Cpi_value]
                        for node in adj:
                            per_data = data[f'per_{node}']
                            CItest = CIT(per_data,"fisherz")
                            pvalue = CItest(i,j,[node])
                            if pvalue > threshold:
                                s_indicator[i][j] = s_indicator[j][i] = 0
                    elif (Upi_value < threshold) & (Cpi_value < threshold):
                        dag[i][j] = 0
                        result['selection_with_cause'][f'{j}-{i}']= [Upi_value, Cpi_value, Upj_value, Cpj_value]
                    elif (Upj_value < threshold) & (Cpj_value < threshold):
                        dag[j][i] = 0
                        result['selection_with_cause'][f'{i}-{j}'] = [Upj_value, Cpj_value, Upi_value, Cpi_value]
                    else:
                        result['others'][f'{i}-{j}'] =  [Upj_value, Cpj_value, Upi_value, Cpi_value]
                # import pdb
     
               # pdb.set_trace()
    true_s = 0
    selection = []
    # import pdb
    # pdb.set_trace()
    for i in range(n_node):
        for j in range(i):
            if (s_indicator[i][j] == 1) or (s_indicator[j][i]==1):
                selection.append([i,j])
                if [i,j] in data['selection'].tolist() or [j,i] in data['selection'].tolist():
                    true_s += 1

    print(selection)
    print(result)
    dag_graph = adjacency_matrix_to_dot(dag, node_name, selection)
    image_file_path =  os.path.join(file_path,'dag.png')
    dag_graph.write_png(image_file_path)
    # import pdb
    # pdb.set_trace()
    # print(f'The accuracy of selecion is {true_s/len(selection)}')
    # if len(selection) == 0 and len(data['selection'].tolist())==0:
    #     acc_selection += 1
    if len(selection) == 0:
        acc_selection.append(0)
    else:
        acc_selection.append(true_s/len(selection))

    ret_dire = count_dag_accuracy(data['dag'], dag)
    print("Directions 1 by CausalDAG: ", ret_dire)
    if ret_dire['f1'] == None:
        ret_dire['f1'] = 0
    if ret_dire['recall'] == None:
        ret_dire['recall'] = 0
    if ret_dire['precision'] == None:
        ret_dire['precision'] = 0
    f1_dag.append(ret_dire['f1'])
    recall_dag.append(ret_dire['recall'])
    acc_dag.append(ret_dire['precision'])
    shd_dag.append(ret_dire['shd'])
    times += 1
    if times == 20:
        break

print(times)
print(samples)
print(f'the average accuracy of selection is {np.mean(acc_selection)}, variace is {np.var(acc_selection)}')
print(f'the average accuracy of dag is {np.mean(acc_dag)}, variance is {np.var(acc_dag)}')
print(f'the average accuracy of recall of dag is {np.mean(recall_dag)}, variance is {np.var(recall_dag)}')
print(f'the average f1 score of dag is {np.mean(f1_dag)}, variance is {np.var(f1_dag)}')
print(f'the average shd of dag is {np.mean(shd_dag)}, variance is {np.var(shd_dag)}')
                    
                    


                






            
            
